8000 Port HIL SERL by AdilZouitine · Pull Request #644 · huggingface/lerobot · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Port HIL SERL #644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 203 commits into
base: main
Choose a base branch
from

Conversation

AdilZouitine
Copy link
Member
@AdilZouitine AdilZouitine commented Jan 17, 2025

Implementing HIL-SERL

This PR implements the HIL-SERL approach as described in the paper. HIL-SERL combines Human in the loop intervention with reinforcement learning to enable efficient learning from human demonstrations.

The implementation includes:

  • Reward classifier training with pretrained architecture: Added a lightweight classification head built on top of a frozen, pretrained image encoder from HuggingFace. This classifier processes robot camera images to predict rewards, supporting binary and multi-class classification. The implementation includes metrics tracking with WandB.

  • Environment configurations for HILSerlRobotEnv: Added configuration classes for the HIL environment including VideoRecordConfig, WrapperConfig, EEActionSpaceConfig, and EnvWrapperConfig. These handle parameters for video recording, action space constraints, end-effector control, and environment-specific settings.

  • SAC-based reinforcement learning algorithm: Implemented Soft Actor-Critic (SAC) algorithm with configurable network architectures and optimization settings. The implementation includes actor and critic networks, policy configurations, temperature auto-tuning, and target network updates via exponential moving averages.

  • Actor-learner architecture with efficient communication protocols: Added actor server script that establishes connection with the learner, creating queues for parameters, transitions, and interactions. Implemented LearnerService class with gRPC for efficient streaming of parameters and transitions between components.

  • Replay buffer for storing transitions: Added ReplayBuffer class for storing and sampling transitions in reinforcement learning. Includes functions for random cropping and shifting of images, memory optimization, and batch sampling capabilities.

  • End-effector control utilities: Implemented input controllers (KeyboardController and GamepadController) that generate motion deltas for robot control. Added utilities for finding joint and end-effector bounds, and for selecting regions of interest in images.

  • Human intervention support: Added RobotEnv class that wraps robot interfaces to provide a consistent API for policy evaluation with integrated human intervention. Created PyTorch-compatible action space wrappers for seamless integration with PyTorch tensors.

Engineering Design Choices for HIL-SERL Implementation

Environment Abstraction and Entry Points

Currently, environment building for both simulation and real robot training is embedded within gym_manipulator.py. This creates a clean interface for robot interaction. While this approach works well for our immediate needs, future discussions may consider consolidating all environment creation through a single entry point in lerobot.common.envs.factory::make_env for consistency across the codebase and better maintainability.

Gym Manipulator

The gym_manipulator.py script contains the main RobotEnv class, which defines a gym-based interface for the Manipulator robot class. It also contains a set of wrappers that can be used on top of the RobotEnv class to provide additional functionality necessary for training. For example, the ImageCropResizeWrapper class is used to crop the image to a region of interest and resize it to a fixed size, EEActionWrapper is used to convert the end-effector action space to joint position commands, and so on.

The script contains three additional functions:

  • make_robot_env: This function builds a gymnasium environment with the RobotEnv base and the requested wrappers.
  • record_dataset: This function allows you to record the offline dataset of demonstrations by recording the robot's actions in the environment. This dataset can be used to train the reward classifier or as the offline dataset for the RL.
  • replay_dataset: This function allows you to replay a dataset which can be useful for debugging the action space on the robot.

You can record/replay a dataset by setting the arguments of HILSerlRobotEnvConfig in lerobot/common/envs/configs.py related to mode, dataset (more details in the guide).

Q: Why not use control_robot.py for collecting and replaying data?

A: Since we mostly use end-effector control and different teleoperation devices (gamepad, keyboard or leader), it is more convinent to collect and replay data using the gym env interface in gym_manipulator.py.
After PR #777 we might be able to seamlessly change then teleoperation device and action space. Then we can revert to using control_robot.py for collecting and replaying data.

Optional Dataset in TrainPipelineConfig

The TrainPipelineConfig class has been modified to make the dataset parameter optional. This reflects the reality that while imitation learning requires demonstration data, pure reinforcement learning algorithms can function without an offline dataset. Th 8000 is makes the training pipeline more versatile and better aligned with various learning paradigms supported by HIL-SERL.

