8000 Remove offline training, refactor `train.py` and logging/checkpointing by aliberts · Pull Request #670 · huggingface/lerobot · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Remove offline training, refactor train.py and logging/checkpointing #670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Feb 11, 2025

Conversation

aliberts
Copy link
Collaborator
@aliberts aliberts commented Jan 31, 2025

What this does

  • ⚠️ Removes the offline training part from the train.py script: online training will be handled by the training scripts from Port HIL SERL #644
  • In consequence, .offline and .online are removed from TrainPipelineConfig. To set the number of offline training step, simply use --steps:
python lerobot/scripts/train.py \
- --offline.steps=200000
+ --steps=200000
  • Adds wandb_utils.py and turns Logger into WandBLogger to remove responsibilities from this class so that it only manages wandb stuff.
  • Replaces training_state serialization with torch.save/load to safetensors.save_file/load_file. We shouldn't use torch.load() for this and in fact it breaks in 2.6 due to weights_only=True by default.
/checkpoints/005000
  ├── pretrained_model
- └── training_state.pth
+ └── training_state
+     ├── optimizer_param_groups.json
+     ├── optimizer_state.safetensors
+     ├── rng_state.safetensors
+     ├── scheduler_state.json
+     └── training_step.json
  • Adds train_utils.py to handle training checkpoints logic (including training state).
  • Cleans up functions related to rng and groups them together in random_utils.py.
  • Save checkpoint before eval during training rather than after (safer in case eval crashes)
  • Fixes logging where displayed values would only be the last one measured instead of the average over the steps from previous logging step.
  • Changed the policies main forward() output format for clarity. It now returns a tuple[Tensor, dict | None] instead of just a dict, the first element being the loss:
- output_dict = policy.forward(batch)
- loss = output_dict["loss"]
+ loss, output_dict = policy.forward(batch)
loss.backward()

How it was tested

Adds the following tests:

  • tests/test_schedulers.py
  • tests/test_optimizers.py
  • tests/test_train_utils.py
  • tests/test_random_utils.py
  • tests/test_io_utils.py

How to checkout & try? (for the reviewer)

Examples:

pytest -v \
    tests/test_schedulers.py \
    tests/test_optimizers.py \
    tests/test_train_utils.py \
    tests/test_random_utils.py \
    tests/test_io_utils.py

@aliberts aliberts changed the title Update safetensors `training_state Update training_state serialization to safetensors Jan 31, 2025
@aliberts aliberts changed the title Update training_state serialization to safetensors Refactor Logger Feb 4, 2025
@aliberts aliberts changed the title Refactor Logger Refactor train.py and logging/checkpointing Feb 8, 2025
@aliberts aliberts changed the title Refactor train.py and logging/checkpointing Remove offline training, refactor train.py and logging/checkpointing Feb 8, 2025
@aliberts aliberts added the refactor Code cleanup or restructuring without changing behavior label Feb 8, 2025
@aliberts aliberts requested a review from Cadene February 8, 2025 21:48
@aliberts aliberts marked this pull request as ready for review February 8, 2025 21:48
Copy link
Collaborator
@Cadene Cadene left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful

Could you remove all appearance of ema?
There were added by default

aliberts and others added 2 commits February 11, 2025 10:08
Co-authored-by: Remi <remi.cadene@huggingface.co>
@aliberts aliberts merged commit 90e099b into main Feb 11, 2025
7 checks passed
@aliberts aliberts deleted the user/aliberts/2025_01_31_safetensors_training_state branch February 11, 2025 09:36
aliberts added a commit that referenced this pull request Feb 12, 2025
JIy3AHKO pushed a commit to vertix/lerobot that referenced this pull request Feb 27, 2025
JIy3AHKO pushed a commit to vertix/lerobot that referenced this pull request Feb 27, 2025
Sign up for free to join this conversation on 5F77 GitHub. Already have an account? Sign in to comment
Labels
refactor Code cleanup or restructuring without changing behavior
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0