Description
Description
For PyTorch training, it is very common to checkpoint state dicts. However, after wrapping your model with ray.train.torch.prepare_model(...)
(which will wrap it in DistributedDataParallel
), the the keys in the state dict will contain a .module
prefix, which will prevent the checkpoint from being used in TorchPredictor
.
We have documented that users should use torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict, "module.")
to remove the prefix from the state dict keys before saving it.
However, with Ray AIR, since we know exactly which key in the checkpoint dict the model state dict will be saved in, we can do this prefix extraction automatically.
We can also move the logic for overriding serialization when using amp to the encoding/decoding logic rather than changing the model directly: https://github.com/ray-project/ray/pull/25335/files#diff-ae1969f579e6412ce1fb189760f540a377dc4bb9853e0dc6f6dcdb6011dbc3adR124
Use case
No response