8000 Improvement: Update normalization in ACT + remove in-place modification + add test by thomwolf · Pull Request #281 · huggingface/lerobot · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Improvement: Update normalization in ACT + remove in-place modification + add test #281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

thomwolf
Copy link
Member
@thomwolf thomwolf commented Jun 18, 2024

What this does

This PR work on the Normalization in ACT to:

  • remove the default added epsilon in favor of capping the std (and min difference between max/min) to a epsilon value
  • make the normalization not in-place
  • add test for the bahavior of the normalization

The reasons for these proposals are as follow:

  • remove the default added epsilon in favor of capping the std:
    • in both torchvision and the original ACT code a capping of the min value of std is used instead of a constantly added epsilon. this bring us in-line with what other well-used repos are doing
    • in most cases std is not null we are thus closer to the theoretical behavior for the normalization
  • make the normalization not in-place
    • right now the normalization was returning a tensor but also modifying it in-place instead the function. This can be error-prone as the user typically doesn't expect a function returning a value to also modify the input in-place.
    • for now I decided to make it not in-place. An alternative could be to make it fully in-place and remove the returned value. Less intuitive according to me for probably minor speed/memory gains

How it was tested

Added extensive tests for the behavior in tests/test_policies.py::test_normalize

How to checkout & try? (for the reviewer)

Run the tests:

pytest ./tests/test_policies.py::test_normalize

This change is Reviewable

@@ -78,10 +79,14 @@ def create_stats_buffers(
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if mode == "mean_std":
buffer["mean"].data = stats[key]["mean"].clone()
buffer["std"].data = stats[key]["std"].clone()
buffer["std"].data = stats[key]["std"].clone().clamp_min(std_epsilon)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so I think that if the buffer says "std" we should trust that it is that, and not a modified version of it. Same with "max" below. I can think of an example where this would have been important given an experience of my own: when porting weights from the original work, I take their stats and set them in the state dict here. I would really hope that their stats are true to their names, and I would assume ours are.

If you agree, my ask would be to move the clamping logic to the normalization function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with me. Did the clamping when storing instead of computing for a small speed-up improvement but it's probably an epsilon speedup anyway (pun intended)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated following this suggestion. feel free to check and merge if ok

@@ -120,18 +121,23 @@ def __init__(
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
std_epsilon (float, optional): A small minimal value for the standard deviation to avoid division by
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would it be more suitable to generalize the variable name (to epsilon maybe) and description? It pertains not just to the standard deviation, but also the min/max. In fact, I think it's very common place to use min/max normalization so it would be nice to make sure std_epsilon is not dismissed up front by the reader.

std_epsilon = 1e-2
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats, std_epsilon=std_epsilon)

# check that the stats are correctly set including the min capping
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed any more now that we have moved the clamping to the normalize function?

"max": torch.ones(1) * 3,
},
"action_test_std_cap": {
"mean": torch.ones(2) * 2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is testing both the capped and uncapped versions (if I understand correctly), I'm unclear on why we have the above two: "action_test_std", "action_test_min_max". Aren't we doubling up without adding value?

normalized_output["action_test_std"],
(input_batch["action_test_std"] - dataset_stats["action_test_std"]["mean"])
/ dataset_stats["action_test_std"]["std"],
rtol=0.1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you really want this rtol=0.1? Afaik the rtol and atol are added, meaning that you are allowing over 10% deviation.


if insert_temporal_dim:
assert torch.isclose(
normalized_output["action_test_std_cap"][0, 0, 0],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we doing this [0, 0, 0] indexing instead of [..., 0]? Shouldn't we be checking everything? I also think using the ellipsis notation will allow you to avoid conditionally branching on insert_temporal_dim

@imstevenpmwork imstevenpmwork added enhancement Suggestions for new features or improvements policies Items related to robot policies labels Apr 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Suggestions for new features or improvements policies Items related to robot policies
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0