10000 (Part 1) fix: make TP training compatible with new transformers by kmehant · Pull Request #3457 · huggingface/accelerate · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

(Part 1) fix: make TP training compatible with new transformers #3457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 11, 2025

Conversation

kmehant
Copy link
Contributor
@kmehant kmehant commented Mar 25, 2025

What does this PR do?

Fixes #3456

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Thanks to @SunMarc for valuable discussion over #3456

@muellerzr or @SunMarc

Copy link
Member
@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Left a comment

@kmehant kmehant requested a review from SunMarc March 25, 2025 14:38
Copy link
Contributor
@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kmehant kmehant changed the title fix: make TP training compatible with new transformers (Part 1) fix: make TP training compatible with new transformers Mar 27, 2025
Copy link
Member
@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll review that after merging the transformers PR ! But for a quick look it looks nice

Copy link
Member
@S1ro1 S1ro1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have tested locally and and ran some stuff, seems to work! LGTM

@kmehant kmehant requested a review from SunMarc April 10, 2025 16:20
@kmehant
Copy link
Contributor Author
kmehant commented Apr 10, 2025

Failing test is unrelated. Thanks

Copy link
Member
@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Now that we've merged the PR about tp_size in transformers, maybe we can use that to infer automatically the tp_size so that we create the plugin accordingly.
Not sure how well this will integrate with the current code as we don't have access to the model when creating accelerator

@@ -49,14 +49,15 @@ def setUp(self):
def test_working_of_tp(self):
self.test_file_path = self.test_scripts_folder / "test_performance.py"
cmd = get_launch_command(
num_processes=self.test_tp_size, num_machines=1, machine_rank=0, use_tp=True, tp_size=self.test_tp_size
num_processes=self.test_tp_size, num_machines=1, machine_rank=0, tp_size=self.test_tp_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is tp_size used ?

Comment on lines +397 to +398
if torch_tp_plugin is not None and not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we create automatically torch_tp_plugin if the model is sharded with tp ? Fine for me that the user have to precise this for 4.51 - 4.52 but in 4.52, we have tp_size now.

Copy link
Contributor Author
@kmehant kmehant Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc
Approach 1: Should we allow for passing model while creating accelerator and doc string it mentioning it would only be used for TP? This approach would also come in handy for any parallelism that modifies the model like context parallel. This would allow users to not needing to create the plugin and pass model while creating accelerator.

Approach 2: modify TorchTensorParallelPlugin to take model as input rather tp_size, this way we enforce to users that to enable TP, they would need to pass a tplized model only. then we extract tp_size from model to create device mesh to be used by data loader.

Approach 3: We use tp_size to only validate if it matches with what model has been sharded to + creating device mesh for data loader (that we have it already). However, it would still feel redundant, from user's PoV since they would need to pass the same tp_size while sharding in transformers and while using the plugin.

Let me know which one sounds good for this PR. Thanks

cc: @S1ro1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm up for approach 3, though we shouldn't use it to validate, but as a fallback if model has no attr tp_size, this way there's no redundancy. Also you'll notice in #3498 I also do the same for DP/FSDP_size as those will be needed as well to recreate the same device mesh as was created for TP (gpu order is different in 1d/2d case), so my proposition is:

Extra dataclass - i.e. ParallelismConfig
Which would hold all parallelism sizes, would be used as a fallback though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @S1ro1

@SunMarc your thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm i think approach 3 will be better. I don't want to force users to create the model before initializing accelerateor or TorchTensorParallelPlugin. Maybe in the future, accelerate also take care for sharding the any model for tp.
We can just check when preparing the model that if the model is tp-sharded, we check the tp size vs the one passed in TorchTensorParallelPlugin. If the model is not tp-sharded but the user passed a TorchTensorParallelPlugin, we return an error. WDYT ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @SunMarc have gone with this approach and updated the PR. Performance test for TP passes as well. Can we merge this PR as a self contained piece for now and visit the discussion at #3498 for optimizations for nd parallel? Thanks

Screenshot 2025-04-11 at 4 03 06 PM

cc: @S1ro1

@S1ro1
Copy link
Member
S1ro1 commented Apr 10, 2025

Thanks ! Now that we've merged the PR about tp_size in transformers, maybe we can use that to infer automatically the tp_size so that we create the plugin accordingly. Not sure how well this will integrate with the current code as we don't have access to the model when creating accelerator

I think we can also cover this in #3498 as for example Trainer does create the TensorParallelPlugin for us if model was sharded. I'm all for removing it and leaving it on accelerate to create plugins based on tp_size provided.

@SunMarc
Copy link
Member
SunMarc commented Apr 11, 2025

I think we can also cover this in #3498 as for example Trainer does create the TensorParallelPlugin for us if model was sharded. I'm all for removing it and leaving it on accelerate to create plugins based on tp_size provided.

I don't mind if you think if it can make things simpler for you. What are the issue currently with TensorParallelPlugin ?

kmehant added a commit to kmehant/accelerate that referenced this pull request Apr 11, 2025
see huggingface#3457 (comment) for more details

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
@S1ro1
Copy link
Member
S1ro1 commented Apr 11, 2025

I don't mind if you think if it can make things simpler for you. What are the issue currently with TensorParallelPlugin ?

The biggest issue is TensorParallelPlugin creating its own device mesh, we would like to create the device mesh based on all the parallelisms provided, not just one (TensorParallelPlugin has access only to tp_size).
There are 2 options that solve it that come to mind:

  1. Remove device_mesh from TensorParallelPlugin completely and make it 1st class citizen of Accelerator/Any of its states
  2. Make another method that resets the device mesh on the plugin.

I'm heavily leaning towards number 1, #3498 does so already (it also exposes the method on the plugin but that's to be removed). This will allow us to receive any number of ?p_size from arguments of Accelerator and construct the device mesh as such to be available to all the plugins.