Consolidation of Implementation Files

For actor_server.py, learner_server.py, and gym_manipulator.py, we deliberately chose to create larger, more comprehensive files rather than splitting functionality across multiple smaller files. While this approach goes against some code organization principles, it significantly reduces the cognitive load required to understand these critical components. Each file represents a complete, coherent system with clear boundaries of responsibility.

Organization of Server-Side Components

We've placed multiple related files in the lerobot/script/server folder as a first step toward better organization. This groups related functionality for the actor-learner architecture. We're waiting for reviewer feedback before proceeding with further organization to ensure our approach aligns with the project's overall structure.

MultiAdamConfig for Optimizer Management

We introduced the MultiAdamConfig class to simplify handling multiple optimizers. Reinforcement learning methods like SAC typically rely on different networks (actor, critic, temperature) that are optimized at different frequencies and with different hyperparameters. This class:

  • Provides a clean interface for creating and managing multiple optimizers
  • Reduces error-prone boilerplate code when updating different networks
  • Enables more sophisticated optimization strategies with minimal code changes
  • Simplifies checkpoint saving and loading for training resumption

Gradient Flow Through Normalization

We removed the torch.no_grad() decorator from normalization functions to allow gradients to flow through these operations. This is essential for end-to-end training where normalized inputs need to contribute to the gradient computation. Without this change, backpropagation would be blocked at normalization boundaries, preventing the model from learning to account for input normalization during training.


How it was tested

  • We trained an agent on ManiSkill using this actor-learner architecture. The main task is the PushCube-v1. The point of the ManiSkill experiments is to validate that the implementation of the soft-actor critic is correct. As for this baseline we don't have any human interventions. We validate that the implementation can work with both sparse and dense rewards, with and without an offline dataset.

image

Reward with maniskill, training without offline data and human intervention

  • Another baseline is the Mujoco based simulation of the franka panda arm in the repo HuggingFace/gym-hil. We have implemented the ability to teleoperate the simulated robot with an external keyboard or gamepad device.

image image

Plots of the intervention rate and reward vs time during one training run. We are able to train a policy with 100% success between 10-30 minutes.

Other videos using this implemenation:

IMG_5714.mov

How to check out & try it (for the reviewer) 😃

Follow this guide 😄

@michel-aractingi michel-aractingi force-pushed the user/adil-zouitine/2025-1-7-port-hil-serl-new branch from b1be31a to 2211209 Compare February 3, 2025 15:11
@AdilZouitine AdilZouitine changed the title [WIP] Fix SAC and port HIL SERL [WIP] Port HIL SERL Mar 18, 2025
@AdilZouitine AdilZouitine force-pushed the user/adil-zouitine/2025-1-7-port-hil-serl-new branch from 9a68f20 to ae12807 Compare March 24, 2025 11:05
@AdilZouitine AdilZouitine changed the base branch from user/michel-aractingi/2024-11-27-port-hil-serl to main March 24, 2025 11:07
@AdilZouitine AdilZouitine force-pushed the user/adil-zouitine/2025-1-7-port-hil-serl-new branch 2 times, most recently from ad51d89 to 808cf63 Compare March 28, 2025 17:20
@imstevenpmwork imstevenpmwork added enhancement Suggestions for new features or improvements policies Items related to robot policies labels Apr 17, 2025
@AdilZouitine AdilZouitine changed the title [WIP] Port HIL SERL Port HIL SERL Apr 18, 2025
Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai>
Co-authored-by: resolver101757 <kelster101757@hotmail.com>
Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com>
Co-authored-by: Remi <re.cadene@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
@helper2424
Copy link
Contributor

Some tests are missed, but it's one the way - Some tests are missed, but it's on the way - #1074.

Networking part will be after that.

@Ke-Wang1017
Copy link

The cube insertion task video can not been watched

dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.config.target_entropy = -np.prod(dim) / 2

def _init_temperature(self):
Copy link
Contributor

Choose a reason for hiding this comment

67E6 The reason will be displayed to describe this comment to others. Learn more.

Test comment. Maybe test with softplus paramitrization.

