From 1909aa71a21d4ab4355e8f49376174cbd2b6e4d0 Mon Sep 17 00:00:00 2001 From: younesbelakda Date: Tue, 28 Feb 2023 16:37:00 +0000 Subject: [PATCH 1/7] fix zero-3 issue --- trl/trainer/ppo_trainer.py | 43 +++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 87975d05ad..5d98818743 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -244,17 +244,41 @@ def __init__( else: self.kl_ctl = FixedKLController(self.config.init_kl_coef) - ( - self.model, - self.ref_model, - self.optimizer, - self.data_collator, - self.dataloader, - self.lr_scheduler, - ) = self.accelerator.prepare( - self.model, self.ref_model, self.optimizer, self.data_collator, self.dataloader, self.lr_scheduler + # Safety checkers for DS integration + is_deepspeed_zero_3 = ( + self.accelerator.distributed_type == "DEEPSPEED" + and hasattr(self.accelerator.state, "deepspeed_plugin") + and self.accelerator.state.deepspeed_plugin.zero_stage == 3 ) + if is_deepspeed_zero_3: + ( + self.model, + self.optimizer, + self.data_collator, + self.dataloader, + self.lr_scheduler, + ) = self.accelerator.prepare( + self.model, self.optimizer, self.data_collator, self.dataloader, self.lr_scheduler + ) + # 8 bit models are already set on the correct device + if not getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False): + # DS integration only allows for single model and as `ref_model` is only used for + # `KL devergence loss`,i.e, in eval model, just have it be on the respective device and + # there is no need to pass it to the `accelerator.prepare` call + self.ref_model = self.ref_model.to(self.accelerator.device) + else: + ( + self.model, + self.ref_model, + self.optimizer, + self.data_collator, + self.dataloader, + self.lr_scheduler, + ) = self.accelerator.prepare( + self.model, self.ref_model, self.optimizer, self.data_collator, self.dataloader, self.lr_scheduler + ) + # In a distributed setup, only logging needs to be performed on the main process # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html # or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11 @@ -677,7 +701,6 @@ def train_minibatch( train_stats (dict[str, `torch.Tensor`]): Dictionary of training statistics """ - loss_p, loss_v, train_stats = self.loss(old_logprobs, values, rewards, logits, vpreds, logprobs, mask) loss = loss_p + loss_v self.optimizer.zero_grad() From fbbe9ebbbd3009b7728f2cfa9f4e6e0be08d8d0b Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 27 Mar 2023 12:13:02 +0200 Subject: [PATCH 2/7] Update trl/trainer/ppo_trainer.py Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> --- trl/trainer/ppo_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index f635cf696c..2efa09cd2b 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -269,7 +269,6 @@ def __init__( is_deepspeed_zero_3 = ( self.accelerator.distributed_type == "DEEPSPEED" and hasattr(self.accelerator.state, "deepspeed_plugin") - and self.accelerator.state.deepspeed_plugin.zero_stage == 3 ) if is_deepspeed_zero_3: From bd06f788d3c5768de96418e9ec81ff57450c23b9 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 27 Mar 2023 10:20:27 +0000 Subject: [PATCH 3/7] adapt --- trl/trainer/ppo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 2efa09cd2b..e874720193 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -266,12 +266,12 @@ def __init__( self.kl_ctl = FixedKLController(self.config.init_kl_coef) # Safety checkers for DS integration - is_deepspeed_zero_3 = ( + is_deepspeed_used = ( self.accelerator.distributed_type == "DEEPSPEED" and hasattr(self.accelerator.state, "deepspeed_plugin") ) - if is_deepspeed_zero_3: + if is_deepspeed_used: ( self.model, self.optimizer, From bb13034ec48c77d64b65e72f9c167195e6cf53f4 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 27 Mar 2023 11:00:53 +0000 Subject: [PATCH 4/7] make style --- trl/trainer/ppo_trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index e874720193..4684a32669 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -266,9 +266,8 @@ def __init__( self.kl_ctl = FixedKLController(self.config.init_kl_coef) # Safety checkers for DS integration - is_deepspeed_used = ( - self.accelerator.distributed_type == "DEEPSPEED" - and hasattr(self.accelerator.state, "deepspeed_plugin") + is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr( + self.accelerator.state, "deepspeed_plugin" ) if is_deepspeed_used: From 71424b45a184ba1d13d8a0d05a02cbc06d572ef4 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 27 Mar 2023 15:16:15 +0000 Subject: [PATCH 5/7] fix --- trl/trainer/ppo_trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 4684a32669..d3d4afd1ba 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -286,6 +286,10 @@ def __init__( # `KL devergence loss`,i.e, in eval model, just have it be on the respective device and # there is no need to pass it to the `accelerator.prepare` call self.ref_model = self.ref_model.to(self.accelerator.device) + + # this hack seems to be needed for DS stage 3 to work + if self.accelerator.state.deepspeed_plugin.zero_stage == 3: + self.model.train() else: ( self.model, From f87dd85c5b0998b7748fe79dc3344458263b9dad Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 27 Mar 2023 15:19:57 +0000 Subject: [PATCH 6/7] add docs --- docs/source/customization.mdx | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/source/customization.mdx b/docs/source/customization.mdx index 4b915c1140..bad6226af5 100644 --- a/docs/source/customization.mdx +++ b/docs/source/customization.mdx @@ -149,4 +149,17 @@ When training large models, you should better handle the CUDA cache by iterative ```python config = PPOConfig(..., optimize_cuda_cache=True) +``` + +## Use correctly DeepSpeed stage 3: + +A small tweak need to be added to your training script to use DeepSpeed stage 3 correctly. You need to properly initialize your reward model on the correct device using the `zero3_init_context_manager` context manager. Here is an example adapted for the `gpt2-sentiment` script: + +```python +ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin +if ds_plugin is not None and ds_plugin.is_zero3_init_enabled(): + with ds_plugin.zero3_init_context_manager(enable=False): + sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device) +else: + sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device) ``` \ No newline at end of file From 59e6d5df3dac6090041bce9787e71dfcd081afcb Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 28 Mar 2023 11:31:51 +0000 Subject: [PATCH 7/7] fix --- trl/trainer/ppo_trainer.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index d3d4afd1ba..3cd885fd32 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -270,16 +270,16 @@ def __init__( self.accelerator.state, "deepspeed_plugin" ) + ( + self.model, + self.optimizer, + self.data_collator, + self.dataloader, + self.lr_scheduler, + ) = self.accelerator.prepare( + self.model, self.optimizer, self.data_collator, self.dataloader, self.lr_scheduler + ) if is_deepspeed_used: - ( - self.model, - self.optimizer, - self.data_collator, - self.dataloader, - self.lr_scheduler, - ) = self.accelerator.prepare( - self.model, self.optimizer, self.data_collator, self.dataloader, self.lr_scheduler - ) # 8 bit models are already set on the correct device if not getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False): # DS integration only allows for single model and as `ref_model` is only used for @@ -291,16 +291,7 @@ def __init__( if self.accelerator.state.deepspeed_plugin.zero_stage == 3: self.model.train() else: - ( - self.model, - self.ref_model, - self.optimizer, - self.data_collator, - self.dataloader, - self.lr_scheduler, - ) = self.accelerator.prepare( - self.model, self.ref_model, self.optimizer, self.data_collator, self.dataloader, self.lr_scheduler - ) + self.ref_model = self.accelerator.prepare(self.ref_model) # In a distributed setup, only logging needs to be performed on the main process # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html