Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🤫 TR-DPO implementation #1593

Merged
merged 20 commits into from
May 23, 2024
Merged
2 changes: 2 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ The [BCO](https://arxiv.org/abs/2404.04656) authors train a binary classifier wh

The [SPPO](https://arxiv.org/abs/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation using loss_type="sppo_hard" approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser.

The [TR-DPO](https://arxiv.org/pdf/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model` flag in the `DPOConfig`.

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
39 changes: 39 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,45 @@ def test_dpo_trainer_w_dataset_num_proc(self):

trainer.train()

def test_tr_dpo_trainer(self):
kashif marked this conversation as resolved.
Show resolved Hide resolved
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
precompute_ref_log_probs=False,
sync_ref_model=True,
)

dummy_dataset = self._init_dummy_dataset()

trainer = DPOTrainer(
model=self.model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.equal(param, new_param)

@require_no_wandb
def test_dpo_trainer_generate_during_eval_no_wandb(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class DPOConfig(TrainingArguments):
If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
force_use_ref_model (`bool`, defaults to `False`):
In case one passes a PEFT model for the active model and you want to use a different model for the ref_model, set this flag to `True`.
sync_ref_model ('bool', defaults to `False`):
The flag for syncing reference model during training from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
ref_model_mixup_alpha ('float', defaults to 1.0):
The alpha parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
ref_model_sync_steps ('int', defaults to 2):
The tau parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
"""

beta: float = 0.1
Expand All @@ -87,3 +93,6 @@ class DPOConfig(TrainingArguments):
ref_adapter_name: Optional[str] = None
reference_free: bool = False
force_use_ref_model: bool = False
sync_ref_model: bool = False
ref_model_mixup_alpha: float = 1.0
ref_model_sync_steps: int = 2
kashif marked this conversation as resolved.
Show resolved Hide resolved
19 changes: 19 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .utils import (
DPODataCollatorWithPadding,
RunningMoments,
SyncRefModelCallback,
disable_dropout_in_model,
pad_to_length,
peft_module_casting_to_bf16,
Expand Down Expand Up @@ -528,12 +529,30 @@ def make_inputs_require_grad(module, input, output):
raise ValueError(
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
)
if args.sync_ref_model:
raise ValueError(
"You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`."
)
else:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

if args.sync_ref_model:
if precompute_ref_log_probs:
raise ValueError(
"You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`."
)

self.add_callback(
SyncRefModelCallback(
ref_model=self.ref_model,
accelerator=self.accelerator,
ref_model_mixup_alpha=args.ref_model_mixup_alpha,
ref_model_sync_steps=args.ref_model_sync_steps,
)
)
if self.loss_type == "bco_pair":
self.running = RunningMoments(self.accelerator)

Expand Down
48 changes: 47 additions & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import numpy as np
import pandas as pd
import torch
from accelerate import PartialState
from accelerate import Accelerator
from accelerate.state import AcceleratorState, PartialState
from accelerate.utils import is_deepspeed_available
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
Expand All @@ -31,6 +33,7 @@
from transformers import (
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.trainer import TrainerCallback
Expand All @@ -44,6 +47,10 @@
from peft import LoraConfig, PeftConfig


if is_deepspeed_available():
import deepspeed


class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
Expand All @@ -62,6 +69,45 @@ def update(self, current, n_steps):
self.value *= mult


class SyncRefModelCallback(TrainerCallback):
def __init__(
self,
ref_model: Union[PreTrainedModel, torch.nn.Module],
accelerator: Optional[Accelerator],
ref_model_mixup_alpha: float,
ref_model_sync_steps: int,
):
self.ref_model_mixup_alpha = ref_model_mixup_alpha
self.ref_model_sync_steps = ref_model_sync_steps
self.accelerator = accelerator
self.ref_model = ref_model

@staticmethod
def _sync_target_model(model, target_model, alpha):
for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data)
kashif marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
kashif marked this conversation as resolved.
Show resolved Hide resolved
def sync_target_model(cls, model, target_model, alpha):
deepspeed_plugin = AcceleratorState().deepspeed_plugin
if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a DS expert. Is this only required for zero stage3?

with deepspeed.zero.GatheredParameters(list(model.parameters()), modifier_rank=0):
if deepspeed.comm.get_rank() == 0:
cls._sync_target_model(model, target_model, alpha)
else:
cls._sync_target_model(model, target_model, alpha)

def on_step_end(self, args, state, control, **kwargs):
model: PreTrainedModel = kwargs["model"]

if self.ref_model is not None and state.global_step % self.ref_model_sync_steps == 0:
if self.accelerator:
unwrapped_model = self.accelerator.unwrap_model(model)
self.sync_target_model(unwrapped_model, self.ref_model, self.ref_model_mixup_alpha)
else:
self.sync_target_model(model, self.ref_model, self.ref_model_mixup_alpha)
kashif marked this conversation as resolved.
Show resolved Hide resolved


class FixedKLController:
"""Fixed KL controller."""

Expand Down
Loading