From 4b87ed4d0234bc91bc3706bf4e8dcdbcd488126c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Thu, 14 Aug 2025 15:40:13 +0800 Subject: [PATCH 01/14] add topr and cispo algorithm --- .../algorithm/add_strategy/add_strategy.py | 32 ++++++++ trinity/algorithm/algorithm.py | 46 +++++++++++ .../policy_loss_fn/cispo_policy_loss.py | 76 +++++++++++++++++ .../policy_loss_fn/trpo_policy_loss.py | 82 +++++++++++++++++++ 4 files changed, 236 insertions(+) create mode 100644 trinity/algorithm/policy_loss_fn/cispo_policy_loss.py create mode 100644 trinity/algorithm/policy_loss_fn/trpo_policy_loss.py diff --git a/trinity/algorithm/add_strategy/add_strategy.py b/trinity/algorithm/add_strategy/add_strategy.py index 3984d49759..f909b4d589 100644 --- a/trinity/algorithm/add_strategy/add_strategy.py +++ b/trinity/algorithm/add_strategy/add_strategy.py @@ -91,6 +91,38 @@ async def add(self, exps: List[Experience], step: int) -> Tuple[int, Dict]: return cnt, metrics +@ADD_STRATEGY.register_module("reinforce") +class REINFORCEAddStrategy(GroupAdvantageStrategy): + """An example AddStrategy that simply use rewards as advantages.""" + + def __init__(self, writer: BufferWriter, **kwargs) -> None: + super().__init__(writer) + + def group_experiences(self, exps): + return group_by(exps, id_type="task") + + def calculate_group_advantage( + self, group_id: str, exps: List[Experience] + ) -> Tuple[List[Experience], Dict]: + with torch.no_grad(): + rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) + group_reward_mean = torch.mean(rewards) + for exp in exps: + score = torch.tensor(exp.reward, dtype=torch.float32) + exp.advantages = score * exp.action_mask + exp.returns = exp.advantages.clone() + + metrics = { + "reward_mean": group_reward_mean.item(), + } + + return exps, metrics + + @classmethod + def default_args(cls) -> dict: + return {} + + @ADD_STRATEGY.register_module("grpo") class GRPOAddStrategy(GroupAdvantageStrategy): """An example AddStrategy that calculates GRPO advantages.""" diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 640ceff1ef..51322f5640 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -176,6 +176,52 @@ def check_config(cls, config: Config) -> None: logger.warning("DPO must use KL loss. Set `algorithm.kl_loss_fn` to `k2`") +@ALGORITHM_TYPE.register_module("topr") +class TOPRAlgorithm(AlgorithmType): + """TOPR algorithm. See https://arxiv.org/pdf/2503.14286v1""" + + use_critic: bool = False + use_reference: bool = True + compute_advantage_in_trainer: bool = False + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "add_strategy": "reinforce", # or simply use grpo + "sample_strategy": "warmup", + "policy_loss_fn": "topr", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } + + +@ALGORITHM_TYPE.register_module("cispo") +class CISPOAlgorithm(AlgorithmType): + """CISPO algorithm. See https://arxiv.org/abs/2506.13585""" + + use_critic: bool = False + use_reference: bool = True + compute_advantage_in_trainer: bool = False + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "add_strategy": "grpo", + "sample_strategy": "warmup", + "policy_loss_fn": "cispo", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } + + @ALGORITHM_TYPE.register_module("mix") class MIXAlgorithm(AlgorithmType): """MIX algorithm.""" diff --git a/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py b/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py new file mode 100644 index 0000000000..f344389a91 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py @@ -0,0 +1,76 @@ +"""CISPO policy loss function. +Refer to https://arxiv.org/abs/2506.13585 for details. +""" + +from typing import Dict, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_mean + + +@POLICY_LOSS_FN.register_module("cispo") +class CISPOPolicyLossFn(PolicyLossFn): + def __init__( + self, + backend: str = "verl", + clip_range_low: float = 1.0, + clip_range_high: float = 0.28, + ) -> None: + super().__init__(backend=backend) + self.clip_range_low = clip_range_low + self.clip_range_high = clip_range_high + assert self.clip_range_low is not None, "clip_range_low must be specified." + assert self.clip_range_high is not None, "clip_range_high must be specified." + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + negative_approx_kl = logprob - old_logprob + ratio = torch.exp(negative_approx_kl) + ppo_kl = masked_mean(-negative_approx_kl, action_mask) + ratio_clamped = torch.clamp( + ratio, min=1.0 - self.clip_range_low, max=1.0 + self.clip_range_high + ) + + # mask = 0 if ratio > 1.0 + self.clip_range_high and advantages > 0 + # mask = 0 if ratio < 1.0 - self.clip_range_low and advantages < 0 + # else 1 + mask = torch.ones_like(ratio) + mask = torch.where( + (ratio > 1.0 + self.clip_range_high) & (advantages > 0), torch.zeros_like(ratio), mask + ) + mask = torch.where( + (ratio < 1.0 - self.clip_range_low) & (advantages < 0), torch.zeros_like(ratio), mask + ) + + cispo_loss = -advantages * ratio_clamped.detach() * mask.detach() * logprob + + loss = masked_mean(cispo_loss, action_mask) + masked_frac = masked_mean(mask, action_mask) + + metrics = { + "cispo_loss": cispo_loss.detach().item(), + "ppo_kl": ppo_kl.detach().item(), + "masked_frac": masked_frac.detach().item(), + } + + return loss, metrics + + @classmethod + def default_args(cls) -> Dict: + """ + In the original paper: + we did not impose a lower bound on the IS weight by setting clip_range_low to a high value, instead, we only tuned clip_range_high + + """ + return { + "clip_range_low": 1.0, + "clip_range_high": 0.28, + } diff --git a/trinity/algorithm/policy_loss_fn/trpo_policy_loss.py b/trinity/algorithm/policy_loss_fn/trpo_policy_loss.py new file mode 100644 index 0000000000..cb31c6d528 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/trpo_policy_loss.py @@ -0,0 +1,82 @@ +"""TOPR policy loss function. +Refer to https://arxiv.org/pdf/2503.14286v1 +""" +from typing import Dict, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_mean + + +@POLICY_LOSS_FN.register_module("topr") +class TOPRPolicyLossFn(PolicyLossFn): + def __init__( + self, + backend: str = "verl", + advantage_threshold: float = 0.0, + ) -> None: + super().__init__(backend=backend) + self.advantage_threshold = advantage_threshold + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, # In TOPR, this is actually the rewards R(x,y) + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + """ + Compute TOPR policy loss. + + In TOPR: + - α = [π(y|x)/μ(y|x)]^1 if R(x,y) < threshold else 1 + - loss = -α * r(x,y) * log π(y|x) + + Since we want to maximize α * r(x,y) * log π(y|x), we minimize its negative. + + Args: + logprob: Log probabilities from current policy π + old_logprob: Log probabilities from reference policy μ + action_mask: Mask for valid actions + advantages: Rewards R(x,y) in TOPR terminology + """ + # in Orginal TOPR paper, advantages are simply rewards + # However, we can use advantages as rewards(Baseline Trick) + rewards = advantages + + # Compute ratio π(y|x) / μ(y|x) in log space for numerical stability + log_ratio = logprob - old_logprob + ratio = torch.exp(log_ratio) + ratio_clamped = torch.clamp(ratio, min=0.0, max=1.0) + + # Apply TOPR's conditional weighting: + # α = ratio clamp min=0 max=1 if R(x,y) <= threshold else 1 + alpha = torch.where( + rewards <= self.advantage_threshold, ratio_clamped, torch.ones_like(ratio) + ) + + # TOPR loss: l = -α * r(x,y) * log π(y|x) + # We want to maximize α * r(x,y) * log π(y|x), so minimize the negative + topr_loss = -alpha.detach() * rewards * logprob # detach alpha as it's used with stop-grad + + # Apply masking and compute mean + loss = masked_mean(topr_loss, action_mask) + + # Average alpha value for monitoring + avg_alpha = masked_mean(alpha, action_mask) + + metrics = { + "topr_loss": loss.detach().item(), + "avg_alpha": avg_alpha.detach().item(), + "avg_ratio": masked_mean(ratio, action_mask).detach().item(), + } + + return loss, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "advantage_threshold": 0.0, + } From c51053a064354f5e7fd11827130b6769db4e61aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Thu, 14 Aug 2025 15:45:56 +0800 Subject: [PATCH 02/14] register new algorithm and polish functions --- trinity/algorithm/add_strategy/__init__.py | 2 ++ trinity/algorithm/policy_loss_fn/__init__.py | 4 ++++ trinity/algorithm/policy_loss_fn/trpo_policy_loss.py | 12 ++---------- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/trinity/algorithm/add_strategy/__init__.py b/trinity/algorithm/add_strategy/__init__.py index 1a9ee3af7f..0cc0573f74 100644 --- a/trinity/algorithm/add_strategy/__init__.py +++ b/trinity/algorithm/add_strategy/__init__.py @@ -3,6 +3,7 @@ AddStrategy, GRPOAddStrategy, OPMDAddStrategy, + REINFORCEAddStrategy, RewardVarianceAddStrategy, ) from trinity.algorithm.add_strategy.correct_bias_add_strategy import ( @@ -20,6 +21,7 @@ "OPMDAddStrategy", "StepWiseGRPOStrategy", "RewardVarianceAddStrategy", + "REINFORCEAddStrategy", "CorrectBiasAddStrategy", "DuplicateInformativeAddStrategy", ] diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index 4d88215a0a..d734c51123 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -1,3 +1,4 @@ +from trinity.algorithm.policy_loss_fn.cispo_policy_loss import CISPOPolicyLossFn from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn from trinity.algorithm.policy_loss_fn.gspo_policy_loss import GSPOLossFn from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn @@ -5,6 +6,7 @@ from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn +from trinity.algorithm.policy_loss_fn.trpo_policy_loss import TRPOPolicyLossFn __all__ = [ "POLICY_LOSS_FN", @@ -15,4 +17,6 @@ "SFTLossFn", "MIXPolicyLossFn", "GSPOLossFn", + "TRPOPolicyLossFn", + "CISPOPolicyLossFn", ] diff --git a/trinity/algorithm/policy_loss_fn/trpo_policy_loss.py b/trinity/algorithm/policy_loss_fn/trpo_policy_loss.py index cb31c6d528..b58618465d 100644 --- a/trinity/algorithm/policy_loss_fn/trpo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/trpo_policy_loss.py @@ -31,16 +31,8 @@ def __call__( # type: ignore Compute TOPR policy loss. In TOPR: - - α = [π(y|x)/μ(y|x)]^1 if R(x,y) < threshold else 1 - - loss = -α * r(x,y) * log π(y|x) - - Since we want to maximize α * r(x,y) * log π(y|x), we minimize its negative. - - Args: - logprob: Log probabilities from current policy π - old_logprob: Log probabilities from reference policy μ - action_mask: Mask for valid actions - advantages: Rewards R(x,y) in TOPR terminology + - α = [π(y|x)/μ(y|x)]_0^1 if R(x,y) <= threshold else 1 + - loss = -sg(α) * r(x,y) * log π(y|x) """ # in Orginal TOPR paper, advantages are simply rewards # However, we can use advantages as rewards(Baseline Trick) From 7f66302c94cd0943fffbd42000ae395457db7097 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Thu, 14 Aug 2025 15:48:21 +0800 Subject: [PATCH 03/14] fix typos and mistakes --- trinity/algorithm/policy_loss_fn/__init__.py | 2 +- trinity/algorithm/policy_loss_fn/cispo_policy_loss.py | 2 +- .../policy_loss_fn/{trpo_policy_loss.py => topr_policy_loss.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename trinity/algorithm/policy_loss_fn/{trpo_policy_loss.py => topr_policy_loss.py} (100%) diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index d734c51123..8233d4d11e 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -6,7 +6,7 @@ from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn -from trinity.algorithm.policy_loss_fn.trpo_policy_loss import TRPOPolicyLossFn +from trinity.algorithm.policy_loss_fn.topr_policy_loss import TRPOPolicyLossFn __all__ = [ "POLICY_LOSS_FN", diff --git a/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py b/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py index f344389a91..d775077214 100644 --- a/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py @@ -56,7 +56,7 @@ def __call__( # type: ignore masked_frac = masked_mean(mask, action_mask) metrics = { - "cispo_loss": cispo_loss.detach().item(), + "cispo_loss": loss.detach().item(), "ppo_kl": ppo_kl.detach().item(), "masked_frac": masked_frac.detach().item(), } diff --git a/trinity/algorithm/policy_loss_fn/trpo_policy_loss.py b/trinity/algorithm/policy_loss_fn/topr_policy_loss.py similarity index 100% rename from trinity/algorithm/policy_loss_fn/trpo_policy_loss.py rename to trinity/algorithm/policy_loss_fn/topr_policy_loss.py From 5db9d8692cfe9b8ede6696f109a53834db1970b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Fri, 15 Aug 2025 17:15:25 +0800 Subject: [PATCH 04/14] fix typo --- trinity/algorithm/policy_loss_fn/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index 8233d4d11e..91b826baf7 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -6,7 +6,7 @@ from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn -from trinity.algorithm.policy_loss_fn.topr_policy_loss import TRPOPolicyLossFn +from trinity.algorithm.policy_loss_fn.topr_policy_loss import TOPRPolicyLossFn __all__ = [ "POLICY_LOSS_FN", @@ -17,6 +17,6 @@ "SFTLossFn", "MIXPolicyLossFn", "GSPOLossFn", - "TRPOPolicyLossFn", + "TOPRPolicyLossFn", "CISPOPolicyLossFn", ] From 375f90d95337c4511bc62bfa41d9438277d9e85f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Wed, 27 Aug 2025 16:00:26 +0800 Subject: [PATCH 05/14] del add strategy, make it consistent with main --- trinity/algorithm/add_strategy/__init__.py | 27 -- .../algorithm/add_strategy/add_strategy.py | 262 ------------------ trinity/algorithm/advantage_fn/__init__.py | 2 + trinity/algorithm/algorithm.py | 4 +- trinity/algorithm/policy_loss_fn/__init__.py | 8 +- 5 files changed, 5 insertions(+), 298 deletions(-) delete mode 100644 trinity/algorithm/add_strategy/__init__.py delete mode 100644 trinity/algorithm/add_strategy/add_strategy.py diff --git a/trinity/algorithm/add_strategy/__init__.py b/trinity/algorithm/add_strategy/__init__.py deleted file mode 100644 index 0cc0573f74..0000000000 --- a/trinity/algorithm/add_strategy/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -from trinity.algorithm.add_strategy.add_strategy import ( - ADD_STRATEGY, - AddStrategy, - GRPOAddStrategy, - OPMDAddStrategy, - REINFORCEAddStrategy, - RewardVarianceAddStrategy, -) -from trinity.algorithm.add_strategy.correct_bias_add_strategy import ( - CorrectBiasAddStrategy, -) -from trinity.algorithm.add_strategy.duplicate_add_strategy import ( - DuplicateInformativeAddStrategy, -) -from trinity.algorithm.add_strategy.step_wise_add_strategy import StepWiseGRPOStrategy - -__all__ = [ - "ADD_STRATEGY", - "AddStrategy", - "GRPOAddStrategy", - "OPMDAddStrategy", - "StepWiseGRPOStrategy", - "RewardVarianceAddStrategy", - "REINFORCEAddStrategy", - "CorrectBiasAddStrategy", - "DuplicateInformativeAddStrategy", -] diff --git a/trinity/algorithm/add_strategy/add_strategy.py b/trinity/algorithm/add_strategy/add_strategy.py deleted file mode 100644 index f909b4d589..0000000000 --- a/trinity/algorithm/add_strategy/add_strategy.py +++ /dev/null @@ -1,262 +0,0 @@ -import asyncio -from abc import ABC, abstractmethod -from typing import Dict, List, Literal, Tuple - -import numpy as np -import torch - -from trinity.buffer import BufferWriter -from trinity.common.experience import Experience -from trinity.utils.monitor import gather_metrics -from trinity.utils.registry import Registry -from trinity.utils.timer import Timer - -ADD_STRATEGY = Registry("add_strategy") - - -class AddStrategy(ABC): - def __init__(self, writer: BufferWriter, **kwargs) -> None: - self.writer = writer - - @abstractmethod - async def add(self, experiences: List[Experience], step: int) -> Tuple[int, Dict]: - """Add experiences to the buffer. - - Args: - experiences (`Experience`): The experiences to be added. - step (`int`): The current step number. - - Returns: - `int`: The number of experiences added to the buffer. - `Dict`: Metrics for logging. - """ - - @classmethod - @abstractmethod - def default_args(cls) -> dict: - """Get the default arguments of the add strategy. - - Returns: - `dict`: The default arguments. - """ - - -class GroupAdvantageStrategy(AddStrategy): - """An example AddStrategy that calculates group advantages.""" - - @abstractmethod - def group_experiences(self, exps: List[Experience]) -> Dict[str, List[Experience]]: - """Group experiences by a certain criterion. - - Args: - exps (List[Experience]): List of experiences to be grouped. - - Returns: - Dict[str, List[Experience]]: A dictionary where keys are group identifiers and values are lists of experiences. - """ - - @abstractmethod - def calculate_group_advantage( - self, group_id: str, exps: List[Experience] - ) -> Tuple[List[Experience], Dict]: - """Calculate advantages for a group of experiences. - - Args: - group_id (str): The identifier for the group of experiences. - exps (List[Experience]): List of experiences in the group. - - Returns: - Tuple[List[Experience], Dict]: A tuple containing the modified list of experiences and a dictionary of metrics. - """ - - async def add(self, exps: List[Experience], step: int) -> Tuple[int, Dict]: - if len(exps) == 0: - return 0, {} - exp_groups = self.group_experiences(exps) - cnt = 0 - metric_list = [] - tasks = [] - for group_id, group_exps in exp_groups.items(): - group_exps, group_metrics = self.calculate_group_advantage(group_id, group_exps) - metric_list.append(group_metrics) - cnt += len(group_exps) - if len(group_exps) > 0: - tasks.append(self.writer.write_async(group_exps)) - if tasks: - await asyncio.gather(*tasks) - try: - metrics = gather_metrics(metric_list, "group_advantages") - except ValueError: - metrics = {} # empty metric list causes ValueError, ignore it - return cnt, metrics - - -@ADD_STRATEGY.register_module("reinforce") -class REINFORCEAddStrategy(GroupAdvantageStrategy): - """An example AddStrategy that simply use rewards as advantages.""" - - def __init__(self, writer: BufferWriter, **kwargs) -> None: - super().__init__(writer) - - def group_experiences(self, exps): - return group_by(exps, id_type="task") - - def calculate_group_advantage( - self, group_id: str, exps: List[Experience] - ) -> Tuple[List[Experience], Dict]: - with torch.no_grad(): - rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) - group_reward_mean = torch.mean(rewards) - for exp in exps: - score = torch.tensor(exp.reward, dtype=torch.float32) - exp.advantages = score * exp.action_mask - exp.returns = exp.advantages.clone() - - metrics = { - "reward_mean": group_reward_mean.item(), - } - - return exps, metrics - - @classmethod - def default_args(cls) -> dict: - return {} - - -@ADD_STRATEGY.register_module("grpo") -class GRPOAddStrategy(GroupAdvantageStrategy): - """An example AddStrategy that calculates GRPO advantages.""" - - def __init__(self, writer: BufferWriter, epsilon: float = 1e-6, **kwargs) -> None: - super().__init__(writer) - self.epsilon = epsilon - - def group_experiences(self, exps): - return group_by(exps, id_type="task") - - def calculate_group_advantage( - self, group_id: str, exps: List[Experience] - ) -> Tuple[List[Experience], Dict]: - with torch.no_grad(): - if len(exps) == 1: - group_reward_mean = torch.tensor(0.0) - group_reward_std = torch.tensor(1.0) - else: - rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) - group_reward_mean = torch.mean(rewards) - group_reward_std = torch.std(rewards) - for exp in exps: - score = (exp.reward - group_reward_mean) / (group_reward_std + self.epsilon) - exp.advantages = score * exp.action_mask - exp.returns = exp.advantages.clone() - - metrics = { - "reward_mean": group_reward_mean.item(), - "reward_std": group_reward_std.item(), - } - - return exps, metrics - - @classmethod - def default_args(cls) -> dict: - return {"epsilon": 1e-6} - - -@ADD_STRATEGY.register_module("opmd") -class OPMDAddStrategy(GroupAdvantageStrategy): - """An example AddStrategy that calculates OPMD advantages.""" - - def __init__( - self, writer: BufferWriter, opmd_baseline: str = "mean", tau: float = 1.0, **kwargs - ) -> None: - super().__init__(writer) - assert opmd_baseline in [ - "mean", - "logavgexp", - ], f"opmd_baseline must be 'mean' or 'logavgexp', got {opmd_baseline}" - self.opmd_baseline = opmd_baseline - self.tau = tau - - def group_experiences(self, exps): - return group_by(exps, id_type="task") - - def calculate_group_advantage( - self, group_id: str, exps: List[Experience] - ) -> Tuple[List[Experience], Dict]: - with torch.no_grad(): - if len(exps) == 1: - group_baseline = torch.tensor(0.0) - else: - group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) - if self.opmd_baseline == "mean": - group_baseline = torch.mean(group_rewards) - else: - group_baseline = self.tau * ( - torch.logsumexp(group_rewards / self.tau, dim=-1) - - torch.log(torch.tensor(len(exps))) - ) - for exp in exps: - score = exp.reward - group_baseline - exp.advantages = score * exp.action_mask - exp.returns = exp.advantages.clone() - metrics = { - "group_baseline": group_baseline, - } - return exps, metrics - - @classmethod - def default_args(cls) -> dict: - return {"opmd_baseline": "mean", "tau": 1.0} - - -@ADD_STRATEGY.register_module("reward_variance") -class RewardVarianceAddStrategy(AddStrategy): - """An example AddStrategy that filters experiences based on a reward variance threshold.""" - - def __init__(self, writer: BufferWriter, variance_threshold: float = 0.0, **kwargs) -> None: - super().__init__(writer) - self.variance_threshold = variance_threshold - - async def add(self, experiences: List[Experience], step: int) -> Tuple[int, Dict]: - cnt = 0 - metrics = {} - tasks = [] - with Timer(metrics, "add_strategy_time"): - grouped_experiences = group_by(experiences, id_type="task") - for _, group_exps in grouped_experiences.items(): - if len(group_exps) < 2: - continue - rewards = [exp.reward for exp in group_exps] - variance = np.var(rewards) - if variance <= self.variance_threshold: - continue - cnt += len(group_exps) - tasks.append(self.writer.write_async(group_exps)) - if tasks: - await asyncio.gather(*tasks) - return cnt, metrics - - @classmethod - def default_args(cls) -> dict: - return {"variance_threshold": 0.0} - - -def group_by( - experiences: List[Experience], id_type: Literal["task", "run", "step"] -) -> Dict[str, List[Experience]]: - """Group experiences by ID.""" - if id_type == "task": - id_type = "tid" - elif id_type == "run": - id_type = "rid" - elif id_type == "step": - id_type = "sid" - else: - raise ValueError(f"Unknown id_type: {id_type}") - grouped = {} - for exp in experiences: - group_id = getattr(exp.eid, id_type) - if group_id not in grouped: - grouped[group_id] = [] - grouped[group_id].append(exp) - return grouped diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index 7e99c2cb5c..c5ca9d4899 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -15,6 +15,7 @@ OPMDGroupAdvantage, ) from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn +from trinity.algorithm.advantage_fn.reinforce_advantage import REINFORCEGroupAdvantage from trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage import ( REINFORCEPLUSPLUSAdvantageFn, ) @@ -34,4 +35,5 @@ "RLOOAdvantageFn", "OPMDAdvantageFn", "OPMDGroupAdvantage", + "REINFORCEGroupAdvantage", ] diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 81286b4d6d..c7720b3ba3 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -190,7 +190,7 @@ class TOPRAlgorithm(AlgorithmType): def default_config(cls) -> Dict: return { "repeat_times": 2, - "add_strategy": "reinforce", # or simply use grpo + "advantage_fn": "reinforce", # or simply use grpo "sample_strategy": "warmup", "policy_loss_fn": "topr", "kl_penalty_fn": "none", @@ -213,7 +213,7 @@ class CISPOAlgorithm(AlgorithmType): def default_config(cls) -> Dict: return { "repeat_times": 2, - "add_strategy": "grpo", + "advantage_fn": "grpo", "sample_strategy": "warmup", "policy_loss_fn": "cispo", "kl_penalty_fn": "none", diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index 545e8d7e9e..fca64d7630 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -1,12 +1,9 @@ -<<<<<<< HEAD -from trinity.algorithm.policy_loss_fn.cispo_policy_loss import CISPOPolicyLossFn -======= from trinity.algorithm.policy_loss_fn.chord_policy_loss import ( MIXCHORDPolicyLossFn, SFTISLossFn, SFTPhiLossFn, ) ->>>>>>> main +from trinity.algorithm.policy_loss_fn.cispo_policy_loss import CISPOPolicyLossFn from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn from trinity.algorithm.policy_loss_fn.gspo_policy_loss import GSPOLossFn from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn @@ -25,12 +22,9 @@ "SFTLossFn", "MIXPolicyLossFn", "GSPOLossFn", -<<<<<<< HEAD "TOPRPolicyLossFn", "CISPOPolicyLossFn", -======= "MIXCHORDPolicyLossFn", "SFTISLossFn", "SFTPhiLossFn", ->>>>>>> main ] From 3579e48de4a446b1e931abd224944b8066a02248 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Thu, 28 Aug 2025 16:47:19 +0800 Subject: [PATCH 06/14] add examples for topr and cispo --- examples/cispo_gsm8k/README.md | 5 ++ examples/cispo_gsm8k/gsm8k.yaml | 69 +++++++++++++++++++++++++++ examples/cispo_gsm8k/train_gsm8k.yaml | 49 +++++++++++++++++++ examples/topr_gsm8k/README.md | 5 ++ examples/topr_gsm8k/gsm8k.yaml | 69 +++++++++++++++++++++++++++ examples/topr_gsm8k/train_gsm8k.yaml | 49 +++++++++++++++++++ 6 files changed, 246 insertions(+) create mode 100644 examples/cispo_gsm8k/README.md create mode 100644 examples/cispo_gsm8k/gsm8k.yaml create mode 100644 examples/cispo_gsm8k/train_gsm8k.yaml create mode 100644 examples/topr_gsm8k/README.md create mode 100644 examples/topr_gsm8k/gsm8k.yaml create mode 100644 examples/topr_gsm8k/train_gsm8k.yaml diff --git a/examples/cispo_gsm8k/README.md b/examples/cispo_gsm8k/README.md new file mode 100644 index 0000000000..aa7f5e568b --- /dev/null +++ b/examples/cispo_gsm8k/README.md @@ -0,0 +1,5 @@ +# CISPO on GSM8K dataset + +This example shows the usage of [CISPO](https://arxiv.org/abs/2506.13585) on the GSM8K dataset. + +The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml). diff --git a/examples/cispo_gsm8k/gsm8k.yaml b/examples/cispo_gsm8k/gsm8k.yaml new file mode 100644 index 0000000000..d7dd956c62 --- /dev/null +++ b/examples/cispo_gsm8k/gsm8k.yaml @@ -0,0 +1,69 @@ +project: "Trinity-RFT-gsm8k" +name: "qwen2.5-1.5B-gsm8k-cispo" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: cispo + repeat_times: 8 +model: + model_path: /PATH/TO/MODEL/ + max_response_tokens: 1024 + max_model_len: 1280 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 96 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'train' + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'test' + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' + # sft_warmup_steps: 0 + # sft_warmup_dataset: # Uncomment these to enable sft warmup + # name: warmup_data + # storage_type: file + # path: '/PATH/TO/WARMUP_DATA/' +explorer: + eval_interval: 50 + runner_num: 32 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 4 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/cispo_gsm8k/train_gsm8k.yaml' + save_interval: 100 diff --git a/examples/cispo_gsm8k/train_gsm8k.yaml b/examples/cispo_gsm8k/train_gsm8k.yaml new file mode 100644 index 0000000000..7af0c9e19f --- /dev/null +++ b/examples/cispo_gsm8k/train_gsm8k.yaml @@ -0,0 +1,49 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # total_training_steps: null + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/examples/topr_gsm8k/README.md b/examples/topr_gsm8k/README.md new file mode 100644 index 0000000000..b82ed2b0f6 --- /dev/null +++ b/examples/topr_gsm8k/README.md @@ -0,0 +1,5 @@ +# TOPR on GSM8K dataset + +This example shows the usage of [TOPR](https://arxiv.org/pdf/2503.14286v1) on the GSM8K dataset, with sync_interval=8. + +The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml). diff --git a/examples/topr_gsm8k/gsm8k.yaml b/examples/topr_gsm8k/gsm8k.yaml new file mode 100644 index 0000000000..f7b1eb6761 --- /dev/null +++ b/examples/topr_gsm8k/gsm8k.yaml @@ -0,0 +1,69 @@ +project: "Trinity-RFT-gsm8k" +name: "qwen2.5-1.5B-gsm8k-topr" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: topr + repeat_times: 8 +model: + model_path: /PATH/TO/MODEL/ + max_response_tokens: 1024 + max_model_len: 1280 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 96 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'train' + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'test' + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' + # sft_warmup_steps: 0 + # sft_warmup_dataset: # Uncomment these to enable sft warmup + # name: warmup_data + # storage_type: file + # path: '/PATH/TO/WARMUP_DATA/' +explorer: + eval_interval: 50 + runner_num: 32 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 8 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/topr_gsm8k/train_gsm8k.yaml' + save_interval: 100 diff --git a/examples/topr_gsm8k/train_gsm8k.yaml b/examples/topr_gsm8k/train_gsm8k.yaml new file mode 100644 index 0000000000..b6f3a4f2de --- /dev/null +++ b/examples/topr_gsm8k/train_gsm8k.yaml @@ -0,0 +1,49 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 2e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # total_training_steps: null + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False From d2cfbe09e0ec3b843bfae82a15e9a943095b20bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Thu, 28 Aug 2025 17:24:33 +0800 Subject: [PATCH 07/14] fix lr --- examples/topr_gsm8k/train_gsm8k.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/topr_gsm8k/train_gsm8k.yaml b/examples/topr_gsm8k/train_gsm8k.yaml index b6f3a4f2de..7af0c9e19f 100644 --- a/examples/topr_gsm8k/train_gsm8k.yaml +++ b/examples/topr_gsm8k/train_gsm8k.yaml @@ -15,7 +15,7 @@ actor_rollout_ref: shuffle: False ulysses_sequence_parallel_size: 1 # sp size optim: - lr: 2e-6 + lr: 1e-5 lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime # min_lr_ratio: null # only useful for warmup with cosine warmup_style: constant # select from constant/cosine From fedd1af025fcec2742e783f0c9a222315acf3124 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Thu, 28 Aug 2025 21:02:16 +0800 Subject: [PATCH 08/14] seperate mask clip for cispo --- .../policy_loss_fn/cispo_policy_loss.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py b/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py index d775077214..690147691b 100644 --- a/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py @@ -17,12 +17,16 @@ def __init__( backend: str = "verl", clip_range_low: float = 1.0, clip_range_high: float = 0.28, + enable_mask_clip: bool = False, + mask_clip_range_low: float = 1.0, + mask_clip_range_high: float = 0.28, ) -> None: super().__init__(backend=backend) self.clip_range_low = clip_range_low self.clip_range_high = clip_range_high - assert self.clip_range_low is not None, "clip_range_low must be specified." - assert self.clip_range_high is not None, "clip_range_high must be specified." + self.enable_mask_clip = enable_mask_clip + self.mask_clip_range_low = mask_clip_range_low + self.mask_clip_range_high = mask_clip_range_high def __call__( # type: ignore self, @@ -43,12 +47,17 @@ def __call__( # type: ignore # mask = 0 if ratio < 1.0 - self.clip_range_low and advantages < 0 # else 1 mask = torch.ones_like(ratio) - mask = torch.where( - (ratio > 1.0 + self.clip_range_high) & (advantages > 0), torch.zeros_like(ratio), mask - ) - mask = torch.where( - (ratio < 1.0 - self.clip_range_low) & (advantages < 0), torch.zeros_like(ratio), mask - ) + if self.enable_mask_clip: + mask = torch.where( + (ratio > 1.0 + self.mask_clip_range_high) & (advantages > 0), + torch.zeros_like(ratio), + mask, + ) + mask = torch.where( + (ratio < 1.0 - self.mask_clip_range_low) & (advantages < 0), + torch.zeros_like(ratio), + mask, + ) cispo_loss = -advantages * ratio_clamped.detach() * mask.detach() * logprob @@ -73,4 +82,7 @@ def default_args(cls) -> Dict: return { "clip_range_low": 1.0, "clip_range_high": 0.28, + "enable_mask_clip": False, + "mask_clip_range_low": 1.0, + "mask_clip_range_high": 0.28, } From 5f2df786d13cca4094766e04da59a4298403772b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Mon, 1 Sep 2025 11:52:45 +0800 Subject: [PATCH 09/14] merge main again --- trinity/algorithm/algorithm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index ad8106aee9..6220496b1f 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -212,7 +212,7 @@ class TOPRAlgorithm(AlgorithmType): use_reference: bool = True compute_advantage_in_trainer: bool = False can_balance_batch: bool = True - schema: type = ExperienceModel + schema: str = "experience" @classmethod def default_config(cls) -> Dict: @@ -235,7 +235,7 @@ class CISPOAlgorithm(AlgorithmType): use_reference: bool = True compute_advantage_in_trainer: bool = False can_balance_batch: bool = True - schema: type = ExperienceModel + schema: str = "experience" @classmethod def default_config(cls) -> Dict: From 2c20092aa7364461311440d31ad304fa4cd0694a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Mon, 1 Sep 2025 12:00:26 +0800 Subject: [PATCH 10/14] fix precommit --- trinity/algorithm/policy_loss_fn/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index 739cfe999a..f66d7dbde4 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -11,8 +11,8 @@ from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn -from trinity.algorithm.policy_loss_fn.topr_policy_loss import TOPRPolicyLossFn from trinity.algorithm.policy_loss_fn.sppo_loss_fn import sPPOPolicyLossFn +from trinity.algorithm.policy_loss_fn.topr_policy_loss import TOPRPolicyLossFn __all__ = [ "POLICY_LOSS_FN", From 9340067ecdfe468ff62a5ef760e45839ed802f97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Mon, 1 Sep 2025 23:34:24 +0800 Subject: [PATCH 11/14] add missing file --- .../advantage_fn/reinforce_advantage.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 trinity/algorithm/advantage_fn/reinforce_advantage.py diff --git a/trinity/algorithm/advantage_fn/reinforce_advantage.py b/trinity/algorithm/advantage_fn/reinforce_advantage.py new file mode 100644 index 0000000000..ffcf398187 --- /dev/null +++ b/trinity/algorithm/advantage_fn/reinforce_advantage.py @@ -0,0 +1,39 @@ +"""Reinforce advantage computation""" + +from typing import Dict, List, Tuple + +import torch + +from trinity.algorithm.advantage_fn.advantage_fn import ( + ADVANTAGE_FN, + GroupAdvantage, +) +from trinity.common.experience import Experience, group_by + + +@ADVANTAGE_FN.register_module("reinforce") +class REINFORCEGroupAdvantage(GroupAdvantage): + """Reinforce Group Advantage computation""" + + def group_experiences(self, exps): + return group_by(exps, id_type="task") + + def calculate_group_advantage( + self, group_id: str, exps: List[Experience] + ) -> Tuple[List[Experience], Dict]: + with torch.no_grad(): + rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) + group_reward_mean = torch.mean(rewards) + for exp in exps: + score = torch.tensor(exp.reward, dtype=torch.float32) + exp.advantages = score * exp.action_mask + exp.returns = exp.advantages.clone() + + metrics = { + "reward_mean": group_reward_mean.item(), + } + return exps, metrics + + @classmethod + def default_args(cls) -> dict: + return {} From c07d914ac1011e4b1de19f23b82272ad8a08e52a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Mon, 1 Sep 2025 23:36:01 +0800 Subject: [PATCH 12/14] fix pre-commit --- trinity/algorithm/advantage_fn/reinforce_advantage.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/trinity/algorithm/advantage_fn/reinforce_advantage.py b/trinity/algorithm/advantage_fn/reinforce_advantage.py index ffcf398187..8c06451eda 100644 --- a/trinity/algorithm/advantage_fn/reinforce_advantage.py +++ b/trinity/algorithm/advantage_fn/reinforce_advantage.py @@ -4,10 +4,7 @@ import torch -from trinity.algorithm.advantage_fn.advantage_fn import ( - ADVANTAGE_FN, - GroupAdvantage, -) +from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, GroupAdvantage from trinity.common.experience import Experience, group_by From ac117db255d7a8c16f3fa2802235ef8aa4da9afb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Tue, 2 Sep 2025 11:41:59 +0800 Subject: [PATCH 13/14] update with main again --- examples/cispo_gsm8k/gsm8k.yaml | 1 - examples/sppo_gsm8k/gsm8k.yaml | 1 - examples/topr_gsm8k/gsm8k.yaml | 1 - 3 files changed, 3 deletions(-) diff --git a/examples/cispo_gsm8k/gsm8k.yaml b/examples/cispo_gsm8k/gsm8k.yaml index d7dd956c62..66a1777f8c 100644 --- a/examples/cispo_gsm8k/gsm8k.yaml +++ b/examples/cispo_gsm8k/gsm8k.yaml @@ -14,7 +14,6 @@ cluster: buffer: total_epochs: 1 batch_size: 96 - max_retry_times: 3 max_retry_interval: 1 explorer_input: taskset: diff --git a/examples/sppo_gsm8k/gsm8k.yaml b/examples/sppo_gsm8k/gsm8k.yaml index d72459afe3..d3f645f606 100644 --- a/examples/sppo_gsm8k/gsm8k.yaml +++ b/examples/sppo_gsm8k/gsm8k.yaml @@ -20,7 +20,6 @@ cluster: buffer: total_steps: 100 batch_size: 96 - max_retry_times: 3 max_retry_interval: 1 explorer_input: taskset: diff --git a/examples/topr_gsm8k/gsm8k.yaml b/examples/topr_gsm8k/gsm8k.yaml index f7b1eb6761..0dbc9590b8 100644 --- a/examples/topr_gsm8k/gsm8k.yaml +++ b/examples/topr_gsm8k/gsm8k.yaml @@ -14,7 +14,6 @@ cluster: buffer: total_epochs: 1 batch_size: 96 - max_retry_times: 3 max_retry_interval: 1 explorer_input: taskset: From 9c52bf8b6b7d55bb2e33461e030e3234fd65195a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Tue, 2 Sep 2025 11:45:27 +0800 Subject: [PATCH 14/14] update with main again and gain --- examples/cispo_gsm8k/gsm8k.yaml | 1 - examples/sppo_gsm8k/gsm8k.yaml | 1 - examples/topr_gsm8k/gsm8k.yaml | 1 - 3 files changed, 3 deletions(-) diff --git a/examples/cispo_gsm8k/gsm8k.yaml b/examples/cispo_gsm8k/gsm8k.yaml index 66a1777f8c..176c4a17bb 100644 --- a/examples/cispo_gsm8k/gsm8k.yaml +++ b/examples/cispo_gsm8k/gsm8k.yaml @@ -14,7 +14,6 @@ cluster: buffer: total_epochs: 1 batch_size: 96 - max_retry_interval: 1 explorer_input: taskset: name: gsm8k diff --git a/examples/sppo_gsm8k/gsm8k.yaml b/examples/sppo_gsm8k/gsm8k.yaml index d3f645f606..f6b08c8ca9 100644 --- a/examples/sppo_gsm8k/gsm8k.yaml +++ b/examples/sppo_gsm8k/gsm8k.yaml @@ -20,7 +20,6 @@ cluster: buffer: total_steps: 100 batch_size: 96 - max_retry_interval: 1 explorer_input: taskset: name: gsm8k diff --git a/examples/topr_gsm8k/gsm8k.yaml b/examples/topr_gsm8k/gsm8k.yaml index 0dbc9590b8..ed9081de1e 100644 --- a/examples/topr_gsm8k/gsm8k.yaml +++ b/examples/topr_gsm8k/gsm8k.yaml @@ -14,7 +14,6 @@ cluster: buffer: total_epochs: 1 batch_size: 96 - max_retry_interval: 1 explorer_input: taskset: name: gsm8k