Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions skyrl-train/docs/configuration/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ Algorithm Configuration
advantage_batch_normalize: false
value_head_prefix: "value_head"
ppo_loss_type: "regular" # "regular", "dual_clip"
loss_reduction: "token_mean" # "token_mean", "sequence_mean"

# GAE parameters
lambd: 1.0
Expand All @@ -315,6 +316,7 @@ Algorithm Configuration
- ``algorithm.advantage_batch_normalize``: Whether to normalize advantages by the (global) batch mean and standard deviation.
- ``algorithm.value_head_prefix``: The name used to identify the value head in the critic model.
- ``algorithm.ppo_loss_type``: Type of PPO loss to use. Currently, we support ``regular`` and ``dual_clip``. ``regular`` is the vanilla PPO loss, while ``dual_clip`` is the dual clip PPO loss proposed in `this paper <https://arxiv.org/pdf/1912.09729>`_.
- ``algorithm.loss_reduction``: Type of PPO loss reduction to use. Currently, we support ``token_mean`` and ``sequence_mean``. ``token_mean`` matches token-level loss introduced by `DAPO <https://dapo-sia.github.io/>`_. ``sequence_mean`` computes per-sequence avg token loss, then averages over the batch.
- ``algorithm.lambd``: Lambda parameter for GAE.
- ``algorithm.gamma``: Gamma parameter for GAE.
- ``algorithm.eps_clip_low``: Lower bound for PPO clipping.
Expand Down
1 change: 1 addition & 0 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ trainer:
advantage_batch_normalize: false
value_head_prefix: "value_head"
ppo_loss_type: "regular" # "regular", "dual_clip"
loss_reduction: "token_mean" # "token_mean", "sequence_mean"
# GAE parameters
lambd: 1.0
gamma: 1.0
Expand Down
4 changes: 2 additions & 2 deletions skyrl-train/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def update(self, current, n_steps):
pass


def masked_mean(tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: int = None) -> torch.Tensor:
def masked_mean(tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: Optional[int] = None) -> torch.Tensor:
if mask is None:
return tensor.mean(axis=dim)
return (tensor * mask).sum(axis=dim) / mask.sum(axis=dim)
return (tensor * mask).sum(axis=dim) / (mask.sum(axis=dim) + 1e-8)


@torch.no_grad()
Expand Down
7 changes: 6 additions & 1 deletion skyrl-train/skyrl_train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,12 @@ def validate_cfg(cfg: DictConfig):
assert cfg.trainer.algorithm.ppo_loss_type in (
"regular",
"dual_clip",
), f"invalid loss type: {cfg.trainer.algorithm.ppo_loss_type}. Must be one of `['regular', 'dual_clip']`"
), f"invalid ppo_loss_type: {cfg.trainer.algorithm.ppo_loss_type}. Must be one of `['regular', 'dual_clip']`"

assert cfg.trainer.algorithm.loss_reduction in (
"token_mean",
"sequence_mean",
), f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. Must be one of `['token_mean', 'sequence_mean']`"

if cfg.trainer.strategy == "deepspeed" and not (
cfg.trainer.policy.optimizer_config.offload_after_step
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def init_model(self, model_id_or_path):
self.cfg.trainer.algorithm.eps_clip_high,
self.cfg.trainer.algorithm.clip_ratio_c,
loss_type=self.cfg.trainer.algorithm.ppo_loss_type,
loss_reduction=self.cfg.trainer.algorithm.loss_reduction,
)

self.use_cuda_ipc = False
Expand Down
1 change: 1 addition & 0 deletions skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def init_model(self, model_path):
self.cfg.trainer.algorithm.eps_clip_high,
self.cfg.trainer.algorithm.clip_ratio_c,
loss_type=self.cfg.trainer.algorithm.ppo_loss_type,
loss_reduction=self.cfg.trainer.algorithm.loss_reduction,
)

self.use_cuda_ipc = False
Expand Down
21 changes: 17 additions & 4 deletions skyrl-train/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
import socket
from typing import Dict, Optional, Type, List, Literal, Any
from typing import Dict, Optional, Type, List, Literal, Any, Tuple
from ctypes import CDLL, POINTER, Structure, c_char_p, c_int, c_ulong, c_void_p
from tqdm import tqdm
from collections import defaultdict
Expand Down Expand Up @@ -354,24 +354,30 @@ class PolicyLoss(nn.Module):
def __init__(
self,
clip_eps_low: float = 0.2,
clip_eps_high: float = 0.4,
clip_eps_high: float = 0.2,
clip_ratio_c: float = 3.0,
loss_type: Literal["regular", "dual_clip"] = "regular",
loss_reduction: Literal["token_mean", "sequence_mean"] = "token_mean",
) -> None:
super().__init__()
self.clip_eps_low = clip_eps_low
self.clip_eps_high = clip_eps_high
self.clip_ratio_c = clip_ratio_c
self.loss_type = loss_type
assert loss_type in ["regular", "dual_clip"], "loss_type must be either 'regular' or 'dual_clip'"
self.loss_reduction = loss_reduction
assert loss_reduction in [
"token_mean",
"sequence_mean",
], "loss_reduction must be either 'token_mean' or 'sequence_mean'"

def forward(
self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
loss_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, float]:

ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
Expand All @@ -383,7 +389,14 @@ def forward(
pg_losses3 = -advantages * self.clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
loss = masked_mean(loss, loss_mask, dim=-1).mean()
if self.loss_reduction == "token_mean":
# sum over *all* valid tokens, divide by total valid-token count
loss = masked_mean(loss, loss_mask)
elif self.loss_reduction == "sequence_mean":
# per-sequence token-mean (dim=-1), then batch-mean
loss = masked_mean(loss, loss_mask, dim=-1).mean()
else:
raise ValueError(f"Invalid loss reduction type: {self.loss_reduction}")
return loss, clip_ratio


Expand Down
113 changes: 113 additions & 0 deletions skyrl-train/tests/cpu/algorithms/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,116 @@ def test_policy_loss_dual_clip():
torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-3, atol=1e-8)
# close to hand calculated value
assert actual_loss.item() == pytest.approx(4.1667, abs=1e-4)


