From a9417a84f0475c7660ae16830715a09994f8ab02 Mon Sep 17 00:00:00 2001 From: Nikolai Karpov Date: Wed, 30 Jul 2025 14:43:58 -0700 Subject: [PATCH 1/6] implement gspo-token policy loss function --- tests/algorithm/policy_loss_test.py | 13 ++++ trinity/algorithm/policy_loss_fn/__init__.py | 2 + .../policy_loss_fn/gspo_policy_loss.py | 71 +++++++++++++++++++ 3 files changed, 86 insertions(+) create mode 100644 trinity/algorithm/policy_loss_fn/gspo_policy_loss.py diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index b6bd5da21d..f21d44a0a4 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -43,6 +43,19 @@ def test_ppo_policy_loss(self): self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss)) + def test_gspo_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("gspo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + gspo_loss = torch.tensor(0.1385955959558487) + pg_clipfrac = torch.tensor(0.0416666679084301) + ppo_kl = torch.tensor(-0.021903187036514282) + self.assertTrue(torch.allclose(loss, gspo_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac)) + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), gspo_loss)) + def test_sft_policy_loss(self): policy_loss_fn_cls = POLICY_LOSS_FN.get("sft") policy_loss_fn_args = policy_loss_fn_cls.default_args() diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index 705fb2525a..4d88215a0a 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -1,4 +1,5 @@ from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn +from trinity.algorithm.policy_loss_fn.gspo_policy_loss import GSPOLossFn from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn @@ -13,4 +14,5 @@ "DPOLossFn", "SFTLossFn", "MIXPolicyLossFn", + "GSPOLossFn", ] diff --git a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py new file mode 100644 index 0000000000..ab691d46ec --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py @@ -0,0 +1,71 @@ +"""GSPO-token policy loss function. + +Implemented from https://arxiv.org/pdf/2507.18071 +""" + +from typing import Dict, Optional, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_mean + + +@POLICY_LOSS_FN.register_module("gspo") +class GSPOLossFn(PolicyLossFn): + def __init__( + self, + backend: str = "verl", + clip_range: Optional[float] = None, + clip_range_low: Optional[float] = None, + clip_range_high: Optional[float] = None, + ) -> None: + super().__init__(backend=backend) + if clip_range_low is None: + if clip_range is None: + raise ValueError("Either clip_range or clip_range_low must be specified.") + self.clip_range_low = clip_range + else: + self.clip_range_low = clip_range_low + + if clip_range_high is None: + if clip_range is None: + raise ValueError("Either clip_range or clip_range_high must be specified.") + self.clip_range_high = clip_range + else: + self.clip_range_high = clip_range_high + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + negative_approx_kl = logprob - old_logprob + seq_lengths = torch.sum(action_mask, dim=-1).clamp(min=1).unsqueeze(-1) + negative_approx_kl = negative_approx_kl / seq_lengths + log_seq_importance_ratio = logprob - logprob.detach() + negative_approx_kl.detach() + ratio = torch.exp(log_seq_importance_ratio) + ppo_kl = masked_mean(-negative_approx_kl, action_mask) + + pg_losses = -advantages * ratio + pg_losses2 = -advantages * torch.clamp( + ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high + ) + + pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), action_mask) + pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask) + metrics = { + "pg_clipfrac": pg_clipfrac.detach().item(), + "ppo_kl": ppo_kl.detach().item(), + "pg_loss": pg_loss.detach().item(), + } + return pg_loss, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "clip_range": 0.2, + } From 8217109f15713ba658975ccf11bee1277bf19efb Mon Sep 17 00:00:00 2001 From: Nikolai Karpov Date: Wed, 30 Jul 2025 15:49:33 -0700 Subject: [PATCH 2/6] Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/algorithm/policy_loss_test.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index f21d44a0a4..18384d4ee9 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -48,13 +48,13 @@ def test_gspo_policy_loss(self): policy_loss_fn_args = policy_loss_fn_cls.default_args() policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) - gspo_loss = torch.tensor(0.1385955959558487) - pg_clipfrac = torch.tensor(0.0416666679084301) - ppo_kl = torch.tensor(-0.021903187036514282) - self.assertTrue(torch.allclose(loss, gspo_loss)) - self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac)) - self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) - self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), gspo_loss)) + gspo_loss_expected = torch.tensor(0.1385955959558487) + pg_clipfrac_expected = torch.tensor(0.0416666679084301) + ppo_kl_expected = torch.tensor(-0.021903187036514282) + self.assertTrue(torch.allclose(loss, gspo_loss_expected)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_expected)) + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl_expected)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), gspo_loss_expected)) def test_sft_policy_loss(self): policy_loss_fn_cls = POLICY_LOSS_FN.get("sft") From f2698dc1bbca0c43c92a62746c2d1f9076839cd2 Mon Sep 17 00:00:00 2001 From: Nikolai Karpov Date: Thu, 31 Jul 2025 13:57:41 -0700 Subject: [PATCH 3/6] adress comments --- tests/algorithm/policy_loss_test.py | 7 +++- .../policy_loss_fn/gspo_policy_loss.py | 41 +++++++++---------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index 18384d4ee9..358eee600c 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -48,9 +48,12 @@ def test_gspo_policy_loss(self): policy_loss_fn_args = policy_loss_fn_cls.default_args() policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) - gspo_loss_expected = torch.tensor(0.1385955959558487) - pg_clipfrac_expected = torch.tensor(0.0416666679084301) + gspo_loss_expected = torch.tensor(0.15352888405323029) + pg_clipfrac_expected = torch.tensor(0.4791666567325592) ppo_kl_expected = torch.tensor(-0.021903187036514282) + print("gspo_loss_expected:", loss.item()) + print("pg_clipfrac_expected:", metrics["pg_clipfrac"]) + print("ppo_kl_expected:", metrics["ppo_kl"]) self.assertTrue(torch.allclose(loss, gspo_loss_expected)) self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_expected)) self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl_expected)) diff --git a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py index ab691d46ec..45da10af9d 100644 --- a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py @@ -21,19 +21,15 @@ def __init__( clip_range_high: Optional[float] = None, ) -> None: super().__init__(backend=backend) - if clip_range_low is None: - if clip_range is None: - raise ValueError("Either clip_range or clip_range_low must be specified.") - self.clip_range_low = clip_range - else: - self.clip_range_low = clip_range_low + _clip_range_low = clip_range_low if clip_range_low is not None else clip_range + if _clip_range_low is None: + raise ValueError("Either clip_range or clip_range_low must be specified.") + self.clip_range_low = _clip_range_low - if clip_range_high is None: - if clip_range is None: - raise ValueError("Either clip_range or clip_range_high must be specified.") - self.clip_range_high = clip_range - else: - self.clip_range_high = clip_range_high + _clip_range_high = clip_range_high if clip_range_high is not None else clip_range + if _clip_range_high is None: + raise ValueError("Either clip_range or clip_range_high must be specified.") + self.clip_range_high = _clip_range_high def __call__( # type: ignore self, @@ -43,20 +39,21 @@ def __call__( # type: ignore advantages: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: - negative_approx_kl = logprob - old_logprob - seq_lengths = torch.sum(action_mask, dim=-1).clamp(min=1).unsqueeze(-1) - negative_approx_kl = negative_approx_kl / seq_lengths - log_seq_importance_ratio = logprob - logprob.detach() + negative_approx_kl.detach() + seq_lengths = action_mask.sum(dim=-1, keepdim=True).clamp_min(1) + negative_approx_kl_seq = (logprob - old_logprob) * action_mask / seq_lengths + log_seq_importance_ratio = logprob - logprob.detach() + negative_approx_kl_seq.detach() ratio = torch.exp(log_seq_importance_ratio) - ppo_kl = masked_mean(-negative_approx_kl, action_mask) + pg_losses = -advantages * ratio - pg_losses2 = -advantages * torch.clamp( + pg_losses_clipped = -advantages * torch.clamp( ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high ) + + pg_loss = masked_mean(torch.max(pg_losses, pg_losses_clipped), action_mask) + pg_clipfrac = masked_mean(torch.gt(pg_losses_clipped, pg_losses).float(), action_mask) + ppo_kl = masked_mean(-negative_approx_kl_seq, action_mask) - pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), action_mask) - pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask) metrics = { "pg_clipfrac": pg_clipfrac.detach().item(), "ppo_kl": ppo_kl.detach().item(), @@ -66,6 +63,8 @@ def __call__( # type: ignore @classmethod def default_args(cls) -> Dict: + # See discussion in https://github.com/volcengine/verl/pull/2775#issuecomment-3130065984 return { - "clip_range": 0.2, + "clip_range_low": 0.0003, + "clip_range_high": 0.0004, } From dce3ecd3501144ae533b824d86f9584f05e6461f Mon Sep 17 00:00:00 2001 From: Nikolai Karpov Date: Fri, 1 Aug 2025 12:52:07 -0700 Subject: [PATCH 4/6] Refactor GSPO policy loss function to improve clarity and accuracy of metrics --- tests/algorithm/policy_loss_test.py | 7 ++--- .../policy_loss_fn/gspo_policy_loss.py | 28 +++++++++---------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index 358eee600c..1dccb4c801 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -50,12 +50,11 @@ def test_gspo_policy_loss(self): loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) gspo_loss_expected = torch.tensor(0.15352888405323029) pg_clipfrac_expected = torch.tensor(0.4791666567325592) - ppo_kl_expected = torch.tensor(-0.021903187036514282) - print("gspo_loss_expected:", loss.item()) - print("pg_clipfrac_expected:", metrics["pg_clipfrac"]) - print("ppo_kl_expected:", metrics["ppo_kl"]) + ppo_kl_seq_expected = torch.tensor(-0.021903187036514282) + ppo_kl_expected = torch.tensor(-0.21663446724414825) self.assertTrue(torch.allclose(loss, gspo_loss_expected)) self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_expected)) + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl_seq"]), ppo_kl_seq_expected)) self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl_expected)) self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), gspo_loss_expected)) diff --git a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py index 45da10af9d..a810927525 100644 --- a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py @@ -33,31 +33,31 @@ def __init__( def __call__( # type: ignore self, - logprob: torch.Tensor, - old_logprob: torch.Tensor, - action_mask: torch.Tensor, - advantages: torch.Tensor, + logprob: torch.Tensor, # [batch_size, seq_len] + old_logprob: torch.Tensor, # [batch_size, seq_len] + action_mask: torch.Tensor, # [batch_size, seq_len] + advantages: torch.Tensor, # [batch_size, seq_len] **kwargs, ) -> Tuple[torch.Tensor, Dict]: - seq_lengths = action_mask.sum(dim=-1, keepdim=True).clamp_min(1) - negative_approx_kl_seq = (logprob - old_logprob) * action_mask / seq_lengths - log_seq_importance_ratio = logprob - logprob.detach() + negative_approx_kl_seq.detach() - ratio = torch.exp(log_seq_importance_ratio) - - - pg_losses = -advantages * ratio + seq_lengths = action_mask.sum(dim=-1, keepdim=True).clamp_min(1) # [batch_size, 1] + negative_approx_kl = logprob - old_logprob # [batch_size, seq_len] + negative_approx_kl_seq = negative_approx_kl * action_mask / seq_lengths # [batch_size, seq_len] + log_seq_importance_ratio = logprob - logprob.detach() + negative_approx_kl_seq.detach() # [batch_size, seq_len] + ratio = torch.exp(log_seq_importance_ratio) # [batch_size, seq_len] + pg_losses = -advantages * ratio # [batch_size, seq_len] pg_losses_clipped = -advantages * torch.clamp( ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high - ) + ) # [batch_size, seq_len] pg_loss = masked_mean(torch.max(pg_losses, pg_losses_clipped), action_mask) pg_clipfrac = masked_mean(torch.gt(pg_losses_clipped, pg_losses).float(), action_mask) - ppo_kl = masked_mean(-negative_approx_kl_seq, action_mask) - + ppo_kl = masked_mean(-negative_approx_kl, action_mask) + ppo_kl_seq = masked_mean(-negative_approx_kl_seq, action_mask) metrics = { "pg_clipfrac": pg_clipfrac.detach().item(), "ppo_kl": ppo_kl.detach().item(), "pg_loss": pg_loss.detach().item(), + "ppo_kl_seq": ppo_kl_seq.detach().item(), } return pg_loss, metrics From 80112a78f7edeb4fb45b7d7243a82d068ea905ed Mon Sep 17 00:00:00 2001 From: Nikolai Karpov Date: Fri, 1 Aug 2025 22:28:36 -0700 Subject: [PATCH 5/6] apply suggestions --- tests/algorithm/policy_loss_test.py | 7 ++++--- trinity/algorithm/policy_loss_fn/gspo_policy_loss.py | 10 +++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index 1dccb4c801..134635c05a 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -48,10 +48,11 @@ def test_gspo_policy_loss(self): policy_loss_fn_args = policy_loss_fn_cls.default_args() policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) - gspo_loss_expected = torch.tensor(0.15352888405323029) - pg_clipfrac_expected = torch.tensor(0.4791666567325592) - ppo_kl_seq_expected = torch.tensor(-0.021903187036514282) + gspo_loss_expected = torch.tensor(0.27235108613967896) + pg_clipfrac_expected = torch.tensor(0.375) + ppo_kl_seq_expected = torch.tensor(-0.21027061343193054) ppo_kl_expected = torch.tensor(-0.21663446724414825) + print(f"{loss.item()=}, {metrics=}") self.assertTrue(torch.allclose(loss, gspo_loss_expected)) self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_expected)) self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl_seq"]), ppo_kl_seq_expected)) diff --git a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py index a810927525..7f96d65386 100644 --- a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py @@ -39,20 +39,20 @@ def __call__( # type: ignore advantages: torch.Tensor, # [batch_size, seq_len] **kwargs, ) -> Tuple[torch.Tensor, Dict]: - seq_lengths = action_mask.sum(dim=-1, keepdim=True).clamp_min(1) # [batch_size, 1] negative_approx_kl = logprob - old_logprob # [batch_size, seq_len] - negative_approx_kl_seq = negative_approx_kl * action_mask / seq_lengths # [batch_size, seq_len] - log_seq_importance_ratio = logprob - logprob.detach() + negative_approx_kl_seq.detach() # [batch_size, seq_len] + negative_approx_kl_seq = masked_mean(negative_approx_kl, action_mask, axis=-1) # [batch_size] + log_seq_importance_ratio = logprob - logprob.detach() + negative_approx_kl_seq.detach().unsqueeze(-1) # [batch_size, seq_len] ratio = torch.exp(log_seq_importance_ratio) # [batch_size, seq_len] pg_losses = -advantages * ratio # [batch_size, seq_len] pg_losses_clipped = -advantages * torch.clamp( ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high ) # [batch_size, seq_len] - pg_loss = masked_mean(torch.max(pg_losses, pg_losses_clipped), action_mask) + seq_losses = masked_mean(torch.max(pg_losses, pg_losses_clipped), action_mask, axis=-1) # [batch_size] + pg_loss = torch.mean(seq_losses) pg_clipfrac = masked_mean(torch.gt(pg_losses_clipped, pg_losses).float(), action_mask) ppo_kl = masked_mean(-negative_approx_kl, action_mask) - ppo_kl_seq = masked_mean(-negative_approx_kl_seq, action_mask) + ppo_kl_seq = torch.mean(-negative_approx_kl_seq) metrics = { "pg_clipfrac": pg_clipfrac.detach().item(), "ppo_kl": ppo_kl.detach().item(), From 9d6e5965dca158b101275445a01c4b36f83f6ab8 Mon Sep 17 00:00:00 2001 From: Nikolai Karpov Date: Sat, 2 Aug 2025 10:46:09 -0700 Subject: [PATCH 6/6] make pre-commit happy --- .../policy_loss_fn/gspo_policy_loss.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py index 7f96d65386..2cac58244c 100644 --- a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py @@ -33,22 +33,28 @@ def __init__( def __call__( # type: ignore self, - logprob: torch.Tensor, # [batch_size, seq_len] - old_logprob: torch.Tensor, # [batch_size, seq_len] - action_mask: torch.Tensor, # [batch_size, seq_len] - advantages: torch.Tensor, # [batch_size, seq_len] + logprob: torch.Tensor, # [batch_size, seq_len] + old_logprob: torch.Tensor, # [batch_size, seq_len] + action_mask: torch.Tensor, # [batch_size, seq_len] + advantages: torch.Tensor, # [batch_size, seq_len] **kwargs, ) -> Tuple[torch.Tensor, Dict]: - negative_approx_kl = logprob - old_logprob # [batch_size, seq_len] - negative_approx_kl_seq = masked_mean(negative_approx_kl, action_mask, axis=-1) # [batch_size] - log_seq_importance_ratio = logprob - logprob.detach() + negative_approx_kl_seq.detach().unsqueeze(-1) # [batch_size, seq_len] - ratio = torch.exp(log_seq_importance_ratio) # [batch_size, seq_len] - pg_losses = -advantages * ratio # [batch_size, seq_len] + negative_approx_kl = logprob - old_logprob # [batch_size, seq_len] + negative_approx_kl_seq = masked_mean( + negative_approx_kl, action_mask, axis=-1 + ) # [batch_size] + log_seq_importance_ratio = ( + logprob - logprob.detach() + negative_approx_kl_seq.detach().unsqueeze(-1) + ) # [batch_size, seq_len] + ratio = torch.exp(log_seq_importance_ratio) # [batch_size, seq_len] + pg_losses = -advantages * ratio # [batch_size, seq_len] pg_losses_clipped = -advantages * torch.clamp( ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high - ) # [batch_size, seq_len] - - seq_losses = masked_mean(torch.max(pg_losses, pg_losses_clipped), action_mask, axis=-1) # [batch_size] + ) # [batch_size, seq_len] + + seq_losses = masked_mean( + torch.max(pg_losses, pg_losses_clipped), action_mask, axis=-1 + ) # [batch_size] pg_loss = torch.mean(seq_losses) pg_clipfrac = masked_mean(torch.gt(pg_losses_clipped, pg_losses).float(), action_mask) ppo_kl = masked_mean(-negative_approx_kl, action_mask)