diff --git a/examples/cispo_gsm8k/README.md b/examples/cispo_gsm8k/README.md new file mode 100644 index 0000000000..aa7f5e568b --- /dev/null +++ b/examples/cispo_gsm8k/README.md @@ -0,0 +1,5 @@ +# CISPO on GSM8K dataset + +This example shows the usage of [CISPO](https://arxiv.org/abs/2506.13585) on the GSM8K dataset. + +The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml). diff --git a/examples/cispo_gsm8k/gsm8k.yaml b/examples/cispo_gsm8k/gsm8k.yaml new file mode 100644 index 0000000000..176c4a17bb --- /dev/null +++ b/examples/cispo_gsm8k/gsm8k.yaml @@ -0,0 +1,67 @@ +project: "Trinity-RFT-gsm8k" +name: "qwen2.5-1.5B-gsm8k-cispo" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: cispo + repeat_times: 8 +model: + model_path: /PATH/TO/MODEL/ + max_response_tokens: 1024 + max_model_len: 1280 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 96 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'train' + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'test' + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' + # sft_warmup_steps: 0 + # sft_warmup_dataset: # Uncomment these to enable sft warmup + # name: warmup_data + # storage_type: file + # path: '/PATH/TO/WARMUP_DATA/' +explorer: + eval_interval: 50 + runner_num: 32 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 4 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/cispo_gsm8k/train_gsm8k.yaml' + save_interval: 100 diff --git a/examples/cispo_gsm8k/train_gsm8k.yaml b/examples/cispo_gsm8k/train_gsm8k.yaml new file mode 100644 index 0000000000..7af0c9e19f --- /dev/null +++ b/examples/cispo_gsm8k/train_gsm8k.yaml @@ -0,0 +1,49 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # total_training_steps: null + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/examples/sppo_gsm8k/gsm8k.yaml b/examples/sppo_gsm8k/gsm8k.yaml index d72459afe3..f6b08c8ca9 100644 --- a/examples/sppo_gsm8k/gsm8k.yaml +++ b/examples/sppo_gsm8k/gsm8k.yaml @@ -20,8 +20,6 @@ cluster: buffer: total_steps: 100 batch_size: 96 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: gsm8k diff --git a/examples/topr_gsm8k/README.md b/examples/topr_gsm8k/README.md new file mode 100644 index 0000000000..b82ed2b0f6 --- /dev/null +++ b/examples/topr_gsm8k/README.md @@ -0,0 +1,5 @@ +# TOPR on GSM8K dataset + +This example shows the usage of [TOPR](https://arxiv.org/pdf/2503.14286v1) on the GSM8K dataset, with sync_interval=8. + +The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml). diff --git a/examples/topr_gsm8k/gsm8k.yaml b/examples/topr_gsm8k/gsm8k.yaml new file mode 100644 index 0000000000..ed9081de1e --- /dev/null +++ b/examples/topr_gsm8k/gsm8k.yaml @@ -0,0 +1,67 @@ +project: "Trinity-RFT-gsm8k" +name: "qwen2.5-1.5B-gsm8k-topr" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: topr + repeat_times: 8 +model: + model_path: /PATH/TO/MODEL/ + max_response_tokens: 1024 + max_model_len: 1280 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 96 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'train' + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'test' + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' + # sft_warmup_steps: 0 + # sft_warmup_dataset: # Uncomment these to enable sft warmup + # name: warmup_data + # storage_type: file + # path: '/PATH/TO/WARMUP_DATA/' +explorer: + eval_interval: 50 + runner_num: 32 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 8 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/topr_gsm8k/train_gsm8k.yaml' + save_interval: 100 diff --git a/examples/topr_gsm8k/train_gsm8k.yaml b/examples/topr_gsm8k/train_gsm8k.yaml new file mode 100644 index 0000000000..7af0c9e19f --- /dev/null +++ b/examples/topr_gsm8k/train_gsm8k.yaml @@ -0,0 +1,49 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # total_training_steps: null + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index f7e1932d98..52e316895a 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -16,6 +16,7 @@ OPMDGroupAdvantage, ) from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn +from trinity.algorithm.advantage_fn.reinforce_advantage import REINFORCEGroupAdvantage from trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage import ( REINFORCEPLUSPLUSAdvantageFn, ) @@ -35,5 +36,6 @@ "RLOOAdvantageFn", "OPMDAdvantageFn", "OPMDGroupAdvantage", + "REINFORCEGroupAdvantage", "ASYMREAdvantageFn", ] diff --git a/trinity/algorithm/advantage_fn/reinforce_advantage.py b/trinity/algorithm/advantage_fn/reinforce_advantage.py new file mode 100644 index 0000000000..8c06451eda --- /dev/null +++ b/trinity/algorithm/advantage_fn/reinforce_advantage.py @@ -0,0 +1,36 @@ +"""Reinforce advantage computation""" + +from typing import Dict, List, Tuple + +import torch + +from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, GroupAdvantage +from trinity.common.experience import Experience, group_by + + +@ADVANTAGE_FN.register_module("reinforce") +class REINFORCEGroupAdvantage(GroupAdvantage): + """Reinforce Group Advantage computation""" + + def group_experiences(self, exps): + return group_by(exps, id_type="task") + + def calculate_group_advantage( + self, group_id: str, exps: List[Experience] + ) -> Tuple[List[Experience], Dict]: + with torch.no_grad(): + rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) + group_reward_mean = torch.mean(rewards) + for exp in exps: + score = torch.tensor(exp.reward, dtype=torch.float32) + exp.advantages = score * exp.action_mask + exp.returns = exp.advantages.clone() + + metrics = { + "reward_mean": group_reward_mean.item(), + } + return exps, metrics + + @classmethod + def default_args(cls) -> dict: + return {} diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index c987bcf608..b17ba7abec 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -204,6 +204,52 @@ def check_config(cls, config: Config) -> None: logger.warning("DPO must use KL loss. Set `algorithm.kl_loss_fn` to `k2`") +@ALGORITHM_TYPE.register_module("topr") +class TOPRAlgorithm(AlgorithmType): + """TOPR algorithm. See https://arxiv.org/pdf/2503.14286v1""" + + use_critic: bool = False + use_reference: bool = True + compute_advantage_in_trainer: bool = False + can_balance_batch: bool = True + schema: str = "experience" + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "advantage_fn": "reinforce", # or simply use grpo + "sample_strategy": "warmup", + "policy_loss_fn": "topr", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } + + +@ALGORITHM_TYPE.register_module("cispo") +class CISPOAlgorithm(AlgorithmType): + """CISPO algorithm. See https://arxiv.org/abs/2506.13585""" + + use_critic: bool = False + use_reference: bool = True + compute_advantage_in_trainer: bool = False + can_balance_batch: bool = True + schema: str = "experience" + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "advantage_fn": "grpo", + "sample_strategy": "warmup", + "policy_loss_fn": "cispo", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } + + @ALGORITHM_TYPE.register_module("mix") class MIXAlgorithm(AlgorithmType): """MIX algorithm.""" diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index fd604f7f33..f66d7dbde4 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -3,6 +3,7 @@ SFTISLossFn, SFTPhiLossFn, ) +from trinity.algorithm.policy_loss_fn.cispo_policy_loss import CISPOPolicyLossFn from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn from trinity.algorithm.policy_loss_fn.gspo_policy_loss import GSPOLossFn from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn @@ -11,6 +12,7 @@ from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn from trinity.algorithm.policy_loss_fn.sppo_loss_fn import sPPOPolicyLossFn +from trinity.algorithm.policy_loss_fn.topr_policy_loss import TOPRPolicyLossFn __all__ = [ "POLICY_LOSS_FN", @@ -21,6 +23,8 @@ "SFTLossFn", "MIXPolicyLossFn", "GSPOLossFn", + "TOPRPolicyLossFn", + "CISPOPolicyLossFn", "MIXCHORDPolicyLossFn", "SFTISLossFn", "SFTPhiLossFn", diff --git a/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py b/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py new file mode 100644 index 0000000000..690147691b --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py @@ -0,0 +1,88 @@ +"""CISPO policy loss function. +Refer to https://arxiv.org/abs/2506.13585 for details. +""" + +from typing import Dict, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_mean + + +@POLICY_LOSS_FN.register_module("cispo") +class CISPOPolicyLossFn(PolicyLossFn): + def __init__( + self, + backend: str = "verl", + clip_range_low: float = 1.0, + clip_range_high: float = 0.28, + enable_mask_clip: bool = False, + mask_clip_range_low: float = 1.0, + mask_clip_range_high: float = 0.28, + ) -> None: + super().__init__(backend=backend) + self.clip_range_low = clip_range_low + self.clip_range_high = clip_range_high + self.enable_mask_clip = enable_mask_clip + self.mask_clip_range_low = mask_clip_range_low + self.mask_clip_range_high = mask_clip_range_high + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + negative_approx_kl = logprob - old_logprob + ratio = torch.exp(negative_approx_kl) + ppo_kl = masked_mean(-negative_approx_kl, action_mask) + ratio_clamped = torch.clamp( + ratio, min=1.0 - self.clip_range_low, max=1.0 + self.clip_range_high + ) + + # mask = 0 if ratio > 1.0 + self.clip_range_high and advantages > 0 + # mask = 0 if ratio < 1.0 - self.clip_range_low and advantages < 0 + # else 1 + mask = torch.ones_like(ratio) + if self.enable_mask_clip: + mask = torch.where( + (ratio > 1.0 + self.mask_clip_range_high) & (advantages > 0), + torch.zeros_like(ratio), + mask, + ) + mask = torch.where( + (ratio < 1.0 - self.mask_clip_range_low) & (advantages < 0), + torch.zeros_like(ratio), + mask, + ) + + cispo_loss = -advantages * ratio_clamped.detach() * mask.detach() * logprob + + loss = masked_mean(cispo_loss, action_mask) + masked_frac = masked_mean(mask, action_mask) + + metrics = { + "cispo_loss": loss.detach().item(), + "ppo_kl": ppo_kl.detach().item(), + "masked_frac": masked_frac.detach().item(), + } + + return loss, metrics + + @classmethod + def default_args(cls) -> Dict: + """ + In the original paper: + we did not impose a lower bound on the IS weight by setting clip_range_low to a high value, instead, we only tuned clip_range_high + + """ + return { + "clip_range_low": 1.0, + "clip_range_high": 0.28, + "enable_mask_clip": False, + "mask_clip_range_low": 1.0, + "mask_clip_range_high": 0.28, + } diff --git a/trinity/algorithm/policy_loss_fn/topr_policy_loss.py b/trinity/algorithm/policy_loss_fn/topr_policy_loss.py new file mode 100644 index 0000000000..b58618465d --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/topr_policy_loss.py @@ -0,0 +1,74 @@ +"""TOPR policy loss function. +Refer to https://arxiv.org/pdf/2503.14286v1 +""" +from typing import Dict, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_mean + + +@POLICY_LOSS_FN.register_module("topr") +class TOPRPolicyLossFn(PolicyLossFn): + def __init__( + self, + backend: str = "verl", + advantage_threshold: float = 0.0, + ) -> None: + super().__init__(backend=backend) + self.advantage_threshold = advantage_threshold + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, # In TOPR, this is actually the rewards R(x,y) + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + """ + Compute TOPR policy loss. + + In TOPR: + - α = [π(y|x)/μ(y|x)]_0^1 if R(x,y) <= threshold else 1 + - loss = -sg(α) * r(x,y) * log π(y|x) + """ + # in Orginal TOPR paper, advantages are simply rewards + # However, we can use advantages as rewards(Baseline Trick) + rewards = advantages + + # Compute ratio π(y|x) / μ(y|x) in log space for numerical stability + log_ratio = logprob - old_logprob + ratio = torch.exp(log_ratio) + ratio_clamped = torch.clamp(ratio, min=0.0, max=1.0) + + # Apply TOPR's conditional weighting: + # α = ratio clamp min=0 max=1 if R(x,y) <= threshold else 1 + alpha = torch.where( + rewards <= self.advantage_threshold, ratio_clamped, torch.ones_like(ratio) + ) + + # TOPR loss: l = -α * r(x,y) * log π(y|x) + # We want to maximize α * r(x,y) * log π(y|x), so minimize the negative + topr_loss = -alpha.detach() * rewards * logprob # detach alpha as it's used with stop-grad + + # Apply masking and compute mean + loss = masked_mean(topr_loss, action_mask) + + # Average alpha value for monitoring + avg_alpha = masked_mean(alpha, action_mask) + + metrics = { + "topr_loss": loss.detach().item(), + "avg_alpha": avg_alpha.detach().item(), + "avg_ratio": masked_mean(ratio, action_mask).detach().item(), + } + + return loss, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "advantage_threshold": 0.0, + }