Skip to content

Bug: Group normalization in _post_process_rewards normalizes across entire batch when sample counts are unequal #1414

@ccggddmm

Description

@ccggddmm

Description

In slime/ray/rollout.py, the _post_process_rewards method incorrectly normalizes rewards across the entire batch instead of per-group when sample counts are unequal across prompt groups.

Affected Code

# slime/ray/rollout.py, lines 310-324 (before fix)
def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):
    ...
    if rewards.shape[-1] == self.args.n_samples_per_prompt * self.args.rollout_batch_size:
        rewards = rewards.reshape(-1, self.args.n_samples_per_prompt)
    else:
        # BUG: This doesn't normalize per-group!
        rewards = rewards.view(-1, rewards.shape[-1])
    mean = rewards.mean(dim=-1, keepdim=True)
    rewards = rewards - mean
    ...

Root Cause

When the number of samples per prompt is not uniform (e.g., due to early termination or aborting), the code falls into the else branch. The problematic line:

rewards = rewards.view(-1, rewards.shape[-1])

This reshapes the tensor to (1, total_samples), causing mean(dim=-1) to compute a single mean across all samples rather than per-group means. This violates the intended GRPO/GSPO normalization semantics where rewards should be normalized within each prompt group independently.

Example

Consider 3 prompt groups with varying sample counts:

  • Group 0: 4 samples, rewards = [1, 2, 3, 4]
  • Group 1: 2 samples, rewards = [10, 12]
  • Group 2: 3 samples, rewards = [5, 6, 7]
    Expected behavior (per-group normalization):
  • Group 0: mean=2.5 → normalized = [-1.5, -0.5, 0.5, 1.5]
  • Group 1: mean=11 → normalized = [-1, 1]
  • Group 2: mean=6 → normalized = [-1, 0, 1]
    Actual behavior (batch normalization):
  • Global mean = (1+2+3+4+10+12+5+6+7) / 9 ≈ 5.56
    All rewards shifted by this single mean, mixing group statistics incorrectly

Fix

Use sample.group_index to properly compute per-group statistics

else:
    # Normalize within each prompt group using sample.group_index
    group_ids = torch.tensor([sample.group_index for sample in samples])

    # Compute group means via scatter
    group_counts = torch.bincount(group_ids).float()
    group_sums = torch.zeros_like(group_counts).scatter_add_(0, group_ids, rewards)
    group_means = group_sums / group_counts
    rewards = rewards - group_means[group_ids]

    if use_std:
        # Compute group stds
        group_sq_sums = torch.zeros_like(group_counts).scatter_add_(0, group_ids, rewards**2)
        group_stds = (group_sq_sums / group_counts).sqrt()
        rewards = rewards / (group_stds[group_ids] + 1e-6)

    return raw_rewards, rewards.tolist()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions