-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Fix ACT temporal ensembling #319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix ACT temporal ensembling #319
Conversation
online_avg = ensembler.update(actions) | ||
# Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ). | ||
# Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid. | ||
# What we want to do is take diagonal slices across it starting from the left. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: I think this gets a little hairy for a "simple" test, but I really wanted to make sure it's properly checked. I hope the explanation is enough to make the reviewer feel comfortable that this test is doing what it's supposed to. Perhaps it's enough to know that we do the same thing with two approaches and get the same answer.
@Alternmill for review please. |
0bb6211
to
7dc4765
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
It seems that this change isn't backward compatible, no?
If this is the case, I am fine with that, but we should warn people on discord that they can't load a checkpoint that has temporal_ensemble_momentum
in the config. And ideally provide a minimal procedure to update their checkpoint config.
Also, I am wondering why this backward compatibility breaking change is not captured in our unit tests when we load a model checkpoint.
@Cadene it doesn't break unit tests because of this lerobot/lerobot/common/policies/factory.py Lines 25 to 44 in 5ffcb48
temporal_ensemble_momentum will be ignored and temporal_ensemble_coeff will get a warning.I think we should probably consider making the former case raise an exception, but there may be a good reasons I didn't do that in the first place. Yes, I'll mention it on 8000 Discord. |
@alexander-soare Could be better to raise an exception to avoid people overriding an argument from command line and it is actually ignored, but they didnt see the warning. When we go out of |
@Cadene I'm not sure I understand. I suggested raising an exception if someone provides an unknown param. Are you saying something else? Let's take this discussion off this PR though :) |
What this does
Fixes an issue with the weighting scheme for the temporal ensembling:
The fix is to directly use the exponential weighting scheme referred to in Algo2 of https://arxiv.org/abs/2304.13705
Here I implement it in an online update fashion so we don't have the ugliness of storing a cache of actions.
How it was tested
I added a test for CI.
I tried eval'ing https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human/tree/main for 500 episodes with temporal ensembling.
Edit: There was a bug in ACT so this table and subsequent commentary was edited on 17 July with the bug fixed.
Here's an episode for m=0.01:
eval_episode_2.mp4
How to checkout & try? (for the reviewer)
Try it with
python lerobot/scripts/eval.py -p lerobot/act_aloha_sim_transfer_cube_human eval.n_episodes=10 eval.batch_size=10 +policy.temporal_ensemble_coeff=0.01 policy.n_action_steps=1