Skip to content
34 changes: 33 additions & 1 deletion skyrl-train/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,29 @@ def apply_reward_kl_penalty(

return data

def _normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingInputBatch:
"""Normalize the advantages in the mini-batch.

This normalization results in calculating the correct minibatch loss for the
given loss reduction type when reducing the loss with a sum.
"""
advantages = data["advantages"]
loss_mask = data["loss_mask"]

# NOTE: Do not modify the tensor in place!
# Otherwise subsequent epochs will keep dividing the same tensor.

# Option 1: token mean
if self.cfg.trainer.algorithm.loss_reduction == "token_mean":
data["advantages"] = advantages / loss_mask.sum()

# Option 2: sequence mean
elif self.cfg.trainer.algorithm.loss_reduction == "sequence_mean":
batch_size = len(data)
data["advantages"] = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True))

return data

def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[str, float]:
"""
Execute training step for FSDP strategy using forward_backward + optim_step.
Expand All @@ -1044,13 +1067,22 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s
mini_batch_size = self.cfg.trainer.critic_mini_batch_size * n_samples

all_metrics: Dict[str, List[float]] = defaultdict(list)
num_mini_batches = len(data) // mini_batch_size

# iterate over mini-batches to do mini batch level normalization
for local_step in range(num_mini_batches):
start_idx = local_step * mini_batch_size
end_idx = (local_step + 1) * mini_batch_size
mini_batch = data[start_idx:end_idx]
mini_batch = self._normalize_minibatch_advantages(mini_batch)
# Copy normalized advantages back to original batch
data["advantages"][start_idx:end_idx] = mini_batch["advantages"]

# Stage full batch in object store ONCE to avoid repeated serialization
data_ref = self.dispatch.stage_data(data)

# Training loop over epochs and mini-batches
for _epoch in range(self.cfg.trainer.update_epochs_per_batch):
num_mini_batches = len(data) // mini_batch_size
for local_step in range(num_mini_batches):
start_idx = local_step * mini_batch_size
end_idx = (local_step + 1) * mini_batch_size
Expand Down
9 changes: 8 additions & 1 deletion skyrl-train/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def ppo_critic_loss(
clipfrac = None
loss = (values - returns) ** 2

# TODO: We separately run into the "mean of means" problem here.
loss = masked_mean(loss, loss_mask, dim=-1).mean()
return 0.5 * loss, clipfrac

Expand Down Expand Up @@ -592,7 +593,13 @@ def ppo_policy_loss(
tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap)
loss = loss * tis_imp_ratio

loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
# NOTE: We scaled the advantages to handle the loss normalization in the trainer.
# So we just need to sum the token-level losses here.
if loss_mask is not None:
loss = loss * loss_mask
loss = loss.sum()
# loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)

return loss, clip_ratio


Expand Down
16 changes: 8 additions & 8 deletions skyrl-train/skyrl_train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,14 @@ def validate_cfg(cfg: DictConfig):
f"Must be one of {available_advantage_estimators}"
)

assert cfg.trainer.algorithm.loss_reduction in (
"token_mean",
"sequence_mean",
"seq_mean_token_sum_norm",
), (
f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. "
f"Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`"
)
# assert cfg.trainer.algorithm.loss_reduction in (
# "token_mean",
# "sequence_mean",
# "seq_mean_token_sum_norm",
# ), (
# f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. "
# f"Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`"
# )

# add field to algorithm config needed for loss functions
# create a new config to make it modifiable
Expand Down
12 changes: 5 additions & 7 deletions skyrl-train/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,17 +899,15 @@ def _forward_backward_micro(

def optim_step(self) -> float:
"""
Scale gradients by 1/micro_batches_accumulated, perform optimizer step, and reset counter.
Perform optimizer step.

Returns:
The gradient norm (before scaling, after clipping)
"""
# Scale accumulated gradients by 1/N to get correct average
if self._micro_batches_accumulated > 0:
scale = 1.0 / self._micro_batches_accumulated
for param in self.model.parameters():
if param.grad is not None:
param.grad.mul_(scale)
# Scale gradients by data parallelism size to undo the DDP all-reduce mean.
for param in self.model.parameters():
if param.grad is not None:
param.grad.mul_(self.strategy.world_size)
Comment on lines +908 to +910
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could do this at the advantage computation level, but i thought it was a bit weird to have ddp all-reduce implementation details there so i separated it to be here.


# Perform optimizer step (includes gradient clipping)
grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor")
Expand Down
Loading