Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Fix DeepSpeed zero-3 issue #182

Merged
merged 8 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/source/customization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,17 @@ When training large models, you should better handle the CUDA cache by iterative

```python
config = PPOConfig(..., optimize_cuda_cache=True)
```

## Use correctly DeepSpeed stage 3:

A small tweak need to be added to your training script to use DeepSpeed stage 3 correctly. You need to properly initialize your reward model on the correct device using the `zero3_init_context_manager` context manager. Here is an example adapted for the `gpt2-sentiment` script:

```python
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
with ds_plugin.zero3_init_context_manager(enable=False):
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
else:
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
```
21 changes: 19 additions & 2 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,33 @@ def __init__(
else:
self.kl_ctl = FixedKLController(self.config.init_kl_coef)

# Safety checkers for DS integration
is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
self.accelerator.state, "deepspeed_plugin"
)

(
self.model,
self.ref_model,
self.optimizer,
self.data_collator,
self.dataloader,
self.lr_scheduler,
) = self.accelerator.prepare(
self.model, self.ref_model, self.optimizer, self.data_collator, self.dataloader, self.lr_scheduler
self.model, self.optimizer, self.data_collator, self.dataloader, self.lr_scheduler
)
if is_deepspeed_used:
# 8 bit models are already set on the correct device
if not getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False):
# DS integration only allows for single model and as `ref_model` is only used for
# `KL devergence loss`,i.e, in eval model, just have it be on the respective device and
# there is no need to pass it to the `accelerator.prepare` call
self.ref_model = self.ref_model.to(self.accelerator.device)

# this hack seems to be needed for DS stage 3 to work
if self.accelerator.state.deepspeed_plugin.zero_stage == 3:
self.model.train()
Comment on lines +291 to +292
Copy link
Contributor Author

@younesbelkada younesbelkada Mar 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the offline discussion I had with @pacman100 , I confirm this hack is needed to make DS3 work

else:
self.ref_model = self.accelerator.prepare(self.ref_model)

# In a distributed setup, only logging needs to be performed on the main process
# check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
Expand Down