-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deepspeed integration for 7B models #19
Conversation
deepspeed_states = AcceleratorState().deepspeed_plugin | ||
deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'] = args.ppo.local_micro_batch_size | ||
deepspeed_states.deepspeed_config['checkpoint'] = {'use_node_local_storage': True} | ||
off_load_device = "cpu" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this will slow down your code significantly. I would allow the option to be set as an option that's inferred from the accelerate
config as I did here: huggingface/trl#758
|
||
deepspeed_states = AcceleratorState().deepspeed_plugin | ||
deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'] = args.ppo.local_micro_batch_size | ||
deepspeed_states.deepspeed_config['checkpoint'] = {'use_node_local_storage': True} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this flag is only needed if each node has a separate local filesystem. For the HFC case, you probably don't need it
import deepspeed | ||
|
||
deepspeed_states = AcceleratorState().deepspeed_plugin | ||
deepspeed_states.deepspeed_config['train_micro_batch_size_per_gpu'] = args.ppo.local_micro_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm not mistaken, these config values are set automatically by the accelerator and don't need to be overridden
"bf16": { | ||
"enabled": True | ||
}, | ||
"prescale_gradients": False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this flag and the one below are false by default, so probably don't need to be set either
@@ -755,7 +790,8 @@ def train(args: Args): | |||
) | |||
|
|||
with torch.no_grad(): | |||
writer.add_histogram("ppo/val/ratio_hist", ratio, update) | |||
if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` | |||
writer.add_histogram("ppo/val/ratio_hist", ratio, update) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI I was able to train 7B models in TRL with ZeRO-2 and didn't need to remove the histogram. On the other hand that was for sentiment tuning, which is less memory intensive than your application here
Confirmed that it can reasonably run 7b models (no benchmark results yet)
https://wandb.ai/costa-huang/cleanRL/runs/hn9wtka9?workspace=user-costa-huang |
This PR attempts to bring deepspeed integration to empower tuning with 7B models. In the summarize from human feedback paper, the experimented with 1.3B, 2.7B, and 6.7B models, so this PR would in principle allow us to replicate that work.
Some of the notable changes needed to make things work:
mixed_precision: 'bf16'
turns out to be important, otherwise OOM.accelerator.prepare
anddeepspeed.initialize
, otherwise OOM.bf16
forreward_model
andref_policy
, otherwise OOM.critic_model
, which they finally offloadreward_model
,critic_model
, andref_policy
to CPU, but it is not necessary in our case.Here is a training run https://wandb.ai/costa-huang/cleanRL/runs/kve7tu43/overview with
Training results was pretty bad, but I think this is probably some issue related to model compatibility. To replicate summarize from human feedback paper, we should probably use the OPT models which have 1.3B, 2.7B, and 6.7B models.
CC @lewtun