With this comes a question on how to error handle:

  1. we can use these arguments only as a fact-check for what is saved on the model, those would be required.
  2. we use these arguments as a fallback -> we don't have info on how model was sharded, we use these
  3. we infer what we can (i.e. in 2D sharding only 1 needs to be specified) and throw an error if we can't - this approach I haven't really thought of through, but I'm not a fan of it really as it's pretty messy

kmehant added a commit to kmehant/accelerate that referenced this pull request Apr 11, 2025
see huggingface#3457 (comment) for more details

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
kmehant added a commit to kmehant/accelerate that referenced this pull request Apr 11, 2025
see huggingface#3457 (comment) for more details

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
kmehant added a commit to kmehant/accelerate that referenced this pull request Apr 11, 2025
see huggingface#3457 (comment) for more details

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
kmehant added a commit to kmehant/accelerate that referenced this pull request Apr 11, 2025
see huggingface#3457 (comment) for more details

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
kmehant added a commit to kmehant/accelerate that referenced this pull request Apr 11, 2025
see huggingface#3457 (comment) for more details

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
@SunMarc
Copy link
Member
SunMarc commented Apr 11, 2025

Sounds good for 1. I'm all for removing TensorParallelPlugin and using tp_size instead.
Note that right now, for deepspeed, we are creating a separate device_mesh also. Check _prepare_device_mesh method.

we can use these arguments only as a fact-check for what is saved on the model, those would be required.

I think it makes more sense to require this and not use it only as a fallback. Think of a situation where the user prepare the dataloader before the model.

Comment on lines +1593 to +1595
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
if not hasattr(model, "tp_size"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was added only recently, so we have to update BETA_TP_AVAILABLE_TRANSFORMERS_VERSION to 4.52.0 or the dev version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies on missing that, have updated it to 4.52.0 Thanks.

Copy link
Member
@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, minor nit on the version check but other than that LGTM ! You can merge it @S1ro1 if you are fine if that

kmehant added 7 commits April 11, 2025 20:32
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
see huggingface#3457 (comment) for more details

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
@S1ro1
Copy link
Member
S1ro1 commented Apr 11, 2025

Sounds good for 1. I'm all for removing TensorParallelPlugin and using tp_size instead. Note that right now, for deepspeed, we are creating a separate device_mesh also. Check _prepare_device_mesh method.

As for the device mesh, that is fine with DeepSpeedPlugin as that encapsulates all levels of parallelism (dp, fsdp, tp), though in our case, we want TP to be composable with either of DDP or FSDP or both, where we don't have a clear "leader" that will be responsible for device mesh. Therefore the best option is to create it centrally and it being accessible from all FSDP/DDP plugins, possibly other such as PP if we decide to support that. Later we can move to having a central device mesh for DeepSpeed as well, but that's possibly breaking and bigger refactor.

However, this is to be discussed later, merging and let's move the discussion to #3498 .

@S1ro1 S1ro1 merged commit 67adb47 into huggingface:main Apr 11, 2025
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

bug: broken TP training since tensor_parallel public API is removed
5 participants
0