-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Use torch.distributed.checkpoint.state_dict.set_model_state_dict
in load_checkpoint_in_model
#3432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use torch.distributed.checkpoint.state_dict.set_model_state_dict
in load_checkpoint_in_model
#3432
Conversation
…oad_checkpoint_in_model load_checkpoint_in_model now supports loading into FSDP2-wrapped models when using device_map=None for large models in a distributed setting, by leveraging broadcast_from_rank0, the reduced file system reads results in much faster loading (for loading a 70B model on a single node of 8 GPUs, 60 seconds vs 90 seconds)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR ! This is a nice functionality ! Left a few comments. Can you have a second look @muellerzr ? Could you also fix the CI ? There are a lot of failing tests currently due to this PR. Also, mayb we can move the tests to the test_fsdp.py test file ?
src/accelerate/utils/modeling.py
Outdated
import torch.nn as nn | ||
from torch import distributed as dist | ||
from torch import nn | ||
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this available starting torch 2.0 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah torch.distributed.checkpoint
is new in 2.0.0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think StateDictOptions
and set_model_state_dict
were first added in 2.2.0 though.
https://github.com/pytorch/pytorch/blob/v2.2.0/torch/distributed/checkpoint/state_dict.py#L81
https://github.com/pytorch/pytorch/blob/v2.2.0/torch/distributed/checkpoint/state_dict.py#L761
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll move this down into the function and guard with is_torch_version(">=", "2.2.0")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/accelerate/utils/modeling.py
Outdated
full_state_dict (`bool`, *optional*, defaults to `True`): if this is set to `True`, all the tensors in the | ||
returned state_dict will be gathered. No ShardedTensor and DTensor will be in the returned state_dict. | ||
broadcast_from_rank0 (`bool`, *optional*, defaults to `True`): when the option is `True`, rank0 should receive | ||
a full state_dict and will broadcast the tensors in the state_dict one by one to other ranks. Other ranks | ||
will receive the tensors and shard according to the local shards in the model. `full_state_dict` must be | ||
set to `True` when using this option. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
specify that these are for fsdp/tp only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_model_state_dict
isn't only for FSDP and TP. It handles non-distributed loading also. You can even use this with DDP. I'll add a test demonstrating this.
I will update the note to mention that a ProcessGroup
must be initialized if broadcast_from_rank0=True
, and I will change its default to False
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/accelerate/utils/modeling.py
Outdated
loaded_checkpoint = ( | ||
load_state_dict(checkpoint_file, device_map=device_map) | ||
if (not broadcast_from_rank0 or dist.is_initialized() and dist.get_rank() == 0) | ||
else {} | ||
) | ||
set_model_state_dict( | ||
model, | ||
loaded_checkpoint, | ||
options=StateDictOptions( | ||
full_state_dict=full_state_dict, | ||
strict=strict, | ||
broadcast_from_rank0=broadcast_from_rank0, | ||
), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do the following if distributed is initialized ! I'm not sure set_model_state_dict
will work if it is not initialized, especially since we set broadcast_from_rank0 to True by default. if a device_map is passed when loading in distributed env, we can raise a warning/error for instance.
Also, we can maybe initialize a PartialState here instead of calling dist.is_initialized() and dist.get_rank()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_model_state_dict
does work when distributed is not initialized, but broadcast_from_rank0=True
doesn't work when distributed isn't initialized. To your point, I think the safest thing to do may be to default broadcast_from_rank0
to False
instead.
Also I'll add a test demonstrating that set_model_state_dict
does work in a non-distributed context (when broadcast_from_rank0=False
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def pytest_xdist_worker_id(): | ||
""" | ||
Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0 | ||
if `-n 1` or `pytest-xdist` isn't being used. | ||
""" | ||
worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0") | ||
worker = re.sub(r"^gw", "", worker, 0, re.M) | ||
return int(worker) | ||
|
||
|
||
def get_torch_dist_unique_port(): | ||
""" | ||
Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument. | ||
|
||
Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same | ||
port at once. | ||
""" | ||
port = 29500 | ||
uniq_delta = pytest_xdist_worker_id() | ||
return port + uniq_delta | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a way to not have this @muellerzr ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW this is just taken from transformers
here:
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import functools | ||
import itertools | ||
import unittest | ||
from typing import Any, Callable | ||
|
||
import torch | ||
from huggingface_hub import hf_hub_download | ||
from torch import distributed as dist | ||
from torch import nn | ||
from torch.distributed._composable.fsdp import fully_shard | ||
from torch.distributed._tensor import DTensor | ||
from torch.distributed.device_mesh import init_device_mesh | ||
from torch.distributed.fsdp.wrap import _recursive_wrap, transformer_auto_wrap_policy | ||
from transformers import AutoConfig, AutoModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also add a tensor parallel test since you talk about it in the PR ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…>=', '2.2.0') This should fix issues with slow import and also fixes versioning issues huggingface#3432 (comment) huggingface#3432 (comment)
…patch(device_map=None) using set_model_state_dict huggingface#3432 (comment) huggingface#3432 (comment)
…n version of torch to test as 2.4.0
I think everything should be passing now if you want to give it another go! @SunMarc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating ! Just a few nits ! Can you have a second look @muellerzr for the tests in particular ? Also, the CI is red bacause of your cahnges, can you have a quick look a these tests ?
FAILED
tests/test_accelerator.py::AcceleratorTester::test_save_model_offload_use_pytorch - AssertionError
FAILED tests/test_accelerator.py::AcceleratorTester::test_save_model_offload_use_safetensors - AssertionError
|
||
class TestLoadCheckpointAndDispatchWithBroadcast(unittest.TestCase): | ||
@require_transformers | ||
@require_multi_gpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can put the decorators above the class name !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else: | ||
torch.testing.assert_close(tensor, tp_tensor, msg=tp_name) | ||
|
||
@require_torch_min_version(version="2.4.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's put the require_torch_min_version decorator above TestLoadCheckpointAndDispatchWithBroadcast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if is_transformers_available(): | ||
from transformers import AutoConfig, AutoModel, PreTrainedModel | ||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block | ||
|
||
def manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]: | ||
"""Manage the creation and destruction of the distributed process group for the wrapped function.""" | ||
|
||
def wrapped(*args: Any, **kwargs: Any) -> Any: | ||
dist.init_process_group(world_size=torch.cuda.device_count()) | ||
try: | ||
return func(*args, **kwargs) | ||
finally: | ||
dist.destroy_process_group() | ||
|
||
return wrapped | ||
|
||
@require_torch_min_version(version="2.4.0") | ||
@manage_process_group | ||
def load_checkpoint_and_dispatch_fsdp2(): | ||
torch.cuda.set_device(device := torch.device(dist.get_rank())) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to put all these function inside the condition if is_transformers_available():
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_model_state_dict will fail if the model state_dict is not on at most one device
So weird, I remember solving these failures earlier but I guess I didn't push it... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot !
src/accelerate/utils/modeling.py
Outdated
if device_map is None: | ||
model.load_state_dict(loaded_checkpoint, strict=strict) | ||
if is_torch_version(">=", "2.2.0") and len(model_devices) <= 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this is needed len(model_devices) <= 1:
? usually, the model is on meta device but the non persistant buffer are usually not on the meta device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a comment explaining this. This is just an explicit restriction that set_model_state_dict
has. Starting in v2.7.0 (not yet released), they actually do support one physical device + the meta device.
I've updated the condition to account for this.
https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/checkpoint/state_dict.py#L557-L563
https://github.com/pytorch/pytorch/blob/v2.7.0-rc2/torch/distributed/checkpoint/state_dict.py#L575-L587
Check it out f5555fb!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, non-persistent buffers are not included in the state_dict by definition, so they wouldn't affect this check.
@muellerzr @SunMarc any progress here? How is this looking? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the tensor parallel example shouldn't work anymore unfortunately due to changed in transformers
with device, init_empty_weights(): | ||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path) | ||
tp_model = AutoModel.from_config(config) | ||
tp_model.tie_weights() | ||
assert isinstance(tp_model, nn.Module) | ||
|
||
mesh = init_device_mesh(device.type, (dist.get_world_size(),)) | ||
assert tp_model.supports_tp_plan | ||
assert callable(tp_model.tensor_parallel) | ||
tp_model.tensor_parallel(mesh) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The api for tensor_parallel
changed a bit in transformers. Not sure we need this example anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure I'll remove this test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will merge it soon but after today's release as I prefer to not include in it yet. |
Sorry for wait, merging. |
What does this PR do?
load_checkpoint_in_model
now supports loading into FSDP2-wrapped or Tensor Parallelized models when usingdevice_map=None
for large models in a distributed setting, by leveraging
broadcast_from_rank0
, the reduced file system reads results in much faster loading (for loading a 70B model on a single node of 8 GPUs, 60 seconds vs 90 seconds)Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@SunMarc @muellerzr @BenjaminBossan