From 221b0759e0d11c51b1bd3b7512c8b5c6c8431d69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Tue, 2 Dec 2025 14:48:24 +0800 Subject: [PATCH 1/5] add sequence mask for grpo --- tests/algorithm/policy_loss_test.py | 25 ++++++++++++ .../policy_loss_fn/ppo_policy_loss.py | 38 +++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index b1484d4f41..3d53b6333f 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -115,3 +115,28 @@ def test_mix_policy_loss(self): ) self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss)) self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss)) + + def test_ppo_policy_loss_with_sequence_masking(self): + """Test PPO policy loss with sequence masking enabled""" + policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn_args["enable_sequence_masking"] = True + policy_loss_fn_args["delta"] = 0.1 + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + + # Test that sequence masking metrics are present + self.assertIn("seq_mask/masked_tokens", metrics) + self.assertIn("seq_mask/mean_sequence_kl", metrics) + + # Test that masked_tokens is between 0 and 1 + self.assertGreaterEqual(metrics["seq_mask/masked_tokens"], 0.0) + self.assertLessEqual(metrics["seq_mask/masked_tokens"], 1.0) + + # Test that loss is different from non-masked version (if masking occurred) + policy_loss_fn_no_mask = policy_loss_fn_cls(**policy_loss_fn_cls.default_args()) + loss_no_mask, _ = policy_loss_fn_no_mask(log_prob=self.logprob, **self.input_data.batch) + + # Loss should be different if tokens were masked + if metrics["seq_mask/masked_tokens"] > 0: + self.assertFalse(torch.allclose(loss, loss_no_mask)) diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index f2c812a0b5..632a1148da 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -21,6 +21,8 @@ def __init__( clip_range_high: Optional[float] = None, clip_ratio_c: float = 3.0, loss_agg_mode: Optional[str] = "token-mean", + enable_sequence_masking: bool = False, + delta: float = 0.1, ) -> None: super().__init__(backend=backend) if clip_range_low is None: @@ -36,6 +38,8 @@ def __init__( assert self.clip_range_low is not None, "clip_range_low must be specified." assert self.clip_range_high is not None, "clip_range_high must be specified." self.loss_agg_mode = loss_agg_mode + self.enable_sequence_masking = enable_sequence_masking + self.delta = delta def __call__( # type: ignore self, @@ -51,6 +55,33 @@ def __call__( # type: ignore ratio = torch.exp(negative_approx_kl) ppo_kl = masked_mean(-negative_approx_kl, action_mask) + # Compute sequence masking if enabled + sequence_mask = torch.ones_like(advantages) + if self.enable_sequence_masking: + # Compute sequence-level KL divergence: mean KL per sequence + # Shape: (batch_size, seq_len) -> (batch_size,) + kl_per_token = -negative_approx_kl # KL divergence per token + sequence_kl = (kl_per_token * action_mask).sum(dim=-1) / ( + action_mask.sum(dim=-1) + 1e-10 + ) + + # Create mask: mask out tokens with negative advantages when sequence KL is high + # Token-level advantage check: (batch_size, seq_len) + has_negative_advantage = advantages < 0 + # Sequence-level KL check: (batch_size,) -> (batch_size, 1) -> (batch_size, seq_len) + exceeds_kl_threshold = (sequence_kl > self.delta).unsqueeze(-1).expand_as(advantages) + # Mask tokens that are both negative advantage AND in high-KL sequences + should_mask = has_negative_advantage & exceeds_kl_threshold + sequence_mask = (~should_mask).float() + + metrics_seq_mask = { + "seq_mask/masked_tokens": should_mask.float().sum().item() + / (action_mask.sum().item() + 1e-10), + "seq_mask/mean_sequence_kl": sequence_kl.mean().detach().item(), + } + else: + metrics_seq_mask = {} + pg_losses1 = -advantages * ratio pg_losses2 = -advantages * torch.clamp( ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore @@ -66,6 +97,10 @@ def __call__( # type: ignore torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask ) pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + + # Apply sequence mask to the losses + pg_losses = pg_losses * sequence_mask + pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode) metrics = { "pg_clipfrac": pg_clip_frac.detach().item(), @@ -73,6 +108,7 @@ def __call__( # type: ignore "ppo_kl": ppo_kl.detach().item(), "pg_loss": pg_loss.detach().item(), } + metrics.update(metrics_seq_mask) return pg_loss, metrics @classmethod @@ -81,4 +117,6 @@ def default_args(cls) -> Dict: "clip_range": 0.2, "clip_ratio_c": 3.0, "loss_agg_mode": "token-mean", + "enable_sequence_masking": False, + "delta": 0.1, } From 6400b5d9e82be90eba5731e72a261e5952d13896 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Tue, 2 Dec 2025 17:00:24 +0800 Subject: [PATCH 2/5] udpate policy loss value test --- tests/algorithm/policy_loss_test.py | 36 +++++++++++++++-------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index 3d53b6333f..7ae0337236 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -117,26 +117,28 @@ def test_mix_policy_loss(self): self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss)) def test_ppo_policy_loss_with_sequence_masking(self): - """Test PPO policy loss with sequence masking enabled""" policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") policy_loss_fn_args = policy_loss_fn_cls.default_args() policy_loss_fn_args["enable_sequence_masking"] = True policy_loss_fn_args["delta"] = 0.1 policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) - - # Test that sequence masking metrics are present - self.assertIn("seq_mask/masked_tokens", metrics) - self.assertIn("seq_mask/mean_sequence_kl", metrics) - - # Test that masked_tokens is between 0 and 1 - self.assertGreaterEqual(metrics["seq_mask/masked_tokens"], 0.0) - self.assertLessEqual(metrics["seq_mask/masked_tokens"], 1.0) - - # Test that loss is different from non-masked version (if masking occurred) - policy_loss_fn_no_mask = policy_loss_fn_cls(**policy_loss_fn_cls.default_args()) - loss_no_mask, _ = policy_loss_fn_no_mask(log_prob=self.logprob, **self.input_data.batch) - - # Loss should be different if tokens were masked - if metrics["seq_mask/masked_tokens"] > 0: - self.assertFalse(torch.allclose(loss, loss_no_mask)) + ppo_loss_masked = torch.tensor(0.22175675630569458) + pg_clipfrac = torch.tensor(0.3541666567325592) + ppo_kl = torch.tensor(-0.21663446724414825) + pg_clipfrac_lower = torch.tensor(0.0625) + masked_tokens = torch.tensor(0.16666666666631944) + mean_sequence_kl = torch.tensor(-0.21027061343193054) + self.assertTrue(torch.allclose(loss, ppo_loss_masked)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac)) + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_masked)) + self.assertTrue( + torch.allclose(torch.tensor(metrics["pg_clipfrac_lower"]), pg_clipfrac_lower) + ) + self.assertTrue( + torch.allclose(torch.tensor(metrics["seq_mask/masked_tokens"]), masked_tokens) + ) + self.assertTrue( + torch.allclose(torch.tensor(metrics["seq_mask/mean_sequence_kl"]), mean_sequence_kl) + ) From a2ce985eb8bab289d0a32e1416f107d5df308337 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Thu, 4 Dec 2025 10:15:35 +0800 Subject: [PATCH 3/5] fix pre commit after merge main --- tests/algorithm/policy_loss_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index 32f7c4389b..0daad6ea46 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -142,7 +142,7 @@ def test_ppo_policy_loss_with_sequence_masking(self): self.assertTrue( torch.allclose(torch.tensor(metrics["seq_mask/mean_sequence_kl"]), mean_sequence_kl) ) - + def test_sapo_policy_loss(self): policy_loss_fn_cls = POLICY_LOSS_FN.get("sapo") policy_loss_fn_args = policy_loss_fn_cls.default_args() From e221c8a8fef732e82a3c24d896b61955583e31c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Thu, 4 Dec 2025 10:24:55 +0800 Subject: [PATCH 4/5] not redundent compute --- trinity/algorithm/policy_loss_fn/ppo_policy_loss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index 632a1148da..0f135e7310 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -99,7 +99,8 @@ def __call__( # type: ignore pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) # Apply sequence mask to the losses - pg_losses = pg_losses * sequence_mask + if self.enable_sequence_masking: + pg_losses = pg_losses * sequence_mask pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode) metrics = { From 7763bdde6e1a0bdb3baeb56f7c97341cbae3dd53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Fri, 5 Dec 2025 11:46:03 +0800 Subject: [PATCH 5/5] clean code as suggested --- tests/algorithm/policy_loss_test.py | 2 +- .../policy_loss_fn/ppo_policy_loss.py | 59 +++++++++---------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index 0daad6ea46..d4ddbbf87c 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -120,7 +120,7 @@ def test_ppo_policy_loss_with_sequence_masking(self): policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") policy_loss_fn_args = policy_loss_fn_cls.default_args() policy_loss_fn_args["enable_sequence_masking"] = True - policy_loss_fn_args["delta"] = 0.1 + policy_loss_fn_args["delta_sequence_masking"] = 0.1 policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) ppo_loss_masked = torch.tensor(0.22175675630569458) diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index 0f135e7310..d5e17a83e9 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -21,8 +21,8 @@ def __init__( clip_range_high: Optional[float] = None, clip_ratio_c: float = 3.0, loss_agg_mode: Optional[str] = "token-mean", - enable_sequence_masking: bool = False, - delta: float = 0.1, + enable_sequence_masking: bool = False, # introduced in DeepseekV3.2 + delta_sequence_masking: float = 0.1, ) -> None: super().__init__(backend=backend) if clip_range_low is None: @@ -39,7 +39,7 @@ def __init__( assert self.clip_range_high is not None, "clip_range_high must be specified." self.loss_agg_mode = loss_agg_mode self.enable_sequence_masking = enable_sequence_masking - self.delta = delta + self.delta_sequence_masking = delta_sequence_masking def __call__( # type: ignore self, @@ -55,8 +55,23 @@ def __call__( # type: ignore ratio = torch.exp(negative_approx_kl) ppo_kl = masked_mean(-negative_approx_kl, action_mask) - # Compute sequence masking if enabled - sequence_mask = torch.ones_like(advantages) + pg_losses1 = -advantages * ratio + pg_losses2 = -advantages * torch.clamp( + ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore + ) + + clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2) + + pg_clip_frac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), action_mask) + + pg_losses3 = -advantages * self.clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + pg_clipfrac_lower = masked_mean( + torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask + ) + pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + + # Apply sequence masking if enabled if self.enable_sequence_masking: # Compute sequence-level KL divergence: mean KL per sequence # Shape: (batch_size, seq_len) -> (batch_size,) @@ -69,38 +84,21 @@ def __call__( # type: ignore # Token-level advantage check: (batch_size, seq_len) has_negative_advantage = advantages < 0 # Sequence-level KL check: (batch_size,) -> (batch_size, 1) -> (batch_size, seq_len) - exceeds_kl_threshold = (sequence_kl > self.delta).unsqueeze(-1).expand_as(advantages) + exceeds_kl_threshold = ( + (sequence_kl > self.delta_sequence_masking).unsqueeze(-1).expand_as(advantages) + ) # Mask tokens that are both negative advantage AND in high-KL sequences should_mask = has_negative_advantage & exceeds_kl_threshold sequence_mask = (~should_mask).float() + # Apply sequence mask to the losses + pg_losses = pg_losses * sequence_mask + metrics_seq_mask = { "seq_mask/masked_tokens": should_mask.float().sum().item() / (action_mask.sum().item() + 1e-10), "seq_mask/mean_sequence_kl": sequence_kl.mean().detach().item(), } - else: - metrics_seq_mask = {} - - pg_losses1 = -advantages * ratio - pg_losses2 = -advantages * torch.clamp( - ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore - ) - - clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2) - - pg_clip_frac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), action_mask) - - pg_losses3 = -advantages * self.clip_ratio_c - clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) - pg_clipfrac_lower = masked_mean( - torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask - ) - pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) - - # Apply sequence mask to the losses - if self.enable_sequence_masking: - pg_losses = pg_losses * sequence_mask pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode) metrics = { @@ -109,7 +107,8 @@ def __call__( # type: ignore "ppo_kl": ppo_kl.detach().item(), "pg_loss": pg_loss.detach().item(), } - metrics.update(metrics_seq_mask) + if self.enable_sequence_masking: + metrics.update(metrics_seq_mask) return pg_loss, metrics @classmethod @@ -119,5 +118,5 @@ def default_args(cls) -> Dict: "clip_ratio_c": 3.0, "loss_agg_mode": "token-mean", "enable_sequence_masking": False, - "delta": 0.1, + "delta_sequence_masking": 0.1, }