Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,28 @@ trainer = PAPOTrainer(
)
```

### The Art of Scaling Reinforcement Learning

**📜 Paper**: https://huggingface.co/papers/2510.13786

A systematic study that defines a framework for analyzing and predicting reinforcement learning scaling in large language models, identifies key design choices that affect compute efficiency and propose a best-practice recipe called ScaleRL.

You can partially reproduce the ScaleRL recipe using the [`GRPOTrainer`] with the following configs:

```python
from trl import GRPOConfig

config = GRPOConfig(
loss_type="cispo",
epsilon_high=5.0,
num_completions=16,
scale_rewards="batch",
cast_lm_head_to_fp32=True
)
```



## Direct Policy Optimization

Papers relating to the [`DPOTrainer`]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_training(self, config_name):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo"])
@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo"])
def test_training_loss_types(self, loss_type):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

Expand Down
15 changes: 14 additions & 1 deletion trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class GRPOConfig(TrainingArguments):
epsilon_high (`float`, *optional*):
Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
When used with `loss_type='cispo'`, this corresponds to the ε_max param specified in the [ScaleRL
paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`.
importance_sampling_level (`str`, *optional*, defaults to `"token"`):
Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"`
keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the
Expand Down Expand Up @@ -201,6 +203,10 @@ class GRPOConfig(TrainingArguments):
batch. Note that normalization is performed over the local batch only, so results may slightly vary
depending on the local batch size, despite a constant effective batch size. When using
`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss.
- `"cispo"`: Clips the importance sampling weights instead of the advantage scaled importance weights. The
clipped weights are then multiplied with the advantages and policy model's log probs. Individual token
losses are aggregated by normalizing with the number of active tokens in the global accumulated batch.
This method was introduced in the [MiniMax-M1 paper](https://huggingface.co/papers/2506.13585).
mask_truncated_completions (`bool`, *optional*, defaults to `False`):
When enabled, truncated completions are excluded from the loss calculation, preventing them from being
incorrectly penalized and introducing noise during training. According to the
Expand Down Expand Up @@ -533,7 +539,9 @@ class GRPOConfig(TrainingArguments):
default=None,
metadata={
"help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the "
"lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`."
"lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`. "
"When used with `loss_type='cispo'`, this corresponds to the ε_max param specified in the"
"[ScaleRL paper]https://huggingface.co/papers/2510.13786) and the recommended value is `5.0`."
},
)
importance_sampling_level: str = field(
Expand Down Expand Up @@ -582,6 +590,11 @@ class GRPOConfig(TrainingArguments):
"Note that normalization is performed over the local batch only, so results may slightly vary depending "
"on the local batch size, despite a constant effective batch size. When using "
"`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss."
"'cispo': Clips the importance sampling weights instead of the advantage scaled importance weights. "
"The clipped weights are then multiplied with the advantages and policy model's log probs. "
"Individual token losses are aggregated by normalizing with the number of active tokens in "
"the global accumulated batch. This method was introduced in the "
"[MiniMax-M1 paper](https://huggingface.co/papers/2506.13585)."
},
)
mask_truncated_completions: bool = field(
Expand Down
65 changes: 39 additions & 26 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1816,19 +1816,25 @@ def _compute_loss(self, model, inputs):
f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "
"and 'sequence'."
)

coef_1 = torch.exp(log_importance_weights)

# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe in the documentation of epsilon_high we can mention that this is the value used for epsilon_max when used with CISPO loss. and that the paper recommends =5.0

per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps

coef_1 = torch.exp(log_importance_weights)
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
else:
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
# Two-sided clipping
if self.args.delta is not None:
coef_1 = torch.clamp(coef_1, max=self.args.delta)

# Two-sided clipping
if self.args.delta is not None:
coef_1 = torch.clamp(coef_1, max=self.args.delta)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)

per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if entropy_mask is not None:
per_token_loss = per_token_loss * entropy_mask

Expand All @@ -1847,7 +1853,7 @@ def _compute_loss(self, model, inputs):
elif self.loss_type == "dr_grpo":
loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
loss = loss / self.current_gradient_accumulation_steps
elif self.loss_type == "dapo":
elif self.loss_type in ["cispo", "dapo"]:
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
loss = (per_token_loss * completion_mask).sum() / normalizer
else:
Expand All @@ -1871,23 +1877,30 @@ def masked_batch_mean(x):
mean_entropy = masked_batch_mean(entropies)
self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())

# Compute the clipped probability ratios
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
is_region_clipped = is_low_clipped | is_high_clipped

low_clip = masked_batch_mean(is_low_clipped.float())
high_clip = masked_batch_mean(is_high_clipped.float())
clip_ratio = masked_batch_mean(is_region_clipped.float())

gathered_low_clip = self.accelerator.gather(low_clip)
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
gathered_high_clip = self.accelerator.gather(high_clip)
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
if self.loss_type != "cispo":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, again (explicit better than implicit

Suggested change
if self.loss_type != "cispo":
if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:

# Compute the clipped probability ratios
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
is_region_clipped = is_low_clipped | is_high_clipped

low_clip = masked_batch_mean(is_low_clipped.float())
high_clip = masked_batch_mean(is_high_clipped.float())
clip_ratio = masked_batch_mean(is_region_clipped.float())

gathered_low_clip = self.accelerator.gather(low_clip)
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
gathered_high_clip = self.accelerator.gather(high_clip)
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
elif self.loss_type == "cispo":
is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(1) > 0)
cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float())
gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio)
self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item())

return loss

def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None):
Expand Down
Loading