Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/cispo_gsm8k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# CISPO on GSM8K dataset

This example shows the usage of [CISPO](https://arxiv.org/abs/2506.13585) on the GSM8K dataset.

The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml).
67 changes: 67 additions & 0 deletions examples/cispo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
project: "Trinity-RFT-gsm8k"
name: "qwen2.5-1.5B-gsm8k-cispo"
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
algorithm:
algorithm_type: cispo
repeat_times: 8
model:
model_path: /PATH/TO/MODEL/
max_response_tokens: 1024
max_model_len: 1280
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 1
batch_size: 96
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: 'openai/gsm8k'
subset_name: 'main'
split: 'train'
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 1.0
eval_tasksets:
- name: gsm8k-eval
storage_type: file
path: 'openai/gsm8k'
subset_name: 'main'
split: 'test'
format:
prompt_key: 'question'
response_key: 'answer'
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
name: gsm8k_buffer
storage_type: queue
path: 'sqlite:///gsm8k.db'
# sft_warmup_steps: 0
# sft_warmup_dataset: # Uncomment these to enable sft warmup
# name: warmup_data
# storage_type: file
# path: '/PATH/TO/WARMUP_DATA/'
explorer:
eval_interval: 50
runner_num: 32
rollout_model:
engine_type: vllm_async
engine_num: 2
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 4
sync_timeout: 1200
trainer:
trainer_type: 'verl'
trainer_config_path: 'examples/cispo_gsm8k/train_gsm8k.yaml'
save_interval: 100
49 changes: 49 additions & 0 deletions examples/cispo_gsm8k/train_gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
actor_rollout_ref:
hybrid_engine: True
model:
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: True # False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_micro_batch_size_per_gpu: 4
use_dynamic_bsz: True # False
ppo_max_token_len_per_gpu: 16384
grad_clip: 1.0
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
# min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size_per_gpu: 4
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size

trainer:
balance_batch: True
# total_training_steps: null
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
val_before_train: False
2 changes: 0 additions & 2 deletions examples/sppo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ cluster:
buffer:
total_steps: 100
batch_size: 96
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
Expand Down
5 changes: 5 additions & 0 deletions examples/topr_gsm8k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# TOPR on GSM8K dataset

This example shows the usage of [TOPR](https://arxiv.org/pdf/2503.14286v1) on the GSM8K dataset, with sync_interval=8.

The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml).
67 changes: 67 additions & 0 deletions examples/topr_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
project: "Trinity-RFT-gsm8k"
name: "qwen2.5-1.5B-gsm8k-topr"
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
algorithm:
algorithm_type: topr
repeat_times: 8
model:
model_path: /PATH/TO/MODEL/
max_response_tokens: 1024
max_model_len: 1280
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 1
batch_size: 96
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: 'openai/gsm8k'
subset_name: 'main'
split: 'train'
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 1.0
eval_tasksets:
- name: gsm8k-eval
storage_type: file
path: 'openai/gsm8k'
subset_name: 'main'
split: 'test'
format:
prompt_key: 'question'
response_key: 'answer'
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
name: gsm8k_buffer
storage_type: queue
path: 'sqlite:///gsm8k.db'
# sft_warmup_steps: 0
# sft_warmup_dataset: # Uncomment these to enable sft warmup
# name: warmup_data
# storage_type: file
# path: '/PATH/TO/WARMUP_DATA/'
explorer:
eval_interval: 50
runner_num: 32
rollout_model:
engine_type: vllm_async
engine_num: 2
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 8
sync_timeout: 1200
trainer:
trainer_type: 'verl'
trainer_config_path: 'examples/topr_gsm8k/train_gsm8k.yaml'
save_interval: 100
49 changes: 49 additions & 0 deletions examples/topr_gsm8k/train_gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
actor_rollout_ref:
hybrid_engine: True
model:
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: True # False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_micro_batch_size_per_gpu: 4
use_dynamic_bsz: True # False
ppo_max_token_len_per_gpu: 16384
grad_clip: 1.0
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
# min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size_per_gpu: 4
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size

