Skip to content

Commit

Permalink
[DPO] fix DPO ref_model=None (huggingface#703)
Browse files Browse the repository at this point in the history
* fix by @tannonk

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* add import

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
  • Loading branch information
2 people authored and Andrew Lapp committed May 10, 2024
1 parent 6a2c2c2 commit 2e8c7f7
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


if is_peft_available():
from peft import get_peft_model, prepare_model_for_int8_training
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training


class DPOTrainer(Trainer):
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
model = prepare_model_for_int8_training(model)
model = get_peft_model(model, peft_config)

self.is_peft_model = getattr(model, "is_peft_model", False)
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)

if ref_model:
self.ref_model = ref_model
Expand Down Expand Up @@ -197,10 +197,7 @@ def __init__(
)

if self.ref_model is None:
if not hasattr(
self.accelerator.unwrap_model(self.model).pretrained_model,
"disable_adapter",
):
if not hasattr(self.accelerator.unwrap_model(self.model), "disable_adapter"):
raise ValueError(
"You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version."
)
Expand Down Expand Up @@ -347,7 +344,7 @@ def get_batch_metrics(
) = self.concatenated_forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter():
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
reference_chosen_logps,
reference_rejected_logps,
Expand Down Expand Up @@ -415,7 +412,7 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[
)

if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter():
with self.accelerator.unwrap_model(self.model).disable_adapter():
reference_output = self.model.generate(
batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
Expand Down

0 comments on commit 2e8c7f7

Please sign in to comment.