Skip to content

Commit 642b721

Browse files
authored
ScaleRL: Add CISPO Loss (#4495)
1 parent 32e9c9f commit 642b721

File tree

4 files changed

+76
-28
lines changed

4 files changed

+76
-28
lines changed

docs/source/paper_index.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,28 @@ trainer = PAPOTrainer(
232232
)
233233
```
234234

235+
### The Art of Scaling Reinforcement Learning
236+
237+
**📜 Paper**: https://huggingface.co/papers/2510.13786
238+
239+
A systematic study that defines a framework for analyzing and predicting reinforcement learning scaling in large language models, identifies key design choices that affect compute efficiency and propose a best-practice recipe called ScaleRL.
240+
241+
You can partially reproduce the ScaleRL recipe using the [`GRPOTrainer`] with the following configs:
242+
243+
```python
244+
from trl import GRPOConfig
245+
246+
config = GRPOConfig(
247+
loss_type="cispo",
248+
epsilon_high=5.0,
249+
num_completions=16,
250+
scale_rewards="batch",
251+
cast_lm_head_to_fp32=True
252+
)
253+
```
254+
255+
256+
235257
## Direct Policy Optimization
236258

237259
Papers relating to the [`DPOTrainer`]

tests/test_grpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def test_training(self, config_name):
167167
new_param = trainer.model.get_parameter(n)
168168
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
169169

170-
@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo"])
170+
@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo"])
171171
def test_training_loss_types(self, loss_type):
172172
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
173173

trl/trainer/grpo_config.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ class GRPOConfig(TrainingArguments):
166166
epsilon_high (`float`, *optional*):
167167
Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
168168
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
169+
When used with `loss_type='cispo'`, this corresponds to the ε_max param specified in the [ScaleRL
170+
paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`.
169171
importance_sampling_level (`str`, *optional*, defaults to `"token"`):
170172
Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"`
171173
keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the
@@ -201,6 +203,10 @@ class GRPOConfig(TrainingArguments):
201203
batch. Note that normalization is performed over the local batch only, so results may slightly vary
202204
depending on the local batch size, despite a constant effective batch size. When using
203205
`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss.
206+
- `"cispo"`: Clips the importance sampling weights instead of the advantage scaled importance weights. The
207+
clipped weights are then multiplied with the advantages and policy model's log probs. Individual token
208+
losses are aggregated by normalizing with the number of active tokens in the global accumulated batch.
209+
This method was introduced in the [MiniMax-M1 paper](https://huggingface.co/papers/2506.13585).
204210
mask_truncated_completions (`bool`, *optional*, defaults to `False`):
205211
When enabled, truncated completions are excluded from the loss calculation, preventing them from being
206212
incorrectly penalized and introducing noise during training. According to the
@@ -533,7 +539,9 @@ class GRPOConfig(TrainingArguments):
533539
default=None,
534540
metadata={
535541
"help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the "
536-
"lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`."
542+
"lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`. "
543+
"When used with `loss_type='cispo'`, this corresponds to the ε_max param specified in the"
544+
"[ScaleRL paper]https://huggingface.co/papers/2510.13786) and the recommended value is `5.0`."
537545
},
538546
)
539547
importance_sampling_level: str = field(
@@ -582,6 +590,11 @@ class GRPOConfig(TrainingArguments):
582590
"Note that normalization is performed over the local batch only, so results may slightly vary depending "
583591
"on the local batch size, despite a constant effective batch size. When using "
584592
"`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss."
593+
"'cispo': Clips the importance sampling weights instead of the advantage scaled importance weights. "
594+
"The clipped weights are then multiplied with the advantages and policy model's log probs. "
595+
"Individual token losses are aggregated by normalizing with the number of active tokens in "
596+
"the global accumulated batch. This method was introduced in the "
597+
"[MiniMax-M1 paper](https://huggingface.co/papers/2506.13585)."
585598
},
586599
)
587600
mask_truncated_completions: bool = field(

trl/trainer/grpo_trainer.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,19 +1816,25 @@ def _compute_loss(self, model, inputs):
18161816
f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "
18171817
"and 'sequence'."
18181818
)
1819+
1820+
coef_1 = torch.exp(log_importance_weights)
1821+
18191822
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
18201823
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
1824+
if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
1825+
clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach()
1826+
per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps
18211827

1822-
coef_1 = torch.exp(log_importance_weights)
1823-
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
1828+
else:
1829+
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
1830+
# Two-sided clipping
1831+
if self.args.delta is not None:
1832+
coef_1 = torch.clamp(coef_1, max=self.args.delta)
18241833

1825-
# Two-sided clipping
1826-
if self.args.delta is not None:
1827-
coef_1 = torch.clamp(coef_1, max=self.args.delta)
1834+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
1835+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
1836+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
18281837

1829-
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
1830-
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
1831-
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
18321838
if entropy_mask is not None:
18331839
per_token_loss = per_token_loss * entropy_mask
18341840

@@ -1847,7 +1853,7 @@ def _compute_loss(self, model, inputs):
18471853
elif self.loss_type == "dr_grpo":
18481854
loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
18491855
loss = loss / self.current_gradient_accumulation_steps
1850-
elif self.loss_type == "dapo":
1856+
elif self.loss_type in ["cispo", "dapo"]:
18511857
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
18521858
loss = (per_token_loss * completion_mask).sum() / normalizer
18531859
else:
@@ -1871,23 +1877,30 @@ def masked_batch_mean(x):
18711877
mean_entropy = masked_batch_mean(entropies)
18721878
self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())
18731879

1874-
# Compute the clipped probability ratios
1875-
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
1876-
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
1877-
is_region_clipped = is_low_clipped | is_high_clipped
1878-
1879-
low_clip = masked_batch_mean(is_low_clipped.float())
1880-
high_clip = masked_batch_mean(is_high_clipped.float())
1881-
clip_ratio = masked_batch_mean(is_region_clipped.float())
1882-
1883-
gathered_low_clip = self.accelerator.gather(low_clip)
1884-
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
1885-
self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
1886-
gathered_high_clip = self.accelerator.gather(high_clip)
1887-
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
1888-
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
1889-
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
1890-
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
1880+
if self.loss_type != "cispo":
1881+
# Compute the clipped probability ratios
1882+
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
1883+
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
1884+
is_region_clipped = is_low_clipped | is_high_clipped
1885+
1886+
low_clip = masked_batch_mean(is_low_clipped.float())
1887+
high_clip = masked_batch_mean(is_high_clipped.float())
1888+
clip_ratio = masked_batch_mean(is_region_clipped.float())
1889+
1890+
gathered_low_clip = self.accelerator.gather(low_clip)
1891+
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
1892+
self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
1893+
gathered_high_clip = self.accelerator.gather(high_clip)
1894+
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
1895+
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
1896+
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
1897+
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
1898+
elif self.loss_type == "cispo":
1899+
is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(1) > 0)
1900+
cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float())
1901+
gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio)
1902+
self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item())
1903+
18911904
return loss
18921905

18931906
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None):

0 commit comments

Comments
 (0)