trainer:
balance_batch: True
# total_training_steps: null
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
val_before_train: False
2 changes: 2 additions & 0 deletions trinity/algorithm/advantage_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
OPMDGroupAdvantage,
)
from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn
from trinity.algorithm.advantage_fn.reinforce_advantage import REINFORCEGroupAdvantage
from trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage import (
REINFORCEPLUSPLUSAdvantageFn,
)
Expand All @@ -35,5 +36,6 @@
"RLOOAdvantageFn",
"OPMDAdvantageFn",
"OPMDGroupAdvantage",
"REINFORCEGroupAdvantage",
"ASYMREAdvantageFn",
]
36 changes: 36 additions & 0 deletions trinity/algorithm/advantage_fn/reinforce_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Reinforce advantage computation"""

from typing import Dict, List, Tuple

import torch

from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, GroupAdvantage
from trinity.common.experience import Experience, group_by


@ADVANTAGE_FN.register_module("reinforce")
class REINFORCEGroupAdvantage(GroupAdvantage):
"""Reinforce Group Advantage computation"""

def group_experiences(self, exps):
return group_by(exps, id_type="task")

def calculate_group_advantage(
self, group_id: str, exps: List[Experience]
) -> Tuple[List[Experience], Dict]:
with torch.no_grad():
rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32)
group_reward_mean = torch.mean(rewards)
for exp in exps:
score = torch.tensor(exp.reward, dtype=torch.float32)
exp.advantages = score * exp.action_mask
exp.returns = exp.advantages.clone()

metrics = {
"reward_mean": group_reward_mean.item(),
}
return exps, metrics

@classmethod
def default_args(cls) -> dict:
return {}
46 changes: 46 additions & 0 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,52 @@ def check_config(cls, config: Config) -> None:
logger.warning("DPO must use KL loss. Set `algorithm.kl_loss_fn` to `k2`")


@ALGORITHM_TYPE.register_module("topr")
class TOPRAlgorithm(AlgorithmType):
"""TOPR algorithm. See https://arxiv.org/pdf/2503.14286v1"""

use_critic: bool = False
use_reference: bool = True
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
schema: str = "experience"

@classmethod
def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"advantage_fn": "reinforce", # or simply use grpo
"sample_strategy": "warmup",
"policy_loss_fn": "topr",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "default",
}


@ALGORITHM_TYPE.register_module("cispo")
class CISPOAlgorithm(AlgorithmType):
"""CISPO algorithm. See https://arxiv.org/abs/2506.13585"""

use_critic: bool = False
use_reference: bool = True
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
schema: str = "experience"

@classmethod
def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"advantage_fn": "grpo",
"sample_strategy": "warmup",
"policy_loss_fn": "cispo",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "default",
}


@ALGORITHM_TYPE.register_module("mix")
class MIXAlgorithm(AlgorithmType):
"""MIX algorithm."""
Expand Down
4 changes: 4 additions & 0 deletions trinity/algorithm/policy_loss_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
SFTISLossFn,
SFTPhiLossFn,
)
from trinity.algorithm.policy_loss_fn.cispo_policy_loss import CISPOPolicyLossFn
from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn
from trinity.algorithm.policy_loss_fn.gspo_policy_loss import GSPOLossFn
from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn
Expand All @@ -11,6 +12,7 @@
from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn
from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn
from trinity.algorithm.policy_loss_fn.sppo_loss_fn import sPPOPolicyLossFn
from trinity.algorithm.policy_loss_fn.topr_policy_loss import TOPRPolicyLossFn

__all__ = [
"POLICY_LOSS_FN",
Expand All @@ -21,6 +23,8 @@
"SFTLossFn",
"MIXPolicyLossFn",
"GSPOLossFn",
"TOPRPolicyLossFn",
"CISPOPolicyLossFn",
"MIXCHORDPolicyLossFn",
"SFTISLossFn",
"SFTPhiLossFn",
Expand Down
Loading