From d47d1f2689d5d70531f465be36dd4ea8d71563d1 Mon Sep 17 00:00:00 2001 From: Karel D'Oosterlinck Date: Tue, 13 Aug 2024 22:27:26 +0000 Subject: [PATCH 1/8] feat: anchored pref optimization --- trl/trainer/dpo_config.py | 2 +- trl/trainer/dpo_trainer.py | 31 ++++++++++++++++++++++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 8167d913a1..d97108c029 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -95,7 +95,7 @@ class DPOConfig(TrainingArguments): beta: float = 0.1 label_smoothing: float = 0 loss_type: Literal[ - "sigmoid", "hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "robust", "aot", "aot_pair", "exo_pair" + "sigmoid", "hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "robust", "aot", "aot_pair", "exo_pair", "apo_zero", "apo_down" ] = "sigmoid" label_pad_token_id: int = -100 padding_value: Optional[int] = None diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 3603a5a9e7..23de9aa233 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -140,7 +140,7 @@ def __init__( ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, beta: float = 0.1, label_smoothing: float = 0, - loss_type: Literal["sigmoid", "hinge", "ipo", "bco_pair", "robust", "aot", "aot_pair"] = "sigmoid", + loss_type: Literal["sigmoid", "hinge", "ipo", "bco_pair", "robust", "aot", "aot_pair", "apo_zero", "apo_down"] = "sigmoid", args: Optional[DPOConfig] = None, data_collator: Optional[DataCollator] = None, label_pad_token_id: int = -100, @@ -491,7 +491,7 @@ def make_inputs_require_grad(module, input, output): "You passed `label_smoothing` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." ) args.label_smoothing = label_smoothing - if args.loss_type in ["hinge", "ipo", "bco_pair"] and args.label_smoothing > 0: + if args.loss_type in ["hinge", "ipo", "bco_pair", "apo_zero", "apo_down"] and args.label_smoothing > 0: warnings.warn( "You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter." ) @@ -1191,9 +1191,34 @@ def dpo_loss( - F.logsigmoid(-self.beta * delta) * self.label_smoothing ) + elif self.loss_type == "apo_zero": + # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) + # Use this loss when you believe the chosen outputs are better than your model's default output + + losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood + losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood + + losses = ( + losses_chosen + + losses_rejected + ) + + elif self.loss_type == "apo_down": + # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266) + # Use this loss when you believe the chosen outputs are worse than your model's default output + + losses_chosen = F.sigmoid(self.beta * chosen_logratios) # Decrease chosen likelihood + losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) # Decrease rejected likelihood more + + losses = ( + losses_chosen + + losses_rejected + ) + + else: raise ValueError( - f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'bco_pair', 'sppo_hard', 'nca_pair', 'robust', 'exo_pair']" + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'bco_pair', 'sppo_hard', 'nca_pair', 'robust', 'exo_pair', 'apo_zero', 'apo_down']" ) chosen_rewards = ( From e125e7c0780b58b07f6bbdf8c3c3dd1f7c21fcc3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 14 Aug 2024 16:58:49 +0200 Subject: [PATCH 2/8] Update trl/trainer/dpo_trainer.py --- trl/trainer/dpo_trainer.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 23de9aa233..0743c0b2bd 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1195,26 +1195,21 @@ def dpo_loss( # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) # Use this loss when you believe the chosen outputs are better than your model's default output - losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood - losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood + losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood + losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood - losses = ( - losses_chosen + - losses_rejected - ) + losses = losses_chosen + losses_rejected elif self.loss_type == "apo_down": # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266) # Use this loss when you believe the chosen outputs are worse than your model's default output - losses_chosen = F.sigmoid(self.beta * chosen_logratios) # Decrease chosen likelihood - losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) # Decrease rejected likelihood more - - losses = ( - losses_chosen + - losses_rejected - ) + losses_chosen = F.sigmoid(self.beta * chosen_logratios) # Decrease chosen likelihood + losses_rejected = 1 - F.sigmoid( + self.beta * (chosen_logratios - rejected_logratios) + ) # Decrease rejected likelihood more + losses = losses_chosen + losses_rejected else: raise ValueError( From 314db55b51fca2c42ce65374d7177412edfbd0e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 14 Aug 2024 14:59:26 +0000 Subject: [PATCH 3/8] format and properly deprecate loss_type --- trl/trainer/dpo_config.py | 13 ++++++++++++- trl/trainer/dpo_trainer.py | 25 ++++++++++--------------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index a5ed187c44..71be97539a 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -105,7 +105,18 @@ class DPOConfig(TrainingArguments): beta: float = 0.1 label_smoothing: float = 0 loss_type: Literal[ - "sigmoid", "hinge", "ipo", "exo_pair", "nca_pair", "robust", "bco_pair", "sppo_hard", "aot", "aot_pair", "apo_zero", "apo_down" + "sigmoid", + "hinge", + "ipo", + "exo_pair", + "nca_pair", + "robust", + "bco_pair", + "sppo_hard", + "aot", + "aot_pair", + "apo_zero", + "apo_down", ] = "sigmoid" label_pad_token_id: int = -100 padding_value: Optional[int] = None diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 23de9aa233..c3de2272a9 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -140,7 +140,7 @@ def __init__( ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, beta: float = 0.1, label_smoothing: float = 0, - loss_type: Literal["sigmoid", "hinge", "ipo", "bco_pair", "robust", "aot", "aot_pair", "apo_zero", "apo_down"] = "sigmoid", + loss_type: Optional[str] = None, args: Optional[DPOConfig] = None, data_collator: Optional[DataCollator] = None, label_pad_token_id: int = -100, @@ -481,7 +481,7 @@ def make_inputs_require_grad(module, input, output): self._precomputed_train_ref_log_probs = False self._precomputed_eval_ref_log_probs = False - if loss_type != "sigmoid": + if loss_type is not None: warnings.warn( "You passed `loss_type` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." ) @@ -1195,26 +1195,21 @@ def dpo_loss( # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) # Use this loss when you believe the chosen outputs are better than your model's default output - losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood - losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood + losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood + losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood - losses = ( - losses_chosen + - losses_rejected - ) + losses = losses_chosen + losses_rejected elif self.loss_type == "apo_down": # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266) # Use this loss when you believe the chosen outputs are worse than your model's default output - losses_chosen = F.sigmoid(self.beta * chosen_logratios) # Decrease chosen likelihood - losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) # Decrease rejected likelihood more - - losses = ( - losses_chosen + - losses_rejected - ) + losses_chosen = F.sigmoid(self.beta * chosen_logratios) # Decrease chosen likelihood + losses_rejected = 1 - F.sigmoid( + self.beta * (chosen_logratios - rejected_logratios) + ) # Decrease rejected likelihood more + losses = losses_chosen + losses_rejected else: raise ValueError( From 5a3c73ede40522f2672eb7d32b223201b6855509 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 14 Aug 2024 15:06:09 +0000 Subject: [PATCH 4/8] add aot in error message and reorder --- trl/trainer/dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index c3de2272a9..924b885e01 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1213,7 +1213,7 @@ def dpo_loss( else: raise ValueError( - f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'bco_pair', 'sppo_hard', 'nca_pair', 'robust', 'exo_pair', 'apo_zero', 'apo_down']" + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down']" ) chosen_rewards = ( From fac99bbab2028dad6b34669e9892ec45d1b3cb34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 14 Aug 2024 15:12:50 +0000 Subject: [PATCH 5/8] add "sppo_hard", "nca_pair" in label_smoothing warning warning --- trl/trainer/dpo_trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 924b885e01..20fc059140 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -491,7 +491,10 @@ def make_inputs_require_grad(module, input, output): "You passed `label_smoothing` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." ) args.label_smoothing = label_smoothing - if args.loss_type in ["hinge", "ipo", "bco_pair", "apo_zero", "apo_down"] and args.label_smoothing > 0: + if ( + args.loss_type in ["hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "apo_zero", "apo_down"] + and args.label_smoothing > 0 + ): warnings.warn( "You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter." ) From 9c6f4dab764a93b6255c002c3cbf771fe4be9371 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 14 Aug 2024 15:14:04 +0000 Subject: [PATCH 6/8] add tests --- tests/test_dpo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index a0ae7ffaa7..dccd39c27a 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -163,6 +163,8 @@ def _init_dummy_image_dataset(self): ["gpt2", "robust", True], ["gpt2", "exo_pair", False], ["t5", "exo_pair", True], + ["gpt2", "apo_zero", True], + ["t5", "apo_down", False], ] ) def test_dpo_trainer(self, name, loss_type, pre_compute): From 023c425f184c99128c4c8f575813953adfc160ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 14 Aug 2024 15:30:56 +0000 Subject: [PATCH 7/8] doc --- docs/source/dpo_trainer.mdx | 2 ++ trl/trainer/dpo_config.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index b86e498da1..01e5285d33 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -150,6 +150,8 @@ The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. +The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. To use these losses, set `loss_type="apo_zero"` or `loss_type="apo_down"` in the [`DPOConfig`]. + ### For Mixture of Experts Models: Enabling the auxiliary loss MOEs are the most efficient if the load is about equally distributed between experts. diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 71be97539a..67acb14ad5 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -52,6 +52,8 @@ class DPOConfig(TrainingArguments): - `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675) paper. - `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper. - `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper. + - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. + - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. label_pad_token_id (`int`, *optional*, defaults to `-100`): The label pad token id. This argument is required if you want to use the default data collator. From c6ff3d8f8a196b9990f114e4c445e759c0d7ea74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 14 Aug 2024 15:37:11 +0000 Subject: [PATCH 8/8] doc fixes --- docs/source/dpo_trainer.mdx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index 01e5285d33..2f86c851c0 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -134,13 +134,13 @@ The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper the The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0). -The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. To use the loss set the `loss_type="exo"` in the [`DPOConfig`]. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large. +The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. To use the loss set the `loss_type="exo_pair"` in the [`DPOConfig`]. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large. -The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. To use the loss set the `loss_type="nca"` in the [`DPOConfig`]. +The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. To use the loss set the `loss_type="nca_pair"` in the [`DPOConfig`]. -The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) and set the `loss_type="robust_dpo"` in the [`DPOConfig`]. +The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) and set the `loss_type="robust"` in the [`DPOConfig`]. -The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. To use this loss, set the `loss_type="bco"` in the [`DPOConfig`]. +The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. To use this loss, set the `loss_type="bco_pair"` in the [`DPOConfig`]. The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`].