pyproject.toml Outdated
@@ -84,6 +85,7 @@ dora = [
]
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0", "gym-hil>=0.1.2", "protobuf>=5.29.3"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0", "gym-hil>=0.1.2", "protobuf>=5.29.3"]
hilserl = ["transformers>=4.48.0", "gym-hil>=0.1.2", "protobuf>=5.29.3"]

We don't need torchmetrics anymore.


type: str = "hil"
name: str = "PandaPickCube"
task: str = "PandaPickCubeKeyboard-v0"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also a bit strange task is PandaPickCubeKeyboard-v0 while use_gamepad is True

)


class SpatialLearnedEmbeddings(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks similar to the one on sac_modeling I thing that one of them should be dropped.
We can reuse one form sac_modeling, or maybe better extract such classed to some shared module, like polices/shared/embdeddings/ and import in both places. Also, would be cool tocover that part with tests in such case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo it's ok to have duplicate modules. Following transformers spirit, each modeling file should be self-contained. You can adapt them so they look similar (possibly copy the code) but do not import ;)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, but it seems that for SAC - all nn are similar - Critic/Actor and Reward have just different final layers. So it would be make sense to have common code for all of them.

In any case it looks like like refactoring and could be done separately.

return output


class Classifier(PreTrainedPolicy):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, generally as I have written - it seems that Reward classifier has very similar structure with Policy/Critic. So we can refactor it and reuse the code and probably drop some repeated code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above on transformers style. Duplicating code is not a problem if done right.

return observations


class DefaultImageEncoder(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can rename to something like CNN or SimpleCNN and extract to some module, which shares nn parts

class DefaultImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
super().__init__()
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of key.startswith("observation.image") I would create a separate method, called like is_image_feature and move implementation there. The implementation could change, also, I think that checking feature type this way could lead to some errors in the future.

return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)


class Identity(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, ok, It's not used, I think that we can delete it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a style pre-commit for unused function/class?
cc @imstevenpmwork @aliberts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% - lets' add it.


service_config_json = json.dumps(service_config)

channel = grpc.insecure_channel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part with config should be extracted to a separate module - to make it reusable.

@helper2424
Copy link
Contributor

I think that https://github.com/michel-aractingi/lerobot-hilserl-guide should be mentioned somewhere in the README.md or in some docs. Whenever the current PR is merged - it will be impossible to find the guide and completely not clear how to use HIL-SERL.

Also, we removed all docs from the repo, but just imagine - the PR is merged - I open the repo and don't understand what is going on in HIL-SERL. I think that we should provide some basic documentation, it shouldn't be super extensive, maybe it should have some overview and just link to https://github.com/michel-aractingi/lerobot-hilserl-guide.

Also, would be great to mention HIL-SERL in the main README.

@helper2424
Copy link
Contributor

Except comments that I already have provided changes look good.

I also would like to add several points:

  1. Security for networking part. Probably that shouldn't be part of the current PR, but definitely should be implemented. At the current point any actor could connect to any learner. Its not very good, as in the future we can have learners which serve different networks. So users with actors with one type of nn could connect to learners which serve another nn. Another issue is that anybody who knows the learner address can steal launched there neural network. So, would be great to implement two things:
    a. The mechanism that check that learner and actor have the same neural network architecture. One of ways to do it - is generating a hash for NN, for example we can take a nn state, nullify all weights and calculate the hash of that dictionary. It looks like a good unique signature - if two nns are similar in context of the architecture - the signatures should be equal. We can send such signature whenever actor connects to learner and reject connection from lrarner side if signature is not the same for them.
    b. Authorization for GRPC. Need some mechanism to checking that actor authorized to connect the learner at all. Exits different ways, for example. We can use TLS ( but it will generate more traffic and load), tokens, or some custom mehcanism bassed on GRPC meta. Here are docs https://grpc.io/docs/guides/auth/, examples https://github.com/grpc/grpc/tree/master/examples/python/auth. Also, one of possible ways - https://chatgpt.com/share/681e26b0-b6bc-8002-b2fd-5803aff4c806 with meta base creds.
  2. System tests. Would be great to have some system tests (https://en.wikipedia.org/wiki/System_testing or end2end). Its shouldn't be a part of the current PR, but nice to have, as could save a lot of time in the future. The idea is to have a CI that could check end2end that scripts work without errors (for example learner and actor) and also could run training for some nn's and check that they converge. We wasted a lot of time checking basic things after changing the nns architecture. Would be great having the automation for that. As the way - we can create docker compose that will launch actor/server. Run them, and if any of them reports about the error - the CI stops and notify about the error. If everything works and after some time the collected reward is good enough - we mark CI as green. Such approach will allow to check everything together - network part, nn convergenes, and also the perfomance. To avoid long waiting we can use small tasks with small nn's. Also, such approach will allow testing other scripts, like train.py. It will speed up the development process x10s.

@Ke-Wang1017
Copy link

I think that https://github.com/michel-aractingi/lerobot-hilserl-guide should be mentioned somewhere in the README.md or in some docs. Whenever the current PR is merged - it will be impossible to find the guide and completely not clear how to use HIL-SERL.

Also, we removed all docs from the repo, but just imagine - the PR is merged - I open the repo and don't understand what is going on in HIL-SERL. I think that we should provide some basic documentation, it shouldn't be super extensive, maybe it should have some overview and just link to https://github.com/michel-aractingi/lerobot-hilserl-guide.

Also, would be great to mention HIL-SERL in the main README.

Agree that the guide should be included in the README. Also in the guide it will be great if gym_hil instruction is included if we plan to include the HIL environment.

really_done = success_steps_collected >= cfg.number_of_steps_after_success

frame["next.done"] = np.array([really_done], dtype=bool)
frame["task"] = cfg.task
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential false use of task

We use the task as part of the gym_handle to create the env. This is not the same as the string describing the task being performed in the episode.

Copy link
Collaborator
@Cadene Cadene left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First review! Thanks :)

@@ -46,7 +46,7 @@ repos:
rev: v3.19.1
hooks:
- id: pyupgrade

exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: what is this file useful for?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These files generate by GRPC library. It doesn't make sense to check them with linters, as that code shouldn't be touched and shouldn't edit except GPRC code generators.

@@ -47,6 +47,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# TODO(aliberts, rcadene): use transforms.ToTensor()?
img = torch.from_numpy(img)

if img.dim() == 3:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if img.dim() == 3:
if img.ndim == 3:

dim() is deprecated, use ndim in torch to fit numpy API.
Look for other examples in the code like this

Comment on lines +50 to +51
if img.dim() == 3:
img = img.unsqueeze(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add your comment in the code?

Comment on lines +82 to +110
F438
class ObservationProcessorWrapper(gym.vector.VectorEnvWrapper):
def __init__(self, env: gym.vector.VectorEnv):
super().__init__(env)

def _observations(self, observations: dict[str, Any]) -> dict[str, Any]:
return preprocess_observation(observations)

def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
):
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
observations, infos = self.env.reset(seed=seed, options=options)
return self._observations(observations), infos

def step(self, actions):
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observations, rewards, terminations, truncations, infos = self.env.step(actions)
return (
self._observations(observations),
rewards,
terminations,
truncations,
infos,
)


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: creating a env wrapper for 1 line of code is a bit overkill imo

@@ -44,7 +45,7 @@ def default_choice_name(cls) -> str | None:
return "adam"

@abc.abstractmethod
def build(self) -> torch.optim.Optimizer:
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add this possibly in docstring?

Comment on lines +51 to +53
def pose_difference_se3(pose1, pose2):
"""
Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does SE(3) mean? Special Euclidean? what is each dimension?
Let's assume people dont know much about classic robotics :)


LOG_PREFIX = "[LEARNER]"

logging.basicConfig(level=logging.INFO)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment init_logging() in __main__

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Copy link
Collaborator 10000

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment, lack of docstring, example commands, etc.

self.shutdown_event.wait(self.seconds_between_pushes)

logging.info("[LEARNER] Stream parameters finished")
return hilserl_pb2.Empty()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why call it hilserl_pb2?

What call directory server since there is a service inside of it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming is not the best - but hilserl_pb2 follows the naming from proto file. We can change it something better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the naming for directory is no the best too - yeah. We can do something better here

def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer, weights_only=False) # nosec B614: Safe usage of torch.load
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add more comment that just # nosec B614: Safe usage of torch.load ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Suggestions for new features or improvements policies Items related to robot policies
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants
0