8000 Use `torch.distributed.checkpoint.state_dict.set_model_state_dict` in `load_checkpoint_in_model` by ringohoffman · Pull Request #3432 · huggingface/accelerate · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Use torch.distributed.checkpoint.state_dict.set_model_state_dict in load_checkpoint_in_model #3432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

ringohoffman
Copy link
@ringohoffman ringohoffman commented Mar 10, 2025

What does this PR do?

load_checkpoint_in_model now supports loading into FSDP2-wrapped or Tensor Parallelized models when using device_map=None

for large models in a distributed setting, by leveraging broadcast_from_rank0, the reduced file system reads results in much faster loading (for loading a 70B model on a single node of 8 GPUs, 60 seconds vs 90 seconds)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@SunMarc @muellerzr @BenjaminBossan

…oad_checkpoint_in_model

load_checkpoint_in_model now supports loading into FSDP2-wrapped models when using device_map=None

for large models in a distributed setting, by leveraging broadcast_from_rank0, the reduced file system reads results in much faster loading (for loading a 70B model on a single node of 8 GPUs, 60 seconds vs 90 seconds)
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member
@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR ! This is a nice functionality ! Left a few comments. Can you have a second look @muellerzr ? Could you also fix the CI ? There are a lot of failing tests currently due to this PR. Also, mayb we can move the tests to the test_fsdp.py test file ?

import torch.nn as nn
from torch import distributed as dist
from torch import nn
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this available starting torch 2.0 ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah torch.distributed.checkpoint is new in 2.0.0.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll move this down into the function and guard with is_torch_version(">=", "2.2.0")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 1823 to 1828
full_state_dict (`bool`, *optional*, defaults to `True`): if this is set to `True`, all the tensors in the
returned state_dict will be gathered. No ShardedTensor and DTensor will be in the returned state_dict.
broadcast_from_rank0 (`bool`, *optional*, defaults to `True`): when the option is `True`, rank0 should receive
a full state_dict and will broadcast the tensors in the state_dict one by one to other ranks. Other ranks
will receive the tensors and shard according to the local shards in the model. `full_state_dict` must be
set to `True` when using this option.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

specify that these are for fsdp/tp only

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_model_state_dict isn't only for FSDP and TP. It handles non-distributed loading also. You can even use this with DDP. I'll add a test demonstrating this.

