-
Notifications
You must be signed in to change notification settings - Fork 442
Open
Description
Description
In slime/ray/rollout.py, the _post_process_rewards method incorrectly normalizes rewards across the entire batch instead of per-group when sample counts are unequal across prompt groups.
Affected Code
# slime/ray/rollout.py, lines 310-324 (before fix)
def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):
...
if rewards.shape[-1] == self.args.n_samples_per_prompt * self.args.rollout_batch_size:
rewards = rewards.reshape(-1, self.args.n_samples_per_prompt)
else:
# BUG: This doesn't normalize per-group!
rewards = rewards.view(-1, rewards.shape[-1])
mean = rewards.mean(dim=-1, keepdim=True)
rewards = rewards - mean
...
Root Cause
When the number of samples per prompt is not uniform (e.g., due to early termination or aborting), the code falls into the else branch. The problematic line:
rewards = rewards.view(-1, rewards.shape[-1])
This reshapes the tensor to (1, total_samples), causing mean(dim=-1) to compute a single mean across all samples rather than per-group means. This violates the intended GRPO/GSPO normalization semantics where rewards should be normalized within each prompt group independently.
Example
Consider 3 prompt groups with varying sample counts:
- Group 0: 4 samples, rewards = [1, 2, 3, 4]
- Group 1: 2 samples, rewards = [10, 12]
- Group 2: 3 samples, rewards = [5, 6, 7]
Expected behavior (per-group normalization): - Group 0: mean=2.5 → normalized = [-1.5, -0.5, 0.5, 1.5]
- Group 1: mean=11 → normalized = [-1, 1]
- Group 2: mean=6 → normalized = [-1, 0, 1]
Actual behavior (batch normalization): - Global mean = (1+2+3+4+10+12+5+6+7) / 9 ≈ 5.56
All rewards shifted by this single mean, mixing group statistics incorrectly
Fix
Use sample.group_index to properly compute per-group statistics
else:
# Normalize within each prompt group using sample.group_index
group_ids = torch.tensor([sample.group_index for sample in samples])
# Compute group means via scatter
group_counts = torch.bincount(group_ids).float()
group_sums = torch.zeros_like(group_counts).scatter_add_(0, group_ids, rewards)
group_means = group_sums / group_counts
rewards = rewards - group_means[group_ids]
if use_std:
# Compute group stds
group_sq_sums = torch.zeros_like(group_counts).scatter_add_(0, group_ids, rewards**2)
group_stds = (group_sq_sums / group_counts).sqrt()
rewards = rewards / (group_stds[group_ids] + 1e-6)
return raw_rewards, rewards.tolist()
Metadata
Metadata
Assignees
Labels
No labels