From 051c53205f5ce23ec10e4b30eef75c02a1f06a14 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 29 Aug 2023 12:19:39 +0200 Subject: [PATCH 1/3] fix by @tannonk --- trl/trainer/dpo_trainer.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index ff4d7a504d..5c5e35de59 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -113,7 +113,7 @@ def __init__( model = prepare_model_for_int8_training(model) model = get_peft_model(model, peft_config) - self.is_peft_model = getattr(model, "is_peft_model", False) + self.is_peft_model = getattr(model, "active_adapter", False) if ref_model: self.ref_model = ref_model @@ -197,10 +197,7 @@ def __init__( ) if self.ref_model is None: - if not hasattr( - self.accelerator.unwrap_model(self.model).pretrained_model, - "disable_adapter", - ): + if not hasattr(self.accelerator.unwrap_model(self.model), "disable_adapter"): raise ValueError( "You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version." ) @@ -347,7 +344,7 @@ def get_batch_metrics( ) = self.concatenated_forward(model, batch) with torch.no_grad(): if self.ref_model is None: - with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter(): + with self.accelerator.unwrap_model(self.model).disable_adapter(): ( reference_chosen_logps, reference_rejected_logps, @@ -415,7 +412,7 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[ ) if self.ref_model is None: - with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter(): + with self.accelerator.unwrap_model(self.model).disable_adapter(): reference_output = self.model.generate( batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"], From 30e75bc7d14229b080842ded29022656f46992b5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 29 Aug 2023 12:31:32 +0200 Subject: [PATCH 2/3] Update trl/trainer/dpo_trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- trl/trainer/dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 5c5e35de59..1a737d4e9c 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -113,7 +113,7 @@ def __init__( model = prepare_model_for_int8_training(model) model = get_peft_model(model, peft_config) - self.is_peft_model = getattr(model, "active_adapter", False) + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) if ref_model: self.ref_model = ref_model From 5552f806bbba9bbd64fd4a2972a8538d6e942176 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 29 Aug 2023 12:45:27 +0200 Subject: [PATCH 3/3] add import --- trl/trainer/dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 1a737d4e9c..8a68beb190 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -29,7 +29,7 @@ if is_peft_available(): - from peft import get_peft_model, prepare_model_for_int8_training + from peft import PeftModel, get_peft_model, prepare_model_for_int8_training class DPOTrainer(Trainer):