Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DPO] fix DPO ref_model=None #703

Merged
merged 3 commits into from
Aug 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


if is_peft_available():
from peft import get_peft_model, prepare_model_for_int8_training
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training


class DPOTrainer(Trainer):
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
model = prepare_model_for_int8_training(model)
model = get_peft_model(model, peft_config)

self.is_peft_model = getattr(model, "is_peft_model", False)
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)

if ref_model:
self.ref_model = ref_model
Expand Down Expand Up @@ -197,10 +197,7 @@ def __init__(
)

if self.ref_model is None:
if not hasattr(
self.accelerator.unwrap_model(self.model).pretrained_model,
"disable_adapter",
):
if not hasattr(self.accelerator.unwrap_model(self.model), "disable_adapter"):
raise ValueError(
"You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version."
)
Expand Down Expand Up @@ -347,7 +344,7 @@ def get_batch_metrics(
) = self.concatenated_forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter():
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
reference_chosen_logps,
reference_rejected_logps,
Expand Down Expand Up @@ -415,7 +412,7 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[
)

if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter():
with self.accelerator.unwrap_model(self.model).disable_adapter():
reference_output = self.model.generate(
batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
Expand Down