diff --git a/skyrl-train/docs/configuration/config.rst b/skyrl-train/docs/configuration/config.rst index fdb5ad23ba..0ea688175a 100644 --- a/skyrl-train/docs/configuration/config.rst +++ b/skyrl-train/docs/configuration/config.rst @@ -291,6 +291,7 @@ Algorithm Configuration advantage_batch_normalize: false value_head_prefix: "value_head" ppo_loss_type: "regular" # "regular", "dual_clip" + loss_reduction: "token_mean" # "token_mean", "sequence_mean" # GAE parameters lambd: 1.0 @@ -315,6 +316,7 @@ Algorithm Configuration - ``algorithm.advantage_batch_normalize``: Whether to normalize advantages by the (global) batch mean and standard deviation. - ``algorithm.value_head_prefix``: The name used to identify the value head in the critic model. - ``algorithm.ppo_loss_type``: Type of PPO loss to use. Currently, we support ``regular`` and ``dual_clip``. ``regular`` is the vanilla PPO loss, while ``dual_clip`` is the dual clip PPO loss proposed in `this paper `_. +- ``algorithm.loss_reduction``: Type of PPO loss reduction to use. Currently, we support ``token_mean`` and ``sequence_mean``. ``token_mean`` matches token-level loss introduced by `DAPO `_. ``sequence_mean`` computes per-sequence avg token loss, then averages over the batch. - ``algorithm.lambd``: Lambda parameter for GAE. - ``algorithm.gamma``: Gamma parameter for GAE. - ``algorithm.eps_clip_low``: Lower bound for PPO clipping. diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index 485d832e9c..949d61f13d 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -89,6 +89,7 @@ trainer: advantage_batch_normalize: false value_head_prefix: "value_head" ppo_loss_type: "regular" # "regular", "dual_clip" + loss_reduction: "token_mean" # "token_mean", "sequence_mean" # GAE parameters lambd: 1.0 gamma: 1.0 diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index bef0259b69..657f09a7cb 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -53,10 +53,10 @@ def update(self, current, n_steps): pass -def masked_mean(tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: int = None) -> torch.Tensor: +def masked_mean(tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: Optional[int] = None) -> torch.Tensor: if mask is None: return tensor.mean(axis=dim) - return (tensor * mask).sum(axis=dim) / mask.sum(axis=dim) + return (tensor * mask).sum(axis=dim) / (mask.sum(axis=dim) + 1e-8) @torch.no_grad() diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index 75d04d4a5a..172b7f58a6 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -190,7 +190,12 @@ def validate_cfg(cfg: DictConfig): assert cfg.trainer.algorithm.ppo_loss_type in ( "regular", "dual_clip", - ), f"invalid loss type: {cfg.trainer.algorithm.ppo_loss_type}. Must be one of `['regular', 'dual_clip']`" + ), f"invalid ppo_loss_type: {cfg.trainer.algorithm.ppo_loss_type}. Must be one of `['regular', 'dual_clip']`" + + assert cfg.trainer.algorithm.loss_reduction in ( + "token_mean", + "sequence_mean", + ), f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. Must be one of `['token_mean', 'sequence_mean']`" if cfg.trainer.strategy == "deepspeed" and not ( cfg.trainer.policy.optimizer_config.offload_after_step diff --git a/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py b/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py index c8c4da77bc..09604fdd36 100644 --- a/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py +++ b/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py @@ -98,6 +98,7 @@ def init_model(self, model_id_or_path): self.cfg.trainer.algorithm.eps_clip_high, self.cfg.trainer.algorithm.clip_ratio_c, loss_type=self.cfg.trainer.algorithm.ppo_loss_type, + loss_reduction=self.cfg.trainer.algorithm.loss_reduction, ) self.use_cuda_ipc = False diff --git a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py index 0729a693ba..9db0579f6f 100644 --- a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py +++ b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py @@ -93,6 +93,7 @@ def init_model(self, model_path): self.cfg.trainer.algorithm.eps_clip_high, self.cfg.trainer.algorithm.clip_ratio_c, loss_type=self.cfg.trainer.algorithm.ppo_loss_type, + loss_reduction=self.cfg.trainer.algorithm.loss_reduction, ) self.use_cuda_ipc = False diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 77ba078828..84d13206f9 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -2,7 +2,7 @@ import logging import os import socket -from typing import Dict, Optional, Type, List, Literal, Any +from typing import Dict, Optional, Type, List, Literal, Any, Tuple from ctypes import CDLL, POINTER, Structure, c_char_p, c_int, c_ulong, c_void_p from tqdm import tqdm from collections import defaultdict @@ -354,9 +354,10 @@ class PolicyLoss(nn.Module): def __init__( self, clip_eps_low: float = 0.2, - clip_eps_high: float = 0.4, + clip_eps_high: float = 0.2, clip_ratio_c: float = 3.0, loss_type: Literal["regular", "dual_clip"] = "regular", + loss_reduction: Literal["token_mean", "sequence_mean"] = "token_mean", ) -> None: super().__init__() self.clip_eps_low = clip_eps_low @@ -364,6 +365,11 @@ def __init__( self.clip_ratio_c = clip_ratio_c self.loss_type = loss_type assert loss_type in ["regular", "dual_clip"], "loss_type must be either 'regular' or 'dual_clip'" + self.loss_reduction = loss_reduction + assert loss_reduction in [ + "token_mean", + "sequence_mean", + ], "loss_reduction must be either 'token_mean' or 'sequence_mean'" def forward( self, @@ -371,7 +377,7 @@ def forward( old_log_probs: torch.Tensor, advantages: torch.Tensor, loss_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, float]: ratio = (log_probs - old_log_probs).exp() surr1 = ratio * advantages @@ -383,7 +389,14 @@ def forward( pg_losses3 = -advantages * self.clip_ratio_c clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) - loss = masked_mean(loss, loss_mask, dim=-1).mean() + if self.loss_reduction == "token_mean": + # sum over *all* valid tokens, divide by total valid-token count + loss = masked_mean(loss, loss_mask) + elif self.loss_reduction == "sequence_mean": + # per-sequence token-mean (dim=-1), then batch-mean + loss = masked_mean(loss, loss_mask, dim=-1).mean() + else: + raise ValueError(f"Invalid loss reduction type: {self.loss_reduction}") return loss, clip_ratio diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index caf2d4f47b..eff9872ae1 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -44,3 +44,116 @@ def test_policy_loss_dual_clip(): torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-3, atol=1e-8) # close to hand calculated value assert actual_loss.item() == pytest.approx(4.1667, abs=1e-4) + + +def test_policy_loss_reduction_modes(): + """Tests different loss_reduction modes in PolicyLoss function. + + Note: token_mean and sequence_mean give the same result when all sequences + have the same length and no mask is applied, but differ when masking creates + different effective sequence lengths. + """ + + device = "cpu" + + clip_eps_low = 0.2 + clip_eps_high = 0.2 + + advantages = torch.tensor( + [ + [2.0, 2.0, 2.0], # sequence 1: consistently higher advantages + [1.0, 1.0, 1.0], # sequence 2: consistently lower advantages + ], + device=device, + ) + + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]], device=device) + + log_probs = torch.tensor( + [[-1.5, -0.5, -1.2], [-0.8, -1.3, -0.9]], # ratios ≈ [[0.61, 1.65, 0.83],[1.22, 0.74, 1.11]] + device=device, + ) + + # Create masks to test sequences with different numbers of valid tokens + loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]], device=device) + + # Test token_mean without mask + loss_fn_token = PolicyLoss( + loss_type="regular", loss_reduction="token_mean", clip_eps_low=clip_eps_low, clip_eps_high=clip_eps_high + ) + loss_token_no_mask, _ = loss_fn_token(log_probs, old_log_probs, advantages) + + # Test token_mean with mask + loss_token_with_mask, _ = loss_fn_token(log_probs, old_log_probs, advantages, loss_mask) + + # Test sequence_mean without mask + loss_fn_seq = PolicyLoss( + loss_type="regular", loss_reduction="sequence_mean", clip_eps_low=clip_eps_low, clip_eps_high=clip_eps_high + ) + loss_seq_no_mask, _ = loss_fn_seq(log_probs, old_log_probs, advantages) + + # Test sequence_mean with mask + loss_seq_with_mask, _ = loss_fn_seq(log_probs, old_log_probs, advantages, loss_mask) + + # Manual calculations to verify (using default PolicyLoss parameters) + ratio = torch.exp(log_probs - old_log_probs) + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages # clip_eps_low=0.2, clip_eps_high=0.2 + loss_per_token = -torch.min(surr1, surr2) + + # Expected token_mean without mask: mean of all tokens + expected_token_no_mask = loss_per_token.mean() + + # Expected token_mean with mask: masked mean of all tokens + expected_token_with_mask = (loss_per_token * loss_mask).sum() / (loss_mask.sum() + 1e-8) + + # Expected sequence_mean without mask: mean of sequence means + expected_seq_no_mask = loss_per_token.mean(dim=1).mean() + + # Expected sequence_mean with mask: mean of masked sequence means + seq_means_masked = (loss_per_token * loss_mask).sum(dim=1) / (loss_mask.sum(dim=1) + 1e-8) + expected_seq_with_mask = seq_means_masked.mean() + + # Verify results + torch.testing.assert_close(loss_token_no_mask, expected_token_no_mask, rtol=1e-5, atol=1e-8) + torch.testing.assert_close(loss_token_with_mask, expected_token_with_mask, rtol=1e-5, atol=1e-8) + torch.testing.assert_close(loss_seq_no_mask, expected_seq_no_mask, rtol=1e-5, atol=1e-8) + torch.testing.assert_close(loss_seq_with_mask, expected_seq_with_mask, rtol=1e-5, atol=1e-8) + + # Verify that the two reduction modes give the same results when sequences have equal length and no mask + assert torch.allclose( + loss_token_no_mask, loss_seq_no_mask, rtol=1e-5 + ), "token_mean and sequence_mean should give same results when sequences have equal length and no mask" + # But they should give different results when mask creates different effective sequence lengths + assert not torch.allclose( + loss_token_with_mask, loss_seq_with_mask, rtol=1e-3 + ), "token_mean and sequence_mean with mask should give different results" + + +def test_policy_loss_reduction_edge_cases(): + """Tests edge cases for loss_reduction modes.""" + + device = "cpu" + + # Test with single sequence (should give same result for both modes) + advantages = torch.tensor([[1.0, -1.0, 2.0]], device=device) + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + log_probs = torch.tensor([[-1.5, -0.5, -1.2]], device=device) + + loss_fn_token = PolicyLoss(loss_type="regular", loss_reduction="token_mean") + loss_fn_seq = PolicyLoss(loss_type="regular", loss_reduction="sequence_mean") + + loss_token, _ = loss_fn_token(log_probs, old_log_probs, advantages) + loss_seq, _ = loss_fn_seq(log_probs, old_log_probs, advantages) + + # With single sequence, both modes should give same result + torch.testing.assert_close(loss_token, loss_seq, rtol=1e-6, atol=1e-8) + + # Test with completely masked sequence + loss_mask = torch.tensor([[0.0, 0.0, 0.0]], device=device) + loss_token_masked, _ = loss_fn_token(log_probs, old_log_probs, advantages, loss_mask) + loss_seq_masked, _ = loss_fn_seq(log_probs, old_log_probs, advantages, loss_mask) + + # Should handle zero mask gracefully (due to +1e-8 in denominator) + assert torch.isfinite(loss_token_masked) + assert torch.isfinite(loss_seq_masked) diff --git a/skyrl-train/tests/cpu/test_trainer.py b/skyrl-train/tests/cpu/test_trainer.py index 1d15b59d8e..78700b6bee 100644 --- a/skyrl-train/tests/cpu/test_trainer.py +++ b/skyrl-train/tests/cpu/test_trainer.py @@ -86,6 +86,7 @@ def dummy_config(): "value_clip": 0.2, "normalize_reward": True, "ppo_loss_type": "regular", + "loss_reduction": "token_mean", }, "resume_mode": "none", },