Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/algorithm/policy_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@ def test_ppo_policy_loss(self):
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss))

def test_gspo_policy_loss(self):
policy_loss_fn_cls = POLICY_LOSS_FN.get("gspo")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
gspo_loss_expected = torch.tensor(0.27235108613967896)
pg_clipfrac_expected = torch.tensor(0.375)
ppo_kl_seq_expected = torch.tensor(-0.21027061343193054)
ppo_kl_expected = torch.tensor(-0.21663446724414825)
print(f"{loss.item()=}, {metrics=}")
self.assertTrue(torch.allclose(loss, gspo_loss_expected))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_expected))
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl_seq"]), ppo_kl_seq_expected))
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl_expected))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), gspo_loss_expected))

def test_sft_policy_loss(self):
policy_loss_fn_cls = POLICY_LOSS_FN.get("sft")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
Expand Down
2 changes: 2 additions & 0 deletions trinity/algorithm/policy_loss_fn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn
from trinity.algorithm.policy_loss_fn.gspo_policy_loss import GSPOLossFn
from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn
from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
Expand All @@ -13,4 +14,5 @@
"DPOLossFn",
"SFTLossFn",
"MIXPolicyLossFn",
"GSPOLossFn",
]
76 changes: 76 additions & 0 deletions trinity/algorithm/policy_loss_fn/gspo_policy_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""GSPO-token policy loss function.

Implemented from https://arxiv.org/pdf/2507.18071
"""

from typing import Dict, Optional, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_mean


@POLICY_LOSS_FN.register_module("gspo")
class GSPOLossFn(PolicyLossFn):
def __init__(
self,
backend: str = "verl",
clip_range: Optional[float] = None,
clip_range_low: Optional[float] = None,
clip_range_high: Optional[float] = None,
) -> None:
super().__init__(backend=backend)
_clip_range_low = clip_range_low if clip_range_low is not None else clip_range
if _clip_range_low is None:
raise ValueError("Either clip_range or clip_range_low must be specified.")
self.clip_range_low = _clip_range_low

_clip_range_high = clip_range_high if clip_range_high is not None else clip_range
if _clip_range_high is None:
raise ValueError("Either clip_range or clip_range_high must be specified.")
self.clip_range_high = _clip_range_high

def __call__( # type: ignore
self,
logprob: torch.Tensor, # [batch_size, seq_len]
old_logprob: torch.Tensor, # [batch_size, seq_len]
action_mask: torch.Tensor, # [batch_size, seq_len]
advantages: torch.Tensor, # [batch_size, seq_len]
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
negative_approx_kl = logprob - old_logprob # [batch_size, seq_len]
negative_approx_kl_seq = masked_mean(
negative_approx_kl, action_mask, axis=-1
) # [batch_size]
log_seq_importance_ratio = (
logprob - logprob.detach() + negative_approx_kl_seq.detach().unsqueeze(-1)
) # [batch_size, seq_len]
ratio = torch.exp(log_seq_importance_ratio) # [batch_size, seq_len]
pg_losses = -advantages * ratio # [batch_size, seq_len]
pg_losses_clipped = -advantages * torch.clamp(
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high
) # [batch_size, seq_len]

seq_losses = masked_mean(
torch.max(pg_losses, pg_losses_clipped), action_mask, axis=-1
) # [batch_size]
pg_loss = torch.mean(seq_losses)
pg_clipfrac = masked_mean(torch.gt(pg_losses_clipped, pg_losses).float(), action_mask)
ppo_kl = masked_mean(-negative_approx_kl, action_mask)
ppo_kl_seq = torch.mean(-negative_approx_kl_seq)
metrics = {
"pg_clipfrac": pg_clipfrac.detach().item(),
"ppo_kl": ppo_kl.detach().item(),
"pg_loss": pg_loss.detach().item(),
"ppo_kl_seq": ppo_kl_seq.detach().item(),
}
return pg_loss, metrics

@classmethod
def default_args(cls) -> Dict:
# See discussion in https://github.com/volcengine/verl/pull/2775#issuecomment-3130065984
return {
"clip_range_low": 0.0003,
"clip_range_high": 0.0004,
}