diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index ff4d7a504d..8a68beb190 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -29,7 +29,7 @@ if is_peft_available(): - from peft import get_peft_model, prepare_model_for_int8_training + from peft import PeftModel, get_peft_model, prepare_model_for_int8_training class DPOTrainer(Trainer): @@ -113,7 +113,7 @@ def __init__( model = prepare_model_for_int8_training(model) model = get_peft_model(model, peft_config) - self.is_peft_model = getattr(model, "is_peft_model", False) + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) if ref_model: self.ref_model = ref_model @@ -197,10 +197,7 @@ def __init__( ) if self.ref_model is None: - if not hasattr( - self.accelerator.unwrap_model(self.model).pretrained_model, - "disable_adapter", - ): + if not hasattr(self.accelerator.unwrap_model(self.model), "disable_adapter"): raise ValueError( "You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version." ) @@ -347,7 +344,7 @@ def get_batch_metrics( ) = self.concatenated_forward(model, batch) with torch.no_grad(): if self.ref_model is None: - with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter(): + with self.accelerator.unwrap_model(self.model).disable_adapter(): ( reference_chosen_logps, reference_rejected_logps, @@ -415,7 +412,7 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[ ) if self.ref_model is None: - with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter(): + with self.accelerator.unwrap_model(self.model).disable_adapter(): reference_output = self.model.generate( batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"],