-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Improvement: Update normalization in ACT + remove in-place modification + add test #281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
lerobot/common/policies/normalize.py
Outdated
@@ -78,10 +79,14 @@ def create_stats_buffers( | |||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. | |||
if mode == "mean_std": | |||
buffer["mean"].data = stats[key]["mean"].clone() | |||
buffer["std"].data = stats[key]["std"].clone() | |||
buffer["std"].data = stats[key]["std"].clone().clamp_min(std_epsilon) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, so I think that if the buffer says "std" we should trust that it is that, and not a modified version of it. Same with "max" below. I can think of an example where this would have been important given an experience of my own: when porting weights from the original work, I take their stats and set them in the state dict here. I would really hope that their stats are true to their names, and I would assume ours are.
If you agree, my ask would be to move the clamping logic to the normalization function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine with me. Did the clamping when storing instead of computing for a small speed-up improvement but it's probably an epsilon speedup anyway (pun intended)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated following this suggestion. feel free to check and merge if ok
@@ -120,18 +121,23 @@ def __init__( | |||
not provided, as expected for finetuning or evaluation, the default buffers should to be | |||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the | |||
dataset is not needed to get the stats, since they are already in the policy state_dict. | |||
std_epsilon (float, optional): A small minimal value for the standard deviation to avoid division by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Would it be more suitable to generalize the variable name (to epsilon
maybe) and description? It pertains not just to the standard deviation, but also the min/max. In fact, I think it's very common place to use min/max normalization so it would be nice to make sure std_epsilon
is not dismissed up front by the reader.
std_epsilon = 1e-2 | ||
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats, std_epsilon=std_epsilon) | ||
|
||
# check that the stats are correctly set including the min capping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed any more now that we have moved the clamping to the normalize function?
"max": torch.ones(1) * 3, | ||
}, | ||
"action_test_std_cap": { | ||
"mean": torch.ones(2) * 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is testing both the capped and uncapped versions (if I understand correctly), I'm unclear on why we have the above two: "action_test_std", "action_test_min_max". Aren't we doubling up without adding value?
normalized_output["action_test_std"], | ||
(input_batch["action_test_std"] - dataset_stats["action_test_std"]["mean"]) | ||
/ dataset_stats["action_test_std"]["std"], | ||
rtol=0.1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you really want this rtol=0.1? Afaik the rtol and atol are added, meaning that you are allowing over 10% deviation.
|
||
if insert_temporal_dim: | ||
assert torch.isclose( | ||
normalized_output["action_test_std_cap"][0, 0, 0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we doing this [0, 0, 0]
indexing instead of [..., 0]
? Shouldn't we be checking everything? I also think using the ellipsis notation will allow you to avoid conditionally branching on insert_temporal_dim
What this does
This PR work on the Normalization in ACT to:
The reasons for these proposals are as follow:
How it was tested
Added extensive tests for the behavior in
tests/test_policies.py::test_normalize
How to checkout & try? (for the reviewer)
Run the tests:
This change is