Skip to content

Commit fe44806

Browse files
🪶 [GRPO] PPO Lite: Scale rewards by Std of Batch (#3935)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
1 parent 251c048 commit fe44806

File tree

6 files changed

+94
-26
lines changed

6 files changed

+94
-26
lines changed

docs/source/grpo_trainer.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Persp
8080

8181
</Tip>
8282

83+
<Tip>
84+
85+
[Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO)](https://huggingface.co/papers/2508.08221) showed that calculating the mean at the local (group) level and the standard deviation at the global (batch) level enables more robust reward shaping. You can use this scaling strategy by setting `scale_rewards="batch"` in [`GRPOConfig`].
86+
87+
</Tip>
88+
8389
### Estimating the KL divergence
8490

8591
KL divergence is estimated using the approximator introduced by [Schulman et al. (2020)](http://joschu.net/blog/kl-approx.html). The approximator is defined as follows:
@@ -138,6 +144,7 @@ $$
138144
\mathcal{L}_{\text{DAPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
139145
$$
140146

147+
To use this formulation, set `loss_type="bnpo"` in [`GRPOConfig`]. Note that we do not reproduce the DAPO formulation exactly: when using gradient accumulation, the loss is computed over the total number of tokens in each batch, not over the accumulated batches. `loss_type="bnpo"` is equivalent to the original DAPO formulation only when `gradient_accumulation_steps=1`.
141148

142149
Furthermore, it was demonstrated in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that the initial GRPO formulation introduces a response length bias. They show that while the DAPO formulation reduces this bias, it does not eliminate it completely. To fully remove this bias, they propose dividing by a constant instead of the sequence length, resulting in the following formulation:
143150

@@ -162,7 +169,9 @@ While training and evaluating, we record the following reward metrics:
162169
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
163170
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
164171
- `reward`: The overall average reward after applying reward weights.
165-
- `reward_std`: The standard deviation of the overall reward within each batch after applying reward weights.
172+
- `reward_std`: The standard deviation of rewards after applying reward weights.
173+
- If `scale_rewards` is `"group"` or `"none"`, this is the average of the per-group standard deviations.
174+
- If `scale_rewards` is `"batch"`, this is the standard deviation computed over all rewards in the batch (ignoring groups).
166175
- `frac_reward_zero_std`: The fraction of samples in the generation batch with a reward std of zero, implying there is little diversity for that prompt (all answers are correct or incorrect).
167176
- `entropy`: Average entropy of token predictions across generated completions. (If `mask_truncated_completions=True`, masked sequences tokens are excluded.)
168177
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.

docs/source/logging.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ Here's a brief explanation for the logged metrics provided in the data for the G
6363

6464
* `num_tokens`: Total number of input tokens processed during training so far.
6565

66-
**Completions:**
66+
#### Completions
67+
6768
* `completions/mean_length`: Mean length of all generated completions (including those not ending with an EOS token).
6869
* `completions/min_length`: Minimum length among all generated completions.
6970
* `completions/max_length`: Maximum length among all generated completions.
@@ -72,13 +73,15 @@ Here's a brief explanation for the logged metrics provided in the data for the G
7273
* `completions/min_terminated_length`: Minimum length among completions that ended with an EOS token.
7374
* `completions/max_terminated_length`: Maximum length among completions that ended with an EOS token.
7475

75-
**Rewards:**
76+
#### Rewards
77+
7678
* `rewards/{reward_func_name}/mean`: The mean reward obtained from a specific, named reward function (e.g., `rewards/my_custom_reward/mean`). This is logged for each reward function used.
7779
* `rewards/{reward_func_name}/std`: The standard deviation of rewards from a specific, named reward function.
7880
* `reward`: The overall mean of the (potentially weighted and, if `args.scale_rewards` is true, normalized) rewards, after group-wise normalization (advantages).
7981
* `reward_std`: The standard deviation of the (potentially weighted) rewards *before* group-wise normalization for advantages.
8082

81-
**Policy and Loss Metrics:**
83+
#### Policy and Loss Metrics
84+
8285
* `kl`: The mean Kullback-Leibler (KL) divergence between the current policy and the reference policy. This is logged only if `beta` (the KL coefficient in `GRPOConfig`) is non-zero.
8386
* `entropy`: Average entropy of token predictions across generated completions.
8487
* If Liger GRPOLoss is used (`use_liger_loss: True` in `GRPOConfig`):
@@ -91,6 +94,7 @@ Here's a brief explanation for the logged metrics provided in the data for the G
9194
* `clip_ratio/region_mean`: The mean fraction of instances where the probability ratio was clipped at either the lower or upper bound.
9295

9396
### Crucial GRPO values
97+
9498
During GRPO training, monitor these values for insights into performance and stability:
9599

96100
1. `reward`: This is the primary objective. It reflects the (group-wise normalized) rewards the policy is achieving. It should generally increase during successful training.

docs/source/paper_index.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,38 @@ trainer = SFTTrainer(
6060
callbacks=[BEMACallback()],
6161
)
6262
```
63+
64+
## Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO)
65+
66+
**📜 Paper**: https://huggingface.co/papers/2508.08221
67+
68+
The authors of this paper find that the combination of:
69+
70+
1. scaling rewards by the standard deviation computed over the entire batch and
71+
2. aggregating loss over the total number of tokens
72+
73+
can unlock the learning capability of critic-free policies using vanilla PPO loss. Their results demonstrate that this simple combination consistently improves performance, surpassing strategies like GRPO and [DAPO](https://huggingface.co/papers/2503.14476).
74+
75+
TRL supports using these learnings to train a GRPO model by:
76+
77+
```python
78+
from trl import GRPOConfig
79+
80+
training_args = GRPOConfig(
81+
...
82+
scale_rewards="group",
83+
loss_type="bnpo",
84+
# Other parameters used
85+
beta=0.0, # = init_kl_coef in the paper
86+
top_p=0.99,
87+
top_k=100,
88+
temperature=0.99,
89+
num_completions=8, # = num_return_sequences in the paper
90+
num_iterations=1, # = ppo_epochs in the paper
91+
per_device_train_batch_size=4
92+
gradient_accumulation_steps=32,
93+
steps_per_generation=8, # (rollout_batch_size*num_return_sequences) / (per_device_train_batch_size*gradient_accumulation_steps)
94+
)
95+
```
96+
97+
Note that when using gradient accumulation, the loss is aggregated over the total number of tokens in the batch, but not over the accumulated batch. For more details, see the [GRPO Trainer - Loss types](grpo_trainer#loss_types).

tests/test_grpo_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,7 +1267,8 @@ def test_training_vllm_with_additional_generation_kwargs(self):
12671267
new_param = trainer.model.get_parameter(n)
12681268
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
12691269

1270-
def test_training_no_scale_rewards(self):
1270+
@parameterized.expand([(False,), ("group",), ("batch",), (True,), ("none",)])
1271+
def test_training_scale_rewards(self, scale_rewards):
12711272
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
12721273

12731274
training_args = GRPOConfig(
@@ -1276,7 +1277,7 @@ def test_training_no_scale_rewards(self):
12761277
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
12771278
num_generations=3, # reduce the number of generations to reduce memory usage
12781279
max_completion_length=8, # reduce the completion length to reduce memory usage
1279-
scale_rewards=False,
1280+
scale_rewards=scale_rewards,
12801281
report_to="none",
12811282
)
12821283
trainer = GRPOTrainer(

trl/trainer/grpo_config.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,16 @@ class GRPOConfig(TrainingArguments):
166166
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
167167
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
168168
weighted equally with weight `1.0`.
169-
scale_rewards (`bool`, *optional*, defaults to `True`):
170-
Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), the rewards
171-
are normalized by the standard deviation, ensuring they have unit variance. If `False`, no scaling is
172-
applied. The [Dr. GRPO paper](https://huggingface.co/papers/2503.20783) recommends not scaling the rewards,
173-
as scaling by the standard deviation introduces a question-level difficulty bias.
169+
scale_rewards (`str` or `bool`, *optional*, defaults to `"group"`):
170+
Specifies the scaling strategy for rewards. Supported values are:
171+
172+
- `True` or `"group"` (default): rewards are scaled by the standard deviation within each group, ensuring
173+
unit variance within a group.
174+
- `"batch"`: rewards are scaled by the standard deviation across the entire batch, as recommended in the
175+
[PPO Lite paper](https://huggingface.co/papers/2508.08221).
176+
- `False` or `"none"`: no scaling is applied. The [Dr. GRPO
177+
paper](https://huggingface.co/papers/2503.20783) recommends not scaling rewards, as scaling by the
178+
standard deviation introduces a question-level difficulty bias.
174179
loss_type (`str`, *optional*, defaults to `"bnpo"`):
175180
Specifies the loss formulation to use. Supported values are:
176181
@@ -496,13 +501,16 @@ class GRPOConfig(TrainingArguments):
496501
"rewards are weighted equally with weight `1.0`."
497502
},
498503
)
499-
scale_rewards: bool = field(
500-
default=True,
504+
scale_rewards: str = field(
505+
default="group",
501506
metadata={
502-
"help": "Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), "
503-
"the rewards are normalized by the standard deviation, ensuring they have unit variance. If `False`, no "
504-
"scaling is applied. The Dr. GRPO paper recommends not scaling the rewards, as scaling by the standard "
505-
"deviation introduces a question-level difficulty bias."
507+
"help": "Specifies the scaling strategy for rewards. Supported values are: "
508+
"`True` or `group'` (default): rewards are scaled by the standard deviation within each group, ensuring "
509+
"unit variance within a group. "
510+
"`'batch'`: rewards are scaled by the standard deviation across the entire batch, as recommended in the "
511+
"PPO Lite paper. "
512+
"`False` or `'none'`: no scaling is applied. The Dr. GRPO paper recommends not scaling rewards, as "
513+
"scaling by the standard deviation introduces a question-level difficulty bias."
506514
},
507515
)
508516
loss_type: str = field(

trl/trainer/grpo_trainer.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ def __init__(
644644
self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode
645645
self.use_liger_loss = args.use_liger_loss
646646
self.loss_type = args.loss_type
647-
self.scale_rewards = args.scale_rewards
647+
self.scale_rewards = {True: "group", False: "none"}.get(args.scale_rewards, args.scale_rewards)
648648
self.importance_sampling_level = args.importance_sampling_level
649649
self.mask_truncated_completions = args.mask_truncated_completions
650650
self.top_entropy_quantile = args.top_entropy_quantile
@@ -1677,15 +1677,26 @@ def _generate_and_score_completions(
16771677

16781678
# Compute grouped-wise rewards
16791679
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
1680-
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
1681-
is_std_zero = torch.isclose(std_grouped_rewards, torch.zeros_like(std_grouped_rewards))
16821680

16831681
# Normalize the rewards to compute the advantages
16841682
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1685-
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
16861683
advantages = rewards - mean_grouped_rewards
1687-
if self.scale_rewards:
1688-
advantages = advantages / (std_grouped_rewards + 1e-4)
1684+
1685+
if self.scale_rewards in ["batch", "none"]:
1686+
# If self.scale_rewards = "none", we'll still log group level std
1687+
std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
1688+
std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0)
1689+
elif self.scale_rewards == "group":
1690+
# Compute global std
1691+
std_rewards = rewards.std().expand_as(rewards)
1692+
else:
1693+
raise ValueError(
1694+
f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'."
1695+
)
1696+
1697+
is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards))
1698+
if self.scale_rewards in ["batch", "none"]:
1699+
advantages = advantages / (std_rewards + 1e-4)
16891700

16901701
# Slice to keep only the local part of the data
16911702
process_slice = slice(
@@ -1721,10 +1732,10 @@ def _generate_and_score_completions(
17211732
for i, reward_func_name in enumerate(self.reward_func_names):
17221733
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
17231734
self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
1724-
std_rewards = nanstd(rewards_per_func[:, i]).item()
1725-
self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards)
1735+
std_func_rewards = nanstd(rewards_per_func[:, i]).item()
1736+
self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards)
17261737
self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
1727-
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
1738+
self._metrics[mode]["reward_std"].append(std_rewards.mean().item())
17281739
self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item())
17291740

17301741
# Log prompt and completion texts

0 commit comments

Comments
 (0)