I will update the note to mention that a ProcessGroup must be initialized if broadcast_from_rank0=True, and I will change its default to False.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 1911 to 1924
loaded_checkpoint = (
load_state_dict(checkpoint_file, device_map=device_map)
if (not broadcast_from_rank0 or dist.is_initialized() and dist.get_rank() == 0)
else {}
)
set_model_state_dict(
model,
loaded_checkpoint,
options=StateDictOptions(
full_state_dict=full_state_dict,
strict=strict,
broadcast_from_rank0=broadcast_from_rank0,
),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do the following if distributed is initialized ! I'm not sure set_model_state_dict will work if it is not initialized, especially since we set broadcast_from_rank0 to True by default. if a device_map is passed when loading in distributed env, we can raise a warning/error for instance.

Also, we can maybe initialize a PartialState here instead of calling dist.is_initialized() and dist.get_rank()

Copy link
Author
@ringohoffman ringohoffman Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_model_state_dict does work when distributed is not initialized, but broadcast_from_rank0=True doesn't work when distributed isn't initialized. To your point, I think the safest thing to do may be to default broadcast_from_rank0 to False instead.

Also I'll add a test demonstrating that set_model_state_dict does work in a non-distributed context (when broadcast_from_rank0=False)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8000
Comment on lines +644 to +665
def pytest_xdist_worker_id():
"""
Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0
if `-n 1` or `pytest-xdist` isn't being used.
"""
worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0")
worker = re.sub(r"^gw", "", worker, 0, re.M)
return int(worker)


def get_torch_dist_unique_port():
"""
Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument.

Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same
port at once.
"""
port = 29500
uniq_delta = pytest_xdist_worker_id()
return port + uniq_delta


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way to not have this @muellerzr ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 1 to 29
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import functools
import itertools
import unittest
from typing import Any, Callable

import torch
from huggingface_hub import hf_hub_download
from torch import distributed as dist
from torch import nn
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp.wrap import _recursive_wrap, transformer_auto_wrap_policy
from transformers import AutoConfig, AutoModel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also add a tensor parallel test since you talk about it in the PR ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ringohoffman
Copy link
Author

I think everything should be passing now if you want to give it another go! @SunMarc

A3E2

Copy link
Member
@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating ! Just a few nits ! Can you have a second look @muellerzr for the tests in particular ? Also, the CI is red bacause of your cahnges, can you have a quick look a these tests ?

FAILED 
tests/test_accelerator.py::AcceleratorTester::test_save_model_offload_use_pytorch - AssertionError
FAILED tests/test_accelerator.py::AcceleratorTester::test_save_model_offload_use_safetensors - AssertionError

Comment on lines 201 to 204

class TestLoadCheckpointAndDispatchWithBroadcast(unittest.TestCase):
@require_transformers
@require_multi_gpu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can put the decorators above the class name !

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
torch.testing.assert_close(tensor, tp_tensor, msg=tp_name)

@require_torch_min_version(version="2.4.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put the require_torch_min_version decorator above TestLoadCheckpointAndDispatchWithBroadcast

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 38 to 58
if is_transformers_available():
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Block

def manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]:
"""Manage the creation and destruction of the distributed process group for the wrapped function."""

def wrapped(*args: Any, **kwargs: Any) -> Any:
dist.init_process_group(world_size=torch.cuda.device_count())
try:
return func(*args, **kwargs)
finally:
dist.destroy_process_group()

return wrapped

@require_torch_min_version(version="2.4.0")
@manage_process_group
def load_checkpoint_and_dispatch_fsdp2():
torch.cuda.set_device(device := torch.device(dist.get_rank()))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to put all these function inside the condition if is_transformers_available():

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc SunMarc requested a review from muellerzr March 13, 2025 11:06
@ringohoffman
Copy link
Author

Thanks for iterating ! Just a few nits ! Can you have a second look @muellerzr for the tests in particular ? Also, the CI is red bacause of your cahnges, can you have a quick look a these tests ?

FAILED 
tests/test_accelerator.py::AcceleratorTester::test_save_model_offload_use_pytorch - AssertionError
FAILED tests/test_accelerator.py::AcceleratorTester::test_save_model_offload_use_safetensors - AssertionError

So weird, I remember solving these failures earlier but I guess I didn't push it...

cdf321b

Copy link
Member
@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot !

if device_map is None:
model.load_state_dict(loaded_checkpoint, strict=strict)
if is_torch_version(">=", "2.2.0") and len(model_devices) <= 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this is needed len(model_devices) <= 1: ? usually, the model is on meta device but the non persistant buffer are usually not on the meta device

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment explaining this. This is just an explicit restriction that set_model_state_dict has. Starting in v2.7.0 (not yet released), they actually do support one physical device + the meta device.

I've updated the condition to account for this.

https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/checkpoint/state_dict.py#L557-L563
https://github.com/pytorch/pytorch/blob/v2.7.0-rc2/torch/distributed/checkpoint/state_dict.py#L575-L587

Check it out f5555fb!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, non-persistent buffers are not included in the state_dict by definition, so they wouldn't affect this check.

@ringohoffman
Copy link
Author

@muellerzr @SunMarc any progress here? How is this looking?

Copy link
Member
@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tensor parallel example shouldn't work anymore unfortunately due to changed in transformers

Comment on lines 146 to 155
with device, init_empty_weights():
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
tp_model = AutoModel.from_config(config)
tp_model.tie_weights()
assert isinstance(tp_model, nn.Module)

mesh = init_device_mesh(device.type, (dist.get_world_size(),))
assert tp_model.supports_tp_plan
assert callable(tp_model.tensor_parallel)
tp_model.tensor_parallel(mesh)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The api for tensor_parallel changed a bit in transformers. Not sure we need this example anymore.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I'll remove this test.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc
Copy link
Member
SunMarc commented Apr 1, 2025

I will merge it soon but after today's release as I prefer to not include in it yet.

@SunMarc SunMarc merged commit 73c2378 into huggingface:main Apr 11, 2025
25 checks passed
@SunMarc
Copy link
Member
SunMarc commented Apr 11, 2025

Sorry for wait, merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0