diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 8a68beb190..c30e2b6c2e 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -202,7 +202,12 @@ def __init__( "You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version." ) else: - self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + if self.is_deepspeed_enabled: + # Read more about the issue in https://github.com/huggingface/trl/pull/687 + self.ref_model = self.accelerator._prepare_deepspeed(self.ref_model) + self.ref_model.eval() + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: """Concatenate the chosen and rejected inputs into a single tensor.