diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index 66c287b7ac..633d5d27b9 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -115,6 +115,8 @@ The [SPPO](https://arxiv.org/abs/2405.00675) authors claim that SPPO is capable The [NCA](https://arxiv.org/abs/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. +The [TR-DPO](https://arxiv.org/pdf/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model` flag in the `DPOConfig`. + ## Logging While training and evaluating we record the following reward metrics: diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index a2aac261f5..227e9b838c 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -308,6 +308,48 @@ def test_dpo_trainer_w_dataset_num_proc(self): trainer.train() + def test_tr_dpo_trainer(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + precompute_ref_log_probs=False, + sync_ref_model=True, + ref_model_mixup_alpha=0.5, + ref_model_sync_steps=1, + ) + + dummy_dataset = self._init_dummy_dataset() + + trainer = DPOTrainer( + model=self.model, + ref_model=self.model, + beta=0.1, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + # params of the ref model as its the same as the model + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.ref_model.get_parameter(n) + # check the ref model's params have changed - ignore 0 biases + if param.sum() != 0: + assert not torch.equal(param, new_param) + @require_no_wandb def test_dpo_trainer_generate_during_eval_no_wandb(self): with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index a132b2eb3a..701e8dba70 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -65,6 +65,12 @@ class DPOConfig(TrainingArguments): If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. force_use_ref_model (`bool`, defaults to `False`): In case one passes a PEFT model for the active model and you want to use a different model for the ref_model, set this flag to `True`. + sync_ref_model ('bool', defaults to `False`): + The flag for syncing reference model during training from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper. + ref_model_mixup_alpha ('float', defaults to 1.0): + The alpha parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper. + ref_model_sync_steps ('int', defaults to 2): + The tau parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper. """ beta: float = 0.1 @@ -87,3 +93,6 @@ class DPOConfig(TrainingArguments): ref_adapter_name: Optional[str] = None reference_free: bool = False force_use_ref_model: bool = False + sync_ref_model: bool = False + ref_model_mixup_alpha: float = 0.9 + ref_model_sync_steps: int = 64 diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 36ce79d3a4..5e900acf6e 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -46,6 +46,7 @@ from .utils import ( DPODataCollatorWithPadding, RunningMoments, + SyncRefModelCallback, disable_dropout_in_model, pad_to_length, peft_module_casting_to_bf16, @@ -528,12 +529,23 @@ def make_inputs_require_grad(module, input, output): raise ValueError( "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" ) + if args.sync_ref_model: + raise ValueError( + "You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`." + ) else: if self.is_deepspeed_enabled: self.ref_model = self._prepare_deepspeed(self.ref_model) else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + if args.sync_ref_model: + if precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`." + ) + + self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) if self.loss_type == "bco_pair": self.running = RunningMoments(self.accelerator) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index addcb10128..c0197336e7 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -20,8 +20,9 @@ import numpy as np import pandas as pd import torch -from accelerate import PartialState -from accelerate.state import AcceleratorState +from accelerate import Accelerator +from accelerate.state import AcceleratorState, PartialState +from accelerate.utils import is_deepspeed_available from rich.console import Console, Group from rich.live import Live from rich.panel import Panel @@ -32,6 +33,7 @@ from transformers import ( BitsAndBytesConfig, DataCollatorForLanguageModeling, + PreTrainedModel, PreTrainedTokenizerBase, ) from transformers.trainer import TrainerCallback @@ -45,6 +47,10 @@ from peft import LoraConfig, PeftConfig +if is_deepspeed_available(): + import deepspeed + + class AdaptiveKLController: """ Adaptive KL controller described in the paper: @@ -63,6 +69,39 @@ def update(self, current, n_steps): self.value *= mult +class SyncRefModelCallback(TrainerCallback): + def __init__( + self, + ref_model: Union[PreTrainedModel, torch.nn.Module], + accelerator: Optional[Accelerator], + ): + self.accelerator = accelerator + self.ref_model = ref_model + + @staticmethod + def _sync_target_model(model, target_model, alpha): + for target_param, copy_param in zip(target_model.parameters(), model.parameters()): + target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha) + + @staticmethod + def sync_target_model(model, target_model, alpha): + deepspeed_plugin = AcceleratorState().deepspeed_plugin + if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3: + with deepspeed.zero.GatheredParameters(list(model.parameters()), modifier_rank=0): + if deepspeed.comm.get_rank() == 0: + SyncRefModelCallback._sync_target_model(model, target_model, alpha) + else: + SyncRefModelCallback._sync_target_model(model, target_model, alpha) + + def on_step_end(self, args, state, control, **kwargs): + model: PreTrainedModel = kwargs["model"] + + if self.ref_model is not None and state.global_step % args.ref_model_sync_steps == 0: + if self.accelerator: + model = self.accelerator.unwrap_model(model) + self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha) + + class FixedKLController: """Fixed KL controller."""