From e7b350351b727bb2e190a92da3af7e7639632053 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 24 Nov 2023 15:52:40 +0100 Subject: [PATCH] [DPO] IPO Training loss (#1022) * initial IPO loss * fix loss * fixed comments * added docs * fix doc-strings * add tests * Update trl/trainer/dpo_trainer.py Co-authored-by: Leandro von Werra * fixes for review * Added doc about beta in the Trainer's docstring --------- Co-authored-by: Leandro von Werra --- docs/source/dpo_trainer.mdx | 2 ++ tests/test_dpo_trainer.py | 2 +- trl/trainer/dpo_trainer.py | 52 ++++++++++++++++++++++++++++++++----- 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index b817a75358..0e9e637114 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -85,6 +85,8 @@ Note that the `beta` is the temperature parameter for the DPO loss, typically so Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the DPO authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `DPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin. +The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. + ## Logging While training and evaluating we record the following reward metrics: diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index b1ba16711b..061f827200 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -74,7 +74,7 @@ def _init_dummy_dataset(self): # fmt: on return Dataset.from_dict(dummy_dataset_dict) - @parameterized.expand([["gpt2", "sigmoid"], ["t5", "hinge"]]) + @parameterized.expand([["gpt2", "sigmoid"], ["t5", "hinge"], ["gpt2", "ipo"], ["t5", "ipo"]]) def test_dpo_trainer(self, name, loss_type): with tempfile.TemporaryDirectory() as tmp_dir: training_args = TrainingArguments( diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index f0ab08f6c1..bf83bda8b2 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -63,9 +63,9 @@ class DPOTrainer(Trainer): Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. beta (`float`, defaults to 0.1): - The beta factor in DPO loss. Higher beta means less divergence from the initial policy. + The beta factor in DPO loss. Higher beta means less divergence from the initial policy. For the IPO loss, beta is the regularization parameter denoted by tau in the paper. loss_type (`str`, defaults to `"sigmoid"`): - The type of DPO loss to use. Either `"sigmoid"` the default DPO loss or `"hinge"` loss from SLiC paper. + The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from SLiC paper or `"ipo"` from IPO paper. args (`transformers.TrainingArguments`): The arguments to use for training. data_collator (`transformers.DataCollator`): @@ -120,7 +120,7 @@ def __init__( model: Union[PreTrainedModel, nn.Module, str] = None, ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, beta: float = 0.1, - loss_type: Literal["sigmoid", "hinge"] = "sigmoid", + loss_type: Literal["sigmoid", "hinge", "ipo"] = "sigmoid", args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, label_pad_token_id: int = -100, @@ -428,7 +428,6 @@ def dpo_loss( policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) - beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. Returns: @@ -437,13 +436,15 @@ def dpo_loss( The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. """ pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps - if reference_free: ref_logratios = 0 + else: + ref_logratios = reference_chosen_logps - reference_rejected_logps logits = pi_logratios - ref_logratios + # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. + # We ignore the reference model as beta -> 0. if self.loss_type == "sigmoid": losses = -F.logsigmoid(self.beta * logits) elif self.loss_type == "hinge": @@ -456,6 +457,38 @@ def dpo_loss( return losses, chosen_rewards, rejected_rewards + def ipo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the IPO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the IPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + """ + pi_logratios = policy_chosen_logps + reference_rejected_logps + ref_logratios = policy_rejected_logps + reference_chosen_logps + + logits = pi_logratios - ref_logratios + # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. + losses = (logits - 1 / (2 * self.beta)) ** 2 + + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return losses, chosen_rewards, rejected_rewards + def _get_batch_logps( self, logits: torch.FloatTensor, @@ -560,7 +593,12 @@ def get_batch_metrics( _, ) = self.concatenated_forward(self.ref_model, batch) - losses, chosen_rewards, rejected_rewards = self.dpo_loss( + if self.loss_type == "ipo": + loss_fn = self.ipo_loss + else: + loss_fn = self.dpo_loss + + losses, chosen_rewards, rejected_rewards = loss_fn( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps,