diff --git a/docs/source/bco_trainer.md b/docs/source/bco_trainer.md index 7bc4ccb60a5..ac105cf5513 100644 --- a/docs/source/bco_trainer.md +++ b/docs/source/bco_trainer.md @@ -9,7 +9,7 @@ For a full example have a look at [`examples/scripts/bco.py`]. ## Expected dataset type The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference). -The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. ## Expected model format The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. diff --git a/docs/source/cpo_trainer.md b/docs/source/cpo_trainer.md index 5d6961f0c8f..ab182395d42 100644 --- a/docs/source/cpo_trainer.md +++ b/docs/source/cpo_trainer.md @@ -4,7 +4,7 @@ ## Overview -Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation of the DPO loss and can be applied to other domains, such as chat. +Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high level, CPO trains models to avoid generating adequate, but not perfect, translations in Machine Translation (MT) tasks. However, CPO is a general approximation of the DPO loss and can be applied to other domains, such as chat. CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective. @@ -44,7 +44,7 @@ accelerate launch train_cpo.py ## Expected dataset type -CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. ## Example script @@ -62,7 +62,7 @@ accelerate launch examples/scripts/cpo.py \ ## Logged metrics -While training and evaluating we record the following reward metrics: +While training and evaluating, we record the following reward metrics: * `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta * `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta @@ -74,28 +74,46 @@ While training and evaluating we record the following reward metrics: ### Simple Preference Optimization (SimPO) -The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`]. +[Simple Preference Optimization](https://huggingface.co/papers/2405.14734) (SimPO) by [Yu Meng](https://huggingface.co/yumeng5), [Mengzhou Xia](https://huggingface.co/mengzhouxia), and [Danqi Chen](https://huggingface.co/cdq10131) proposes a simpler and more effective preference optimization algorithm than DPO without using a reference model. The key designs in SimPO are (1) using length-normalized log likelihood as the implicit reward, and (2) incorporating a target reward margin in the Bradley-Terry ranking objective. The official code can be found at [princeton-nlp/SimPO](https://github.com/princeton-nlp/SimPO). + +The abstract from the paper is the following: + +> Direct Preference Optimization (DPO) is a widely used offline preference optimization algorithm that reparameterizes reward functions in reinforcement learning from human feedback (RLHF) to enhance simplicity and training stability. In this work, we propose SimPO, a simpler yet more effective approach. The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as the implicit reward. This reward formulation better aligns with model generation and eliminates the need for a reference model, making it more compute and memory efficient. Additionally, we introduce a target reward margin to the Bradley-Terry objective to encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance. We compare SimPO to DPO and its latest variants across various state-of-the-art training setups, including both base and instruction-tuned models like Mistral and Llama3. We evaluated on extensive instruction-following benchmarks, including AlpacaEval 2, MT-Bench, and the recent challenging Arena-Hard benchmark. Our results demonstrate that SimPO consistently and significantly outperforms existing approaches without substantially increasing response length. Specifically, SimPO outperforms DPO by up to 6.4 points on AlpacaEval 2 and by up to 7.5 points on Arena-Hard. Our top-performing model, built on Llama3-8B-Instruct, achieves a remarkable 44.7 length-controlled win rate on AlpacaEval 2 -- surpassing Claude 3 Opus on the leaderboard, and a 33.8 win rate on Arena-Hard -- making it the strongest 8B open-source model. + +The SimPO loss is integrated in the [`CPOTrainer`], as it's an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, just turn on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`] and set the `simpo_gamma` to a recommended value. ### CPO-SimPO We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`CPOConfig`]. +### AlphaPO + +The [AlphaPO -- Reward shape matters for LLM alignment](https://huggingface.co/papers/2501.03884) (AlphaPO) method by Aman Gupta, Shao Tang, Qingquan Song, Sirou Zhu, [Jiwoo Hong](https://huggingface.co/JW17), Ankan Saha, Viral Gupta, Noah Lee, Eunki Kim, Jason Zhu, Natesh Pillai, and S. Sathiya Keerthi is also implemented in the [`CPOTrainer`]. AlphaPO is an alternative method that applies a transformation to the reward function shape in the context of SimPO loss. The abstract from the paper is the following: + +> Reinforcement Learning with Human Feedback (RLHF) and its variants have made huge strides toward the effective alignment of large language models (LLMs) to follow instructions and reflect human values. More recently, Direct Alignment Algorithms (DAAs) have emerged in which the reward modeling stage of RLHF is skipped by characterizing the reward directly as a function of the policy being learned. Some popular examples of DAAs include Direct Preference Optimization (DPO) and Simple Preference Optimization (SimPO). These methods often suffer from likelihood displacement, a phenomenon by which the probabilities of preferred responses are often reduced undesirably. In this paper, we argue that, for DAAs the reward (function) shape matters. We introduce AlphaPO, a new DAA method that leverages an α-parameter to help change the shape of the reward function beyond the standard log reward. AlphaPO helps maintain fine-grained control over likelihood displacement and overoptimization. Compared to SimPO, one of the best performing DAAs, AlphaPO leads to about 7% to 10% relative improvement in alignment performance for the instruct versions of Mistral-7B and Llama3-8B while achieving 15% to 50% relative improvement over DPO on the same models. The analysis and results presented highlight the importance of the reward shape and how one can systematically change it to affect training dynamics, as well as improve alignment performance. + +To use this loss as described in the paper, we can set the `loss_type="alphapo"` which automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`, together with `alpha` and `simpo_gamma` to recommended values in the [`CPOConfig`]. Alternatively, you can manually set `loss_type="simpo"`, `cpo_alpha=0.0`, together with `alpha` and `simpo_gamma` to recommended values. Other variants of this method are also possible, such as setting `loss_type="ipo"` and `alpha` to any non-zero value. + ## Loss functions The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported: | `loss_type=` | Description | | -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. | +| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model, and in fact, the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. | | `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. | -| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). | +| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair, and thus the smaller the `beta`, the larger this gap is. As per the paper, the loss is averaged over log-likelihoods of the completion (unlike DPO, which is summed only). | +| `"simpo"` | The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, simply set `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`] and `simpo_gamma` to a recommended value. | +| `"alphapo"` | The [AlphaPO](https://huggingface.co/papers/2501.03884) method is also implemented in the [`CPOTrainer`]. This is syntactic sugar that automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`. AlphaPO applies a transformation to the reward function shape in the context of SimPO loss when the `alpha` parameter is non-zero. | + + ### For Mixture of Experts Models: Enabling the auxiliary loss MOEs are the most efficient if the load is about equally distributed between experts. To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. -This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]). +This option is enabled by setting `output_router_logits=True` in the model config (e.g., [`~transformers.MixtralConfig`]). To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config. ## CPOTrainer diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md index 40987bbee2b..10e298027f9 100644 --- a/docs/source/dataset_formats.md +++ b/docs/source/dataset_formats.md @@ -213,7 +213,7 @@ For more detailed information on tool calling, refer to the [Tool Calling sectio The [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) was introduced with the [OpenAI GPT OSS models](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4). It extends the conversational format by adding richer structure for reasoning, function calls, and metadata about the model’s behavior. Key features include: -- **Developer role** – Provides high-level instructions (similar to a system prompt) and lists available tools. +- **Developer role** – Provides high level instructions (similar to a system prompt) and lists available tools. - **Channels** – Separate types of assistant output into distinct streams: - `analysis` – for internal reasoning, from the key `"thinking"` diff --git a/docs/source/dpo_trainer.md b/docs/source/dpo_trainer.md index 5753b33695b..c2275a29608 100644 --- a/docs/source/dpo_trainer.md +++ b/docs/source/dpo_trainer.md @@ -118,7 +118,7 @@ accelerate launch trl/scripts/dpo.py \ ## Logged metrics -While training and evaluating we record the following reward metrics: +While training and evaluating, we record the following reward metrics: - `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta - `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 5db1b545471..a1637649ef7 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -149,6 +149,8 @@ This constant is recommended to be the maximum completion length. To use this fo ## Logged metrics +While training and evaluating, we record the following reward metrics: + - `num_tokens`: The total number of tokens processed so far, including both prompts and completions. - `completions/mean_length`: The average length of generated completions. - `completions/min_length`: The minimum length of generated completions. diff --git a/docs/source/kto_trainer.md b/docs/source/kto_trainer.md index 1186e48f76f..16e2a7f012b 100644 --- a/docs/source/kto_trainer.md +++ b/docs/source/kto_trainer.md @@ -74,7 +74,7 @@ Here are some other factors to consider when choosing a programming language for KTO requires an [unpaired preference dataset](dataset_formats#unpaired-preference). Alternatively, you can provide a *paired* preference dataset (also known simply as a *preference dataset*). In this case, the trainer will automatically convert it to an unpaired format by separating the chosen and rejected responses, assigning `label = True` to the chosen completions and `label = False` to the rejected ones. -The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. In theory, the dataset should contain at least one chosen and one rejected completion. However, some users have successfully run KTO using *only* chosen or only rejected data. If using only rejected data, it is advisable to adopt a conservative learning rate. @@ -118,7 +118,7 @@ By default, they are both 1. However, if you have more of one or the other, then ## Logged metrics -While training and evaluating we record the following reward metrics: +While training and evaluating, we record the following reward metrics: - `rewards/chosen_sum`: the sum of log probabilities of the policy model for the chosen responses scaled by beta - `rewards/rejected_sum`: the sum of log probabilities of the policy model for the rejected responses scaled by beta diff --git a/docs/source/nash_md_trainer.md b/docs/source/nash_md_trainer.md index 3050af33073..82b1ab24a27 100644 --- a/docs/source/nash_md_trainer.md +++ b/docs/source/nash_md_trainer.md @@ -63,7 +63,7 @@ The best programming language depends on personal preference, the complexity of ## Expected dataset type -Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. ## Usage tips @@ -132,7 +132,7 @@ python examples/scripts/nash_md.py \ ## Logged metrics -The logged metrics are as follows: +While training and evaluating, we record the following reward metrics: * `loss/kl`: The mean KL divergence between the model and reference data. * `objective/entropy`: The mean entropy of the model and reference data. diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index 7081807b809..eac8ca550cc 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -65,7 +65,7 @@ The best programming language depends on your specific needs and priorities. Som ## Expected dataset type -Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. ## Usage tips @@ -132,7 +132,7 @@ python examples/scripts/dpo_online.py \ ## Logged metrics -The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/w4apmsi9) +While training and evaluating, we record the following reward metrics. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/w4apmsi9) * `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current model and reference model. * `objective/entropy`: The mean entropy of the model, indicating the randomness of the actions chosen by the model. diff --git a/docs/source/orpo_trainer.md b/docs/source/orpo_trainer.md index cb2e2b886b0..ab9ef411642 100644 --- a/docs/source/orpo_trainer.md +++ b/docs/source/orpo_trainer.md @@ -109,7 +109,7 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype ## Logged metrics -While training and evaluating we record the following reward metrics: +While training and evaluating, we record the following reward metrics: - `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta - `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 448aea288cb..86955d53903 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -26,6 +26,26 @@ training_args = GRPOConfig( ) ``` +## AlphaPO -- Reward shape matters for LLM alignment + +**📜 Paper**: https://huggingface.co/papers/2501.03884 + +AlphaPO is a new Direct Alignment Algorithms (DAAs) method that leverages an alpha-parameter to help change the shape of the reward function beyond the standard log reward. AlphaPO helps maintain fine-grained control over likelihood displacement and over-optimization. To reproduce the paper's setting, use this configuration: + +```python +from trl import CPOConfig + +# Mistral-Instruct from Table 3 of the paper +training_args = CPOConfig( + loss_type="alphapo", + alpha=0.25, + beta=2.5, + simpo_gamma=0.1, + learning_rate=7e-7, + ... +) +``` + ## EMA Without the Lag: Bias-Corrected Iterate Averaging Schemes **📜 Paper**: https://huggingface.co/papers/2508.00180 diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md index b2ceef53bd6..94a52b9b9f1 100644 --- a/docs/source/sft_trainer.md +++ b/docs/source/sft_trainer.md @@ -112,6 +112,8 @@ Padding tokens (if present) are ignored in the loss computation by applying an i ## Logged metrics +While training and evaluating we record the following reward metrics: + * `global_step`: The total number of optimizer steps taken so far. * `epoch`: The current epoch number, based on dataset iteration. * `num_tokens`: The total number of tokens processed so far. diff --git a/docs/source/xpo_trainer.md b/docs/source/xpo_trainer.md index 72aab3a016c..61499d768a4 100644 --- a/docs/source/xpo_trainer.md +++ b/docs/source/xpo_trainer.md @@ -131,7 +131,7 @@ python examples/scripts/xpo.py \ ## Logged metrics -The logged metrics are as follows: +While training and evaluating we record the following reward metrics: * `loss/xpo`: The mean xpo part of the full loss. * `loss/dpo`: The mean dpo part of the full loss. diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py index 7ae92169dbb..cc3e394846d 100644 --- a/tests/test_cpo_trainer.py +++ b/tests/test_cpo_trainer.py @@ -153,10 +153,6 @@ def test_cpo_trainer_with_lora(self, config_name): self.assertFalse(torch.equal(param, new_param)) def test_compute_metrics(self): - model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") - tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") - tokenizer.pad_token = tokenizer.eos_token - dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") def dummy_compute_metrics(*args, **kwargs): @@ -174,9 +170,9 @@ def dummy_compute_metrics(*args, **kwargs): ) trainer = CPOTrainer( - model=model, + model=self.model, args=training_args, - processing_class=tokenizer, + processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], compute_metrics=dummy_compute_metrics, @@ -185,3 +181,40 @@ def dummy_compute_metrics(*args, **kwargs): trainer.train() self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) + + def test_alphapo_trainer(self): + training_args = CPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + loss_type="alphapo", + alpha=0.5, + simpo_gamma=0.5, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = CPOTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) diff --git a/trl/trainer/cpo_config.py b/trl/trainer/cpo_config.py index 97dd314245a..70f5c554f21 100644 --- a/trl/trainer/cpo_config.py +++ b/trl/trainer/cpo_config.py @@ -54,6 +54,8 @@ class CPOConfig(TrainingArguments): [SLiC](https://huggingface.co/papers/2305.10425) paper. - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper. + - `"alphapo"`: AlphaPO loss from the [AlphaPO](https://huggingface.co/papers/2501.03884) paper. This + automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`. disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the model. @@ -61,6 +63,11 @@ class CPOConfig(TrainingArguments): Weight of the BC regularizer in CPO training. simpo_gamma (`float`, *optional*, defaults to `0.5`): Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`. + alpha (`float`, *optional*, defaults to `0.0`): + Alpha parameter that controls reward function shape across all loss types. When alpha=0 (default), uses + standard log probability rewards. When `alpha != 0`, applies AlphaPO transformation: `r = (1 - p^(-alpha)) + / alpha` from the [AlphaPO paper](https://huggingface.co/papers/2501.03884). This parameter works with all + loss types. label_pad_token_id (`int`, *optional*, defaults to `-100`): Label pad token id. This argument is required if you want to use the default data collator. padding_value (`int` or `None`, *optional*, defaults to `None`): @@ -142,7 +149,7 @@ class CPOConfig(TrainingArguments): default="sigmoid", metadata={ "help": "Type of loss to use.", - "choices": ["sigmoid", "hinge", "ipo", "simpo"], + "choices": ["sigmoid", "hinge", "ipo", "simpo", "alphapo"], }, ) disable_dropout: bool = field( @@ -157,6 +164,14 @@ class CPOConfig(TrainingArguments): default=0.5, metadata={"help": "Target reward margin for the SimPO loss, used only when the `loss_type='simpo'`."}, ) + alpha: float = field( + default=0.0, + metadata={ + "help": "Alpha parameter that controls reward function shape across all loss types. When alpha=0 " + "(default), uses standard log probability rewards. When `alpha != 0`, applies AlphaPO transformation: " + "`r = (1 - p^(-alpha)) / alpha` from the AlphaPO paper. This parameter works with all loss types." + }, + ) label_pad_token_id: int = field( default=-100, metadata={"help": "Label pad token id."}, @@ -195,4 +210,9 @@ class CPOConfig(TrainingArguments): def __post_init__(self): self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + # Syntactic sugar for AlphaPO: set loss_type to "simpo" and cpo_alpha to 0.0 + if self.loss_type == "alphapo": + self.loss_type = "simpo" + self.cpo_alpha = 0.0 + super().__post_init__() diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 6ce4ada6116..2f0a0847331 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -319,6 +319,9 @@ def make_inputs_require_grad(module, input, output): if args.loss_type == "simpo": self.simpo_gamma = args.simpo_gamma + # AlphaPO parameter for reward shaping + self.alpha = args.alpha + self._stored_metrics = defaultdict(lambda: defaultdict(list)) # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the @@ -659,7 +662,20 @@ def cpo_loss( loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. """ - logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device) + # Apply AlphaPO reward transformation if alpha != 0 + if self.alpha != 0.0: + # Compute probabilities + chosen_probs = torch.exp(policy_chosen_logps) + rejected_probs = torch.exp(policy_rejected_logps) + + # Apply AlphaPO transformation: r = (1 - p^(-alpha)) / alpha + policy_chosen_rewards = (1 - chosen_probs.pow(-self.alpha)) / self.alpha + policy_rejected_rewards = (1 - rejected_probs.pow(-self.alpha)) / self.alpha + + logits = (policy_chosen_rewards - policy_rejected_rewards).to(self.accelerator.device) + else: + # Standard log probability rewards when alpha = 0 + logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device) # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5. # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and @@ -689,8 +705,15 @@ def cpo_loss( f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']" ) - chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach() - rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach() + # Calculate rewards for logging + if self.alpha != 0.0: + # When using AlphaPO transformation, use the transformed rewards + chosen_rewards = self.beta * policy_chosen_rewards.to(self.accelerator.device).detach() + rejected_rewards = self.beta * policy_rejected_rewards.to(self.accelerator.device).detach() + else: + # Standard log probability rewards + chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach() + rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach() return losses, chosen_rewards, rejected_rewards