diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 6467548d8ea..b5ac187f744 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -232,6 +232,28 @@ trainer = PAPOTrainer( ) ``` +### The Art of Scaling Reinforcement Learning + +**📜 Paper**: https://huggingface.co/papers/2510.13786 + +A systematic study that defines a framework for analyzing and predicting reinforcement learning scaling in large language models, identifies key design choices that affect compute efficiency and propose a best-practice recipe called ScaleRL. + +You can partially reproduce the ScaleRL recipe using the [`GRPOTrainer`] with the following configs: + +```python +from trl import GRPOConfig + +config = GRPOConfig( + loss_type="cispo", + epsilon_high=5.0, + num_completions=16, + scale_rewards="batch", + cast_lm_head_to_fp32=True +) +``` + + + ## Direct Policy Optimization Papers relating to the [`DPOTrainer`] diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 88d2579a69d..a3bf80a641b 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -167,7 +167,7 @@ def test_training(self, config_name): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - @pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo"]) + @pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo"]) def test_training_loss_types(self, loss_type): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 6001a8dc524..260a73ab164 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -166,6 +166,8 @@ class GRPOConfig(TrainingArguments): epsilon_high (`float`, *optional*): Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. + When used with `loss_type='cispo'`, this corresponds to the ε_max param specified in the [ScaleRL + paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`. importance_sampling_level (`str`, *optional*, defaults to `"token"`): Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"` keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the @@ -201,6 +203,10 @@ class GRPOConfig(TrainingArguments): batch. Note that normalization is performed over the local batch only, so results may slightly vary depending on the local batch size, despite a constant effective batch size. When using `per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. + - `"cispo"`: Clips the importance sampling weights instead of the advantage scaled importance weights. The + clipped weights are then multiplied with the advantages and policy model's log probs. Individual token + losses are aggregated by normalizing with the number of active tokens in the global accumulated batch. + This method was introduced in the [MiniMax-M1 paper](https://huggingface.co/papers/2506.13585). mask_truncated_completions (`bool`, *optional*, defaults to `False`): When enabled, truncated completions are excluded from the loss calculation, preventing them from being incorrectly penalized and introducing noise during training. According to the @@ -533,7 +539,9 @@ class GRPOConfig(TrainingArguments): default=None, metadata={ "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the " - "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`." + "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`. " + "When used with `loss_type='cispo'`, this corresponds to the ε_max param specified in the" + "[ScaleRL paper]https://huggingface.co/papers/2510.13786) and the recommended value is `5.0`." }, ) importance_sampling_level: str = field( @@ -582,6 +590,11 @@ class GRPOConfig(TrainingArguments): "Note that normalization is performed over the local batch only, so results may slightly vary depending " "on the local batch size, despite a constant effective batch size. When using " "`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss." + "'cispo': Clips the importance sampling weights instead of the advantage scaled importance weights. " + "The clipped weights are then multiplied with the advantages and policy model's log probs. " + "Individual token losses are aggregated by normalizing with the number of active tokens in " + "the global accumulated batch. This method was introduced in the " + "[MiniMax-M1 paper](https://huggingface.co/papers/2506.13585)." }, ) mask_truncated_completions: bool = field( diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index fe29c78b92a..bddef14adfc 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1816,19 +1816,25 @@ def _compute_loss(self, model, inputs): f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " "and 'sequence'." ) + + coef_1 = torch.exp(log_importance_weights) + # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) + if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]: + clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() + per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps - coef_1 = torch.exp(log_importance_weights) - coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + else: + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + # Two-sided clipping + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) - # Two-sided clipping - if self.args.delta is not None: - coef_1 = torch.clamp(coef_1, max=self.args.delta) + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) - per_token_loss1 = coef_1 * advantages.unsqueeze(1) - per_token_loss2 = coef_2 * advantages.unsqueeze(1) - per_token_loss = -torch.min(per_token_loss1, per_token_loss2) if entropy_mask is not None: per_token_loss = per_token_loss * entropy_mask @@ -1847,7 +1853,7 @@ def _compute_loss(self, model, inputs): elif self.loss_type == "dr_grpo": loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) loss = loss / self.current_gradient_accumulation_steps - elif self.loss_type == "dapo": + elif self.loss_type in ["cispo", "dapo"]: normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes loss = (per_token_loss * completion_mask).sum() / normalizer else: @@ -1871,23 +1877,30 @@ def masked_batch_mean(x): mean_entropy = masked_batch_mean(entropies) self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) - # Compute the clipped probability ratios - is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) - is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) - is_region_clipped = is_low_clipped | is_high_clipped - - low_clip = masked_batch_mean(is_low_clipped.float()) - high_clip = masked_batch_mean(is_high_clipped.float()) - clip_ratio = masked_batch_mean(is_region_clipped.float()) - - gathered_low_clip = self.accelerator.gather(low_clip) - self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) - self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) - gathered_high_clip = self.accelerator.gather(high_clip) - self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) - self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) - gathered_clip_ratio = self.accelerator.gather(clip_ratio) - self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) + if self.loss_type != "cispo": + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = masked_batch_mean(is_low_clipped.float()) + high_clip = masked_batch_mean(is_high_clipped.float()) + clip_ratio = masked_batch_mean(is_region_clipped.float()) + + gathered_low_clip = self.accelerator.gather(low_clip) + self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) + gathered_high_clip = self.accelerator.gather(high_clip) + self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) + gathered_clip_ratio = self.accelerator.gather(clip_ratio) + self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) + elif self.loss_type == "cispo": + is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(1) > 0) + cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) + gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) + self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item()) + return loss def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None):