8000 [AIR/Train] Torch: Automatically unpack model when checkpointing state dicts · Issue #24975 · ray-project/ray · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[AIR/Train] Torch: Automatically unpack model when checkpointing state dicts #24975
Open
@amogkam

Description

@amogkam

Description

For PyTorch training, it is very common to checkpoint state dicts. However, after wrapping your model with ray.train.torch.prepare_model(...) (which will wrap it in DistributedDataParallel), the the keys in the state dict will contain a .module prefix, which will prevent the checkpoint from being used in TorchPredictor.

We have documented that users should use torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict, "module.") to remove the prefix from the state dict keys before saving it.

However, with Ray AIR, since we know exactly which key in the checkpoint dict the model state dict will be saved in, we can do this prefix extraction automatically.

We can also move the logic for overriding serialization when using amp to the encoding/decoding logic rather than changing the model directly: https://github.com/ray-project/ray/pull/25335/files#diff-ae1969f579e6412ce1fb189760f540a377dc4bb9853e0dc6f6dcdb6011dbc3adR124

Use case

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    P2Important issue, but not time-criticalenhancementRequest for new feature and/or capabilitypending-cleanupThis issue is pending cleanup. It will be removed in 2 weeks after being assigned.ray-team-createdRay Team createdtrainRay Train Related Issue

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0