Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/bco_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ For a full example have a look at [`examples/scripts/bco.py`].
## Expected dataset type

The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference).
The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

## Expected model format
The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
Expand Down
32 changes: 25 additions & 7 deletions docs/source/cpo_trainer.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/dataset_formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ For more detailed information on tool calling, refer to the [Tool Calling sectio

The [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) was introduced with the [OpenAI GPT OSS models](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4). It extends the conversational format by adding richer structure for reasoning, function calls, and metadata about the model’s behavior. Key features include:

- **Developer role** – Provides high-level instructions (similar to a system prompt) and lists available tools.
- **Developer role** – Provides high level instructions (similar to a system prompt) and lists available tools.
- **Channels** – Separate types of assistant output into distinct streams:

- `analysis` – for internal reasoning, from the key `"thinking"`
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ accelerate launch trl/scripts/dpo.py \

## Logged metrics

While training and evaluating we record the following reward metrics:
While training and evaluating, we record the following reward metrics:

- `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
- `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
Expand Down
2 changes: 2 additions & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ This constant is recommended to be the maximum completion length. To use this fo

## Logged metrics

While training and evaluating, we record the following reward metrics:

- `num_tokens`: The total number of tokens processed so far, including both prompts and completions.
- `completions/mean_length`: The average length of generated completions.
- `completions/min_length`: The minimum length of generated completions.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/kto_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Here are some other factors to consider when choosing a programming language for

KTO requires an [unpaired preference dataset](dataset_formats#unpaired-preference). Alternatively, you can provide a *paired* preference dataset (also known simply as a *preference dataset*). In this case, the trainer will automatically convert it to an unpaired format by separating the chosen and rejected responses, assigning `label = True` to the chosen completions and `label = False` to the rejected ones.

The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

In theory, the dataset should contain at least one chosen and one rejected completion. However, some users have successfully run KTO using *only* chosen or only rejected data. If using only rejected data, it is advisable to adopt a conservative learning rate.

Expand Down Expand Up @@ -118,7 +118,7 @@ By default, they are both 1. However, if you have more of one or the other, then

## Logged metrics

While training and evaluating we record the following reward metrics:
While training and evaluating, we record the following reward metrics:

- `rewards/chosen_sum`: the sum of log probabilities of the policy model for the chosen responses scaled by beta
- `rewards/rejected_sum`: the sum of log probabilities of the policy model for the rejected responses scaled by beta
Expand Down
4 changes: 2 additions & 2 deletions docs/source/nash_md_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ The best programming language depends on personal preference, the complexity of

## Expected dataset type

Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

## Usage tips

Expand Down Expand Up @@ -132,7 +132,7 @@ python examples/scripts/nash_md.py \

## Logged metrics

The logged metrics are as follows:
While training and evaluating, we record the following reward metrics:

* `loss/kl`: The mean KL divergence between the model and reference data.
* `objective/entropy`: The mean entropy of the model and reference data.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/online_dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The best programming language depends on your specific needs and priorities. Som

## Expected dataset type

Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

## Usage tips

Expand Down Expand Up @@ -132,7 +132,7 @@ python examples/scripts/dpo_online.py \

## Logged metrics

The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/w4apmsi9)
While training and evaluating, we record the following reward metrics. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/w4apmsi9)

* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current model and reference model.
* `objective/entropy`: The mean entropy of the model, indicating the randomness of the actions chosen by the model.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/orpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype

## Logged metrics

While training and evaluating we record the following reward metrics:
While training and evaluating, we record the following reward metrics:

- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
Expand Down
20 changes: 20 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,26 @@ training_args = GRPOConfig(
)
```

## AlphaPO -- Reward shape matters for LLM alignment

**📜 Paper**: https://huggingface.co/papers/2501.03884

AlphaPO is a new Direct Alignment Algorithms (DAAs) method that leverages an alpha-parameter to help change the shape of the reward function beyond the standard log reward. AlphaPO helps maintain fine-grained control over likelihood displacement and over-optimization. To reproduce the paper's setting, use this configuration:

```python
from trl import CPOConfig

