8000 DeepSpeed Revamp by pacman100 · Pull Request #405 · huggingface/accelerate · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

DeepSpeed Revamp #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jun 6, 2022
Merged

Conversation

pacman100
Copy link
Contributor
@pacman100 pacman100 commented May 27, 2022

What does this PR do?

  1. Support for DeepSpeed Config File
  2. Support for gradient clipping and offloading params w/o config file
  3. Fixing scheduler w/o config file
  4. Simplifying model, optimizer and scheduler wrappers by removing stale code of current wrappers and directly relying on DeepSpeed.
  5. Using HFDeepSpeedConfig when in Zero Stage-3 as per user consent to handle Deepspeed ZeRO-3 param gathering and automatically splitting the model onto multiple gpus during from_pretrained call

ToDo:

  • Write Tests
  • Redo experiments from Testing Trainer and Accelerate Integration of DeepSpeed to see if gaps between Accelerate integration and Trainer integrations are fixed.
  • ZeRo Stage-3 inference
  • Adding example leveraging config-file and related saving and loading of model when using deepspeed.
  • Documentation

@HuggingFaceDocBuilderDev
Copy link
HuggingFaceDocBuilderDev commented May 27, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator
@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for all the work and very nice new tests!
Can you confirm whether this is fully backward compatible, or if there are any breaking changes, could you document them?

"When `zero3_init_flag` is set, it requires Transformers to be installed. "
"Please run `pip3 install transformers`."
)
from transformers.deepspeed import HfDeepSpeedConfig
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ultimately, we will want this object to live in Accelerate, not in Transformers. I don't know when is the best time to move it, just putting it as a general comment :-)

Copy link
Contributor Author
@pacman100 pacman100 Jun 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this. The thing is, a weakref to HFDeepSpeedConfig is created in that file (transformers.deepspeed). This is important only when using ZeRO Stage-3 when we don't want to load Transformer models fully on CPU/GPU and we want to directly partition the model parameters across GPUs. This weakref _hf_deepspeed_config_weak_ref is used in transformers modeling_utils.py to check if deepspeed zero stage3 is enabled. If it is enabled DeepSpeed functionality of zero.init is used to directly partition model parameters across GPUs. It is used by from_config (when training from scratch) and from_pretrained (when finetuning) methods.

Snippet in modeling_utils.py:

if is_deepspeed_zero3_enabled():
    import deepspeed
    
    logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
    # this immediately partitions the model across all gpus, to avoid the overhead in time
    # and memory copying it on CPU or each GPU first
    with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
        model = cls(config, **kwargs)
else:
    model = cls(config, **kwargs)

is_deepspeed_zero3_enabled from above snippet directly refers to the weakref in transformers.deepspeed

def is_deepspeed_zero3_enabled():
    if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
        return _hf_deepspeed_config_weak_ref().is_zero3()
    else:
        return False

Due to above reasons I thought it would be good to let this be part of transformers repo as it is specifically used only in ZeRO Stage-3 for efficiently loading models that are part of transformers repo.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the move, the weakref will disappear and we will rely on the AcceleratorState to know if zero3 is enabled inside Transformers. Again, not sure when is the right point to do the move (as it will make Accelerate a hard dep of Transformers) but want to flag this is the final destination :-)

pacman100 and others added 7 commits June 2, 2022 12:56
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
1. Add example to show the usage of config file with revamped deepspeed support.
2. update required deepspeed version to 0.6.5
2. reverting `reinit` change as it is not required,
3. raising Exception when using `clip_grad_value` with DeepSpeed/FSDP.
1. Changes to support ZeRo Stage-3 Inference support.
2. minor bug fixes.
3. Documentation.
@pacman100 pacman100 marked this pull request as ready for review June 3, 2022 11:28
@pacman100 pacman100 changed the title [DO NOT MERGE] deepspeed revamp DeepSpeed revamp Jun 3, 2022
@pacman100 pacman100 changed the title DeepSpeed revamp DeepSpeed Revamp Jun 3, 2022
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
pacman100 added 2 commits June 6, 2022 16:15
1. update tests and add new one testing autofill functionality of `prepare` method.
2. fix bug related to zero-3 init related to HFDeepSpeedConfig
3. Update documentation addressing comments.
Copy link
Collaborator
@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot for working on this revamp!

@pacman100 pacman100 merged commit 1703b79 into huggingface:main Jun 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0