8000 Extend reward classifier for multiple camera views by michel-aractingi · Pull Request #626 · huggingface/lerobot · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Extend reward classifier for multiple camera views #626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lerobot/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
from pathlib import Path

import torch
import wandb
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
10000 from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

import wandb
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class ClassifierConfig:
model_name: str = "microsoft/resnet-50"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2

def save_pretrained(self, save_dir):
"""Save config to json file."""
Expand Down
16 changes: 11 additions & 5 deletions lerobot/common/policies/hilserl/classifier/modeling_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _build_classifier_head(self) -> None:
raise ValueError("Unsupported transformer architecture since hidden_size is not found")

self.classifier_head = nn.Sequential(
nn.Linear(input_dim, self.config.hidden_dim),
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(),
Expand Down Expand Up @@ -130,16 +130,22 @@ def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :]

def forward(self, x: torch.Tensor) -> ClassifierOutput:
def forward(self, xs: torch.Tensor) -> ClassifierOutput:
"""Forward pass of the classifier."""
# For training, we expect input to be a tensor directly from LeRobotDataset
encoder_output = self._get_encoder_output(x)
logits = self.classifier_head(encoder_output)
encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
logits = self.classifier_head(encoder_outputs)

if self.config.num_classes == 2:
logits = logits.squeeze(-1)
probabilities = torch.sigmoid(logits)
else:
probabilities = torch.softmax(logits, dim=-1)

return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_output)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)

def predict_reward(self, x):
if self.config.num_classes == 2:
return (self.forward(x).probabilities > 0.5).float()
else:
return torch.argmax(self.forward(x).probabilities, dim=1)
9 changes: 9 additions & 0 deletions lerobot/common/robot_devices/control_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from functools import cache

import cv2
import numpy as np
import torch
import tqdm
from deepdiff import DeepDiff
Expand Down Expand Up @@ -332,6 +333,14 @@ def reset_environment(robot, events, reset_time_s):
break


def reset_follower_position(robot: Robot, target_position):
current_position = robot.follower_arms["main"].read("Present_Position")
trajectory = torch.from_numpy(np.linspace(current_position, target_position, 30)) # NOTE: 30 is just an aribtrary number
for pose in trajectory:
robot.send_action(pose)
busy_wait(0.015)


def stop_recording(robot, listener, display_cameras):
robot.disconnect()

Expand Down
9 changes: 5 additions & 4 deletions lerobot/configs/policy/hilserl_classifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defaults:
- _self_

seed: 13
dataset_repo_id: "dataset_repo_id"
dataset_repo_id: aractingi/pick_place_lego_cube_1
train_split_proportion: 0.8

# Required by logger
Expand All @@ -24,17 +24,18 @@ training:
eval_freq: 1 # How often to run validation (in epochs)
save_freq: 1 # How often to save checkpoints (in epochs)
save_checkpoint: true
image_key: "observation.images.phone"
image_keys: ["observation.images.top", "observation.images.wrist"]
label_key: "next.reward"

eval:
batch_size: 16
num_samples_to_log: 30 # Number of validation samples to log in the table

policy:
name: "hilserl/classifier"
name: "hilserl/classifier/pick_place_lego_cube_1"
model_name: "facebook/convnext-base-224"
model_type: "cnn"
num_cameras: 2 # Has to be len(training.image_keys)

wandb:
enable: false
Expand All @@ -44,4 +45,4 @@ wandb:

device: "mps"
resume: false
output_dir: "output"
output_dir: "outputs/classifier"
13 changes: 13 additions & 0 deletions lerobot/scripts/control_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
log_control_info,
record_episode,
reset_environment,
reset_follower_position,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
stop_recording,
Expand Down Expand Up @@ -205,6 +206,7 @@ def record(
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True,
play_sounds: bool = True,
reset_follower: bool = False,
resume: bool = False,
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
local_files_only: bool = False,
Expand Down Expand Up @@ -265,6 +267,9 @@ def record(
robot.connect()
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)

if reset_follower:
initial_position = robot.follower_arms["main"].read("Present_Position")

# Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided,
# 2. give times to the robot devices to connect and start synchronizing,
Expand Down Expand Up @@ -307,6 +312,8 @@ def record(
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", play_sounds)
if reset_follower:
reset_follower_position(robot, initial_position)
reset_environment(robot, events, reset_time_s)

if events["rerecord_episode"]:
7680 Expand Down Expand Up @@ -527,6 +534,12 @@ def replay(
default=0,
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
)
parser_record.add_argument(
"--reset-follower",
type=int,
default=0,
help="Resets the follower to the initial position during while reseting the evironment, this is to avoid having the follower start at an awkward position in the next episode",
)

parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
Expand Down
Loading
Loading
0