# Mistral-Instruct from Table 3 of the paper
training_args = CPOConfig(
loss_type="alphapo",
alpha=0.25,
beta=2.5,
simpo_gamma=0.1,
learning_rate=7e-7,
...
)
```

## EMA Without the Lag: Bias-Corrected Iterate Averaging Schemes

**📜 Paper**: https://huggingface.co/papers/2508.00180
Expand Down
2 changes: 2 additions & 0 deletions docs/source/sft_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ Padding tokens (if present) are ignored in the loss computation by applying an i

## Logged metrics

While training and evaluating we record the following reward metrics:

* `global_step`: The total number of optimizer steps taken so far.
* `epoch`: The current epoch number, based on dataset iteration.
* `num_tokens`: The total number of tokens processed so far.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/xpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ python examples/scripts/xpo.py \

## Logged metrics

The logged metrics are as follows:
While training and evaluating we record the following reward metrics:

* `loss/xpo`: The mean xpo part of the full loss.
* `loss/dpo`: The mean dpo part of the full loss.
Expand Down
45 changes: 39 additions & 6 deletions tests/test_cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,6 @@ def test_cpo_trainer_with_lora(self, config_name):
self.assertFalse(torch.equal(param, new_param))

def test_compute_metrics(self):
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer.pad_token = tokenizer.eos_token

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

def dummy_compute_metrics(*args, **kwargs):
Expand All @@ -174,9 +170,9 @@ def dummy_compute_metrics(*args, **kwargs):
)

trainer = CPOTrainer(
model=model,
model=self.model,
args=training_args,
processing_class=tokenizer,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
compute_metrics=dummy_compute_metrics,
Expand All @@ -185,3 +181,40 @@ def dummy_compute_metrics(*args, **kwargs):
trainer.train()

self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)

def test_alphapo_trainer(self):
training_args = CPOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
eval_strategy="steps",
beta=0.1,
loss_type="alphapo",
alpha=0.5,
simpo_gamma=0.5,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = CPOTrainer(
model=self.model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
22 changes: 21 additions & 1 deletion trl/trainer/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,20 @@ class CPOConfig(TrainingArguments):
[SLiC](https://huggingface.co/papers/2305.10425) paper.
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
- `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
- `"alphapo"`: AlphaPO loss from the [AlphaPO](https://huggingface.co/papers/2501.03884) paper. This
automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`.

disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
cpo_alpha (`float`, *optional*, defaults to `1.0`):
Weight of the BC regularizer in CPO training.
simpo_gamma (`float`, *optional*, defaults to `0.5`):
Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
alpha (`float`, *optional*, defaults to `0.0`):
Alpha parameter that controls reward function shape across all loss types. When alpha=0 (default), uses
standard log probability rewards. When `alpha != 0`, applies AlphaPO transformation: `r = (1 - p^(-alpha))
/ alpha` from the [AlphaPO paper](https://huggingface.co/papers/2501.03884). This parameter works with all
loss types.
label_pad_token_id (`int`, *optional*, defaults to `-100`):
Label pad token id. This argument is required if you want to use the default data collator.
padding_value (`int` or `None`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -142,7 +149,7 @@ class CPOConfig(TrainingArguments):
default="sigmoid",
metadata={
"help": "Type of loss to use.",
"choices": ["sigmoid", "hinge", "ipo", "simpo"],
"choices": ["sigmoid", "hinge", "ipo", "simpo", "alphapo"],
},
)
disable_dropout: bool = field(
Expand All @@ -157,6 +164,14 @@ class CPOConfig(TrainingArguments):
default=0.5,
metadata={"help": "Target reward margin for the SimPO loss, used only when the `loss_type='simpo'`."},
)
alpha: float = field(
default=0.0,
metadata={
"help": "Alpha parameter that controls reward function shape across all loss types. When alpha=0 "
"(default), uses standard log probability rewards. When `alpha != 0`, applies AlphaPO transformation: "
"`r = (1 - p^(-alpha)) / alpha` from the AlphaPO paper. This parameter works with all loss types."
},
)
label_pad_token_id: int = field(
default=-100,
metadata={"help": "Label pad token id."},
Expand Down Expand Up @@ -195,4 +210,9 @@ class CPOConfig(TrainingArguments):
def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16

# Syntactic sugar for AlphaPO: set loss_type to "simpo" and cpo_alpha to 0.0
if self.loss_type == "alphapo":
self.loss_type = "simpo"
self.cpo_alpha = 0.0

super().__post_init__()
29 changes: 26 additions & 3 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ def make_inputs_require_grad(module, input, output):
if args.loss_type == "simpo":
self.simpo_gamma = args.simpo_gamma

# AlphaPO parameter for reward shaping
self.alpha = args.alpha

self._stored_metrics = defaultdict(lambda: defaultdict(list))

# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
Expand Down Expand Up @@ -659,7 +662,20 @@ def cpo_loss(
loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for
the chosen and rejected responses, respectively.
"""
logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
# Apply AlphaPO reward transformation if alpha != 0
if self.alpha != 0.0:
# Compute probabilities
chosen_probs = torch.exp(policy_chosen_logps)
rejected_probs = torch.exp(policy_rejected_logps)

# Apply AlphaPO transformation: r = (1 - p^(-alpha)) / alpha
policy_chosen_rewards = (1 - chosen_probs.pow(-self.alpha)) / self.alpha
policy_rejected_rewards = (1 - rejected_probs.pow(-self.alpha)) / self.alpha

logits = (policy_chosen_rewards - policy_rejected_rewards).to(self.accelerator.device)
else:
# Standard log probability rewards when alpha = 0
logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)

# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
Expand Down Expand Up @@ -689,8 +705,15 @@ def cpo_loss(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
)

chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
# Calculate rewards for logging
if self.alpha != 0.0:
# When using AlphaPO transformation, use the transformed rewards
chosen_rewards = self.beta * policy_chosen_rewards.to(self.accelerator.device).detach()
rejected_rewards = self.beta * policy_rejected_rewards.to(self.accelerator.device).detach()
else:
# Standard log probability rewards
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()

return losses, chosen_rewards, rejected_rewards

Expand Down
Loading