Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend DeepSpeed integration to ZeRO-{1,2,3} #758

Merged
merged 13 commits into from
Sep 12, 2023
Merged

Conversation

lewtun
Copy link
Member

@lewtun lewtun commented Sep 12, 2023

This PR extends the DeepSpeed initialization of the reference model to work with all stages of DeepSpeed ZeRO.

I'll share some plots of the GPT-2 runs on sentiment tuning shortly, but the code should be good for a review.

Tested with:

accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml examples/scripts/sentiment_tuning.py --batch_size 32 --mini_batch_size 32 --log_with wandb

Here's the screenshots of the various runs on wandb: https://wandb.ai/huggingface/trl?workspace=user-lewtun

Overall, getting good agreement between the baseline (no DeepSpeed) and stages 1 & 2, while stage 3 has a noticeable discrepancy in the value loss that is worth digging into in a separate issue IMO.

Screenshot 2023-09-12 at 15 14 19
Screenshot 2023-09-12 at 15 14 12
Screenshot 2023-09-12 at 15 14 07

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 12, 2023

The documentation is not available anymore as the PR was closed or merged.

@lewtun lewtun changed the title [WIP] Extend DeepSpeed integration to ZeRO-{1,2,3} Extend DeepSpeed integration to ZeRO-{1,2,3} Sep 12, 2023
@lewtun lewtun marked this pull request as ready for review September 12, 2023 10:09
@@ -8,7 +8,7 @@ distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
mixed_precision: 'bf16'
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can now set this as the default since we initialise both the reference and active models with DeepSpeed

@@ -38,6 +38,9 @@ class ScriptArguments:
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
# models like gpt-neo* models are more suitable.
model_name: Optional[str] = field(default="lvwerra/gpt2-imdb", metadata={"help": "the model name"})
reward_model_name: Optional[str] = field(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added this arg to make it easier to configure the running of this script

deepspeed_plugin = self.accelerator.state.deepspeed_plugin
batch_size_per_device = deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"]
# See DeepSpeed docs for definition of these parameters: https://deepspeed.readthedocs.io/en/latest/zero3.html
config_kwargs = {
Copy link
Member Author

@lewtun lewtun Sep 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these parameters are set automatically by accelerate and this don't need duplicating. One check I need to make is the inclusion of gradient accumulation.

Update: yes, train_batch_size does reflect the size of gradient accumulation as well, so this is fine to be removed IMO

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Sep 12, 2023

@lewtun nice work! I love that there are gpt2 runs across different zero stages. Could you also test zero 2 + 3 on larger models such as falcon 7b or cerebras GPT 6.7B?

@lewtun
Copy link
Member Author

lewtun commented Sep 12, 2023

@lewtun nice work! I love that there are gpt2 runs across different zero stages. Could you also test zero 2 + 3 on larger models such as falcon 7b or cerebras GPT 6.7B?

Yes, I'm running the Cerebras models as we speak and will report back when the runs are done :)

@lewtun
Copy link
Member Author

lewtun commented Sep 12, 2023

Update on running 3 x 6.7B models with DeepSpeed on sentiment_tuning.py:

  • ZeRO-2 works fine and doesn't OOM
  • ZeRO-3 works fine and doesn't OOM

Here's the command I used to test:

accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{2,3}.yaml examples/scripts/sentiment_tuning.py --batch_size 32 --mini_batch_size 32 --
log_with wandb --model_name cerebras/Cerebras-GPT-6.7B --reward_model_name cerebras/Cerebras-GPT-6.7B

Interestingly, although ZeRO-3 is less memory intensive, the savings aren't as high as I would have expected on a single node:

Screenshot 2023-09-12 at 15 48 25

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really great to me, thanks a lot for the great investigation and spending some time on benchmarking and testing, to make sure we now support DS Zero 1, 2, 3!

Comment on lines +187 to +192
# Some tokenizers like GPT-2's don't have a padding token by default, so we set one here.
if sentiment_pipe.tokenizer.pad_token_id is None:
sentiment_pipe.tokenizer.pad_token_id = tokenizer.pad_token_id

if sentiment_pipe.model.config.pad_token_id is None:
sentiment_pipe.model.config.pad_token_id = tokenizer.pad_token_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why this was not needed before?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's usually not needed if you've already trained a proper reward model because this comes with a proper padding token. However, if you want to plug and play with any causal LM on the Hub then this is typically needed to avoid throwing errors in the pipeline

docs/source/customization.mdx Outdated Show resolved Hide resolved
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
@uahmad235
Copy link

Great work @lewtun
I am trying to replicate your experiments but going OOM on 2xA6000s.
Would be really helpful if you could provide following info:

  1. What GPUs are you using
  2. How much RAM does your training machine have
  3. Are you using default deepspeed config (from repo) or have you modified anything like offload param?

Thanks in Advance!

@lewtun
Copy link
Member Author

lewtun commented Sep 13, 2023

Hi @uahmad235 !

Here's answers to your questions:

  1. I'm using a single node of 8 x A100s with 80GB vRAM each
  2. I have 1TB CPU RAM
  3. I am using the default deepspeed configs from the repo with no changes (i.e. no offloading)

It will likely be tight to fit 3 x 7B models on 2 x A6000s, so one possibility would be to quantize the reward model by passing load_in_8bit=True. Alternatively, you could try shrinking the batch size & increasing the number of gradient accumulation steps.

Hope that helps!

@uahmad235
Copy link

Thanks for the info @lewtun.
I am actually using the default reward model (i.e., lvwerra/distilbert-imdb) so it is not too heavy on the memory.
And yeah, I tried using load_in_8bit with batchsize=4. It runs into runtime error at ppo_trainer.step(). Here's the issue i get:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

However, using torch_dtype=torch.bfloat16 does help a little with memory but not enough that it can allow me to make gradient updates even with batch_size=2

Seems like i might have to go for a pair of A100s.

kushal-tri pushed a commit to kushalarora/trl that referenced this pull request Sep 19, 2023
* Generalise deepspeed

* Refactor

* Add reward model arg

* Fix pipeline tokenizer

* Fix deprecation

* Pin deepspeed lower

* Fix docs

* Revert top_k change

* Add ZeRO-3 context manager

* Revert docs change

* Fix docs

* Polish docs

* Update docs/source/customization.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* Generalise deepspeed

* Refactor

* Add reward model arg

* Fix pipeline tokenizer

* Fix deprecation

* Pin deepspeed lower

* Fix docs

* Revert top_k change

* Add ZeRO-3 context manager

* Revert docs change

* Fix docs

* Polish docs

* Update docs/source/customization.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants