8000 [WIP] HIL SERL port grasp critic by AdilZouitine · Pull Request #937 · huggingface/lerobot · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[WIP] HIL SERL port grasp critic #937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
4a1c26d
Add grasp critic
s1lent4gnt Mar 31, 2025
007fee9
Add complementary info in the replay buffer
s1lent4gnt Mar 31, 2025
7452f9b
Add gripper penalty wrapper
s1lent4gnt Mar 31, 2025
2c1e5fa
Add get_gripper_action method to GamepadController
s1lent4gnt Mar 31, 2025
c774bbe
Add grasp critic to the training loop
s1lent4gnt Mar 31, 2025
7983baf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2025
fe2ff51
Added Gripper quantization wrapper and grasp penalty
michel-aractingi Apr 1, 2025
6a215f4
Refactor SAC configuration and policy to support discrete actions
AdilZouitine Apr 1, 2025
306c735
Refactor SAC policy and training loop to enhance discrete action support
AdilZouitine Apr 1, 2025
451a7b0
Add mock gripper support and enhance SAC policy action handling
AdilZouitine Apr 1, 2025
699d374
Refactor SACPolicy for improved readability and action dimension hand…
AdilZouitine Apr 1, 2025
0ed7ff1
Enhance SACPolicy and learner server for improved grasp critic integr…
AdilZouitine Apr 2, 2025
51f1625
Enhance SACPolicy to support shared encoder and optimize action selec…
AdilZouitine Apr 3, 2025
38a8dbd
Enhance SAC configuration and replay buffer with asynchronous prefetc…
AdilZouitine Apr 3, 2025
e86fe66
fix indentation issue
AdilZouitine Apr 3, 2025
037ecae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2025
7741526
fix caching
AdilZouitine Apr 4, 2025
4621f4e
Handle gripper penalty
AdilZouitine Apr 7, 2025
6c10390
Refactor complementary_info handling in ReplayBuffer
AdilZouitine Apr 7, 2025
632b2b4
fix sign issue
AdilZouitine Apr 7, 2025
a7be613
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2025
a813562
Add rounding for safety
AdilZouitine Apr 8, 2025
d948b95
fix caching and dataset stats is optional
AdilZouitine Apr 9, 2025
e7edf2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2025
5428ab9
General fixes in code, removed delta action, fixed grasp penalty, add…
michel-aractingi Apr 9, 2025
ba09f44
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2025
854bfb4
fix encoder training
AdilZouitine Apr 11, 2025
320a1a9
Refactor modeling_sac and parameter handling for clarity and reusabil…
AdilZouitine Apr 14, 2025
35ecaae
stick to hil serl nn architecture
AdilZouitine Apr 15, 2025
a850d43
match target entropy hil serl
AdilZouitine Apr 15, 2025
157f719
change the tanh distribution to match hil serl
AdilZouitine Apr 15, 2025
0c9a3ec
Handle caching
AdilZouitine Apr 15, 2025
d4f341e
fix caching
AdilZouitine Apr 15, 2025
a6f612e
Update log_std_min type to float in PolicyConfig for consistency
AdilZouitine Apr 15, 2025
7191bbb
Fix init temp
AdilZouitine Apr 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lerobot/common/envs/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ class VideoRecordConfig:
class WrapperConfig:
"""Configuration for environment wrappers."""

delta_action: float | None = None
joint_masking_action_space: list[bool] | None = None


Expand All @@ -191,7 +190,6 @@ class EnvWrapperConfig:
"""Configuration for environment wrappers."""

display_cameras: bool = False
delta_action: float = 0.1
use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False
add_ee_pose_to_observation: bool = False
Expand All @@ -203,6 +201,10 @@ class EnvWrapperConfig:
joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None
use_gripper: bool = False
gripper_quantization_threshold: float | None = 0.8
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
open_gripper_on_reset: bool = False


@EnvConfig.register_subclass(name="gym_manipulator")
Expand Down Expand Up @@ -254,6 +256,7 @@ class ManiskillEnvConfig(EnvConfig):
robot: str = "so100" # This is a hack to make the robot config work
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
mock_gripper: bool = False
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
Expand Down
14 changes: 10 additions & 4 deletions lerobot/common/policies/sac/configuration_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class ActorNetworkConfig:
@dataclass
class PolicyConfig:
use_tanh_squash: bool = True
log_std_min: int = -5
log_std_max: int = 2
log_std_min: float = 1e-5
log_std_max: float = 10.0
init_final: float = 0.05


Expand Down Expand Up @@ -85,12 +85,15 @@ class SACConfig(PreTrainedConfig):
freeze_vision_encoder: Whether to freeze the vision encoder during training.
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
shared_encoder: Whether to use a shared encoder for actor and critic.
num_discrete_actions: Number of discrete actions, eg for gripper actions.
image_embedding_pooling_dim: Dimension of the image embedding pooling.
concurrency: Configuration for concurrency settings.
actor_learner: Configuration for actor-learner architecture.
online_steps: Number of steps for online training.
online_env_seed: Seed for the online environment.
online_buffer_capacity: Capacity of the online replay buffer.
offline_buffer_capacity: Capacity of the offline replay buffer.
async_prefetch: Whether to use asynchronous prefetching for the buffers.
online_step_before_learning: Number of steps before learning starts.
policy_update_freq: Frequency of policy updates.
discount: Discount factor for the SAC algorithm.
Expand Down Expand Up @@ -118,7 +121,7 @@ class SACConfig(PreTrainedConfig):
}
)

dataset_stats: dict[str, dict[str, list[float]]] = field(
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: {
"observation.image": {
"mean": [0.485, 0.456, 0.406],
Expand All @@ -144,12 +147,15 @@ class SACConfig(PreTrainedConfig):
freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32
shared_encoder: bool = True
num_discrete_actions: int | None = None
image_embedding_pooling_dim: int = 8

# Training parameter
online_steps: int = 1000000
online_env_seed: int = 10000
online_buffer_capacity: int = 100000
offline_buffer_capacity: int = 100000
async_prefetch: bool = False
online_step_before_learning: int = 100
policy_update_freq: int = 1

Expand All @@ -173,7 +179,7 @@ class SACConfig(PreTrainedConfig):
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)

grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)

Expand Down
Loading
Loading
0