Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support modules_to_save config option when using DeepSpeed ZeRO-3 with ZeRO init enabled. #1450

Merged
merged 4 commits into from
Feb 9, 2024

Conversation

pacman100
Copy link
Contributor

@pacman100 pacman100 commented Feb 9, 2024

What does this PR do?

  1. When using DeepSpeed ZeRO Stage-3 with ZeRO init enabled, the deeepcopy performed in ModulesToSaveWrapper class doesn't work as expected and creates a new module with 0 parameters. As such, this results in following error when training:
assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
AssertionError    return self.modules_to_save[self.active_adapter](*args, **kwargs)
  File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
: {'id': 249, 'status': 'NOT_AVAILABLE', 'numel': 0, 'ds_numel': 0, 'shape': (0,), 'ds_shape': (0,), 'requires_grad': True, 'grad_shape': None, 'persist': True, 'active_sub_modules': {451}, 'ds_tensor.shape': torch.Size([0])}
  1. This PR resolves this issue. To resolve it, the parameters of the modules specified in modules_to_save config option are gathered across processes using deepspeed.zero.GatheredParameters.
  2. Example tested: https://github.com/huggingface/accelerate/blob/main/examples/nlp_example.py with following changes:
...
+ from peft import get_peft_model, LoraConfig
...
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True)
+ config = LoraConfig(r=8, lora_alpha=16, task_type="SEQ_CLS")
+ model  = get_peft_model(model, config)
+ print(model)
...

- config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 16}
+ config = {"lr": 2e-4, "num_epochs": 10, "seed": 42, "batch_size": 16}

without this PR, the above stated error is given and with this PR the fine-tuning happens successfully.

epoch 7: {'accuracy': 0.8480392156862745, 'f1': 0.8945578231292517}
[2024-02-09 12:26:16,228] [INFO] [loss_scaler.py:183:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536, reducing to 32768
epoch 8: {'accuracy': 0.8504901960784313, 'f1': 0.8957264957264958}
/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/autograd/__init__.py:266: UserWarning: c10d::broadcast_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at ../torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
epoch 9: {'accuracy': 0.8602941176470589, 'f1': 0.902229845626072}

launch command:

accelerate launch --use_deepspeed --num_processes=2 --zero_stage=3 --zero3_init_flag=True --zero3_save_16bit_model=True --gradient_accumulation_steps=1 --gradient_clipping=1 --mixed_precision=fp16 nlp_example.py --mixed_precision fp16

Fixes: huggingface/transformers#24445 (comment)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making modules_to_save work with DeepSpeed Zero-3. LGTM.

@pacman100 pacman100 marked this pull request as ready for review February 9, 2024 12:12
@pacman100 pacman100 merged commit a1c472f into main Feb 9, 2024
14 checks passed
@pacman100 pacman100 deleted the smangrul/fix-modules-to-save-for-ds-z3-init branch February 20, 2024 05:46
BenjaminBossan pushed a commit to BenjaminBossan/peft that referenced this pull request Mar 14, 2024
…ith ZeRO init enabled. (huggingface#1450)

* Update other.py

* Update other.py

* fix quality

* Update other.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LoRA is incompatible with DeepSpeed ZeRO3
3 participants