def test_policy_loss_reduction_modes():
"""Tests different loss_reduction modes in PolicyLoss function.

Note: token_mean and sequence_mean give the same result when all sequences
have the same length and no mask is applied, but differ when masking creates
different effective sequence lengths.
"""

device = "cpu"

clip_eps_low = 0.2
clip_eps_high = 0.2

advantages = torch.tensor(
[
[2.0, 2.0, 2.0], # sequence 1: consistently higher advantages
[1.0, 1.0, 1.0], # sequence 2: consistently lower advantages
],
device=device,
)

old_log_probs = torch.tensor([[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]], device=device)

log_probs = torch.tensor(
[[-1.5, -0.5, -1.2], [-0.8, -1.3, -0.9]], # ratios ≈ [[0.61, 1.65, 0.83],[1.22, 0.74, 1.11]]
device=device,
)

# Create masks to test sequences with different numbers of valid tokens
loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]], device=device)

# Test token_mean without mask
loss_fn_token = PolicyLoss(
loss_type="regular", loss_reduction="token_mean", clip_eps_low=clip_eps_low, clip_eps_high=clip_eps_high
)
loss_token_no_mask, _ = loss_fn_token(log_probs, old_log_probs, advantages)

# Test token_mean with mask
loss_token_with_mask, _ = loss_fn_token(log_probs, old_log_probs, advantages, loss_mask)

# Test sequence_mean without mask
loss_fn_seq = PolicyLoss(
loss_type="regular", loss_reduction="sequence_mean", clip_eps_low=clip_eps_low, clip_eps_high=clip_eps_high
)
loss_seq_no_mask, _ = loss_fn_seq(log_probs, old_log_probs, advantages)

# Test sequence_mean with mask
loss_seq_with_mask, _ = loss_fn_seq(log_probs, old_log_probs, advantages, loss_mask)

# Manual calculations to verify (using default PolicyLoss parameters)
ratio = torch.exp(log_probs - old_log_probs)
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages # clip_eps_low=0.2, clip_eps_high=0.2
loss_per_token = -torch.min(surr1, surr2)

# Expected token_mean without mask: mean of all tokens
expected_token_no_mask = loss_per_token.mean()

# Expected token_mean with mask: masked mean of all tokens
expected_token_with_mask = (loss_per_token * loss_mask).sum() / (loss_mask.sum() + 1e-8)

# Expected sequence_mean without mask: mean of sequence means
expected_seq_no_mask = loss_per_token.mean(dim=1).mean()

# Expected sequence_mean with mask: mean of masked sequence means
seq_means_masked = (loss_per_token * loss_mask).sum(dim=1) / (loss_mask.sum(dim=1) + 1e-8)
expected_seq_with_mask = seq_means_masked.mean()

# Verify results
torch.testing.assert_close(loss_token_no_mask, expected_token_no_mask, rtol=1e-5, atol=1e-8)
torch.testing.assert_close(loss_token_with_mask, expected_token_with_mask, rtol=1e-5, atol=1e-8)
torch.testing.assert_close(loss_seq_no_mask, expected_seq_no_mask, rtol=1e-5, atol=1e-8)
torch.testing.assert_close(loss_seq_with_mask, expected_seq_with_mask, rtol=1e-5, atol=1e-8)

# Verify that the two reduction modes give the same results when sequences have equal length and no mask
assert torch.allclose(
loss_token_no_mask, loss_seq_no_mask, rtol=1e-5
), "token_mean and sequence_mean should give same results when sequences have equal length and no mask"
# But they should give different results when mask creates different effective sequence lengths
assert not torch.allclose(
loss_token_with_mask, loss_seq_with_mask, rtol=1e-3
), "token_mean and sequence_mean with mask should give different results"


def test_policy_loss_reduction_edge_cases():
"""Tests edge cases for loss_reduction modes."""

device = "cpu"

# Test with single sequence (should give same result for both modes)
advantages = torch.tensor([[1.0, -1.0, 2.0]], device=device)
old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
log_probs = torch.tensor([[-1.5, -0.5, -1.2]], device=device)

loss_fn_token = PolicyLoss(loss_type="regular", loss_reduction="token_mean")
loss_fn_seq = PolicyLoss(loss_type="regular", loss_reduction="sequence_mean")

loss_token, _ = loss_fn_token(log_probs, old_log_probs, advantages)
loss_seq, _ = loss_fn_seq(log_probs, old_log_probs, advantages)

# With single sequence, both modes should give same result
torch.testing.assert_close(loss_token, loss_seq, rtol=1e-6, atol=1e-8)

# Test with completely masked sequence
loss_mask = torch.tensor([[0.0, 0.0, 0.0]], device=device)
loss_token_masked, _ = loss_fn_token(log_probs, old_log_probs, advantages, loss_mask)
loss_seq_masked, _ = loss_fn_seq(log_probs, old_log_probs, advantages, loss_mask)

# Should handle zero mask gracefully (due to +1e-8 in denominator)
assert torch.isfinite(loss_token_masked)
assert torch.isfinite(loss_seq_masked)
1 change: 1 addition & 0 deletions skyrl-train/tests/cpu/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def dummy_config():
"value_clip": 0.2,
"normalize_reward": True,
"ppo_loss_type": "regular",
"loss_reduction": "token_mean",
},
"resume_mode": "none",
},
Expand Down