From 041433056e9067e35a42a3943ecbaf89bef24518 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Mon, 18 Aug 2025 11:01:49 +0800 Subject: [PATCH 1/6] feat: migrate CHORD to new version with phi function support - Migrate CHORD from legacy code version to new code version - Enable phi function - Merge latest changes from main --- examples/mix_chord/README.md | 89 ++++++ examples/mix_chord/get_openr1_data.py | 205 ++++++++++++++ examples/mix_chord/mix_chord.yaml | 86 ++++++ examples/mix_chord/train_mix_chord.yaml | 48 ++++ trinity/algorithm/algorithm.py | 24 ++ .../entropy_loss_fn/entropy_loss_fn.py | 27 ++ trinity/algorithm/policy_loss_fn/__init__.py | 8 + .../policy_loss_fn/chord_policy_loss.py | 258 ++++++++++++++++++ .../sample_strategy/mix_sample_strategy.py | 13 +- trinity/trainer/verl/dp_actor.py | 1 + trinity/trainer/verl_trainer.py | 2 +- 11 files changed, 759 insertions(+), 2 deletions(-) create mode 100644 examples/mix_chord/README.md create mode 100644 examples/mix_chord/get_openr1_data.py create mode 100644 examples/mix_chord/mix_chord.yaml create mode 100644 examples/mix_chord/train_mix_chord.yaml create mode 100644 trinity/algorithm/policy_loss_fn/chord_policy_loss.py diff --git a/examples/mix_chord/README.md b/examples/mix_chord/README.md new file mode 100644 index 0000000000..9a74cb079e --- /dev/null +++ b/examples/mix_chord/README.md @@ -0,0 +1,89 @@ +# Example: CHORD Algorithm + +Below we show an example of implementing the [CHORD](https://arxiv.org/pdf/2508.11408) algorithm. + +Here we provide a basic runnable example demonstrating the core functionality of CHORD. The hyperparameters used in our experiments may not be optimal across different datasets—we encourage researchers to build upon this implementation and explore further improvements. + +If you are interested in implementing your own algorithm, you may refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md) for guidance. + +## How to run + +### Install Trinity-RFT + +First, you should install Trinity-RFT. + +Please follow the guide in [README.md](../../README.md) to install the dependencies and set up the environment. + +### Prepare the models and datasets + +Then you should prepare the models and datasets, and fill them in the configuration file. + +You should first download the model you want from Hugging Face or ModelScope, for example: +```bash +# Using Hugging Face +huggingface-cli download Qwen/Qwen2.5-1.5B-Instruct --local-dir $MODEL_PATH/Qwen/Qwen2.5-1.5B-Instruct + +# Using ModelScope +modelscope download {model_name} --local_dir $MODEL_PATH/{model_name} +``` + +For the dataset, you need to prepare both the SFT dataset and the RL dataset. Below we provide a script for processing the dataset into our required format. + +Before running the dataset processing script, you need to fill in the tokenizer path in the script for filtering SFT data that is too long. +You can also change the sample size if you want. +```python +TOKENIZER_MODEL_PATH = "YOUR MODEL TOKENIZER PATH" +MAX_TOKEN_LENGTH = 8196 +SFT_SAMPLE_SIZE = 5000 +PREFERENCE_SAMPLE_SIZE = 20000 +``` + +Then just run the script: +```bash +python examples/mix_chord/get_openr1_data.py +``` +This may take a while to run. + +> **Note**: Here we provide scripts for sampling SFT and RL data from the OpenR1 dataset, but unfortunately, since our original experiments did not use a fixed random seed, the data selection and ordering may differ from the paper. + +### Modify the running script + +Fill in the config files in [`mix_chord.yaml`](mix_chord.yaml) and [`train_mix_chord.yaml`](train_mix_chord.yaml). + +### Run the script + +```bash +# Stop existing ray processes +ray stop + +# Start ray +ray start --head + +# Run Trinity +trinity run --config examples/mix_chord/mix_chord.yaml +``` + +## Citation + +If you find this code useful, please consider citing our paper: +```bibtex +@misc{TrinityRFT, + title={Trinity-RFT: A General-Purpose and Unified Framework for Reinforcement Fine-Tuning of Large Language Models}, + author={Xuchen Pan and Yanxi Chen and Yushuo Chen and Yuchang Sun and Daoyuan Chen and Wenhao Zhang and Yuexiang Xie and Yilun Huang and Yilei Zhang and Dawei Gao and Weijie Shi and Yaliang Li and Bolin Ding and Jingren Zhou}, + year={2025}, + eprint={2505.17826}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2505.17826}, +} + +@misc{MIXCHORD, + title={On-Policy RL Meets Off-Policy Experts: Harmonizing Supervised Fine-Tuning and Reinforcement Learning via Dynamic Weighting}, + author={Wenhao Zhang and Yuexiang Xie and Yuchang Sun and Yanxi Chen and Guoyin Wang and Yaliang Li and Bolin Ding and Jingren Zhou}, + year={2025}, + eprint={2508.11408}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2508.11408}, +} +``` diff --git a/examples/mix_chord/get_openr1_data.py b/examples/mix_chord/get_openr1_data.py new file mode 100644 index 0000000000..71d188066c --- /dev/null +++ b/examples/mix_chord/get_openr1_data.py @@ -0,0 +1,205 @@ +""" +We provide scripts for generating the SFT and RL dataset. +""" + +import json +import os +import random + +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +# import re + +# Set random seed for reproducibility +random.seed(42) + +# Configuration parameters +TOKENIZER_MODEL_PATH = "YOUR MODEL TOKENIZER PATH" +MAX_TOKEN_LENGTH = 8196 +SFT_SAMPLE_SIZE = 5000 +RL_SAMPLE_SIZE = 20000 +SYSTEM_PROMPT = """You are a helpful assistant that solves MATH problems. You should first thinks about the reasoning process in mind and then provides the user with the answer. You should present your reasoning process using the format: \n ...your reasoning process here... \n first. You should always include your final answer in \\boxed{} as closed-form results.""" + + +def can_convert_to_int(answer): + """Check if answer can be directly converted to integer""" + try: + int(answer) + return True + except (ValueError, TypeError): + return False + + +def contains_chinese(text): + """There are many incorrect translated problems in OpenR1. We may want to filter them out.""" + for char in text: + if "\u4e00" <= char <= "\u9fff": + return True + if "\u3400" <= char <= "\u4dbf": + return True + if "\u20000" <= char <= "\u2a6df": + return True + return False + + +def process_dataset(openr1_ds, tokenizer): + """Process dataset and filter out instances that don't meet criteria""" + processed_data = [] + + for instance in tqdm(openr1_ds, desc="Processing dataset"): + # Filter out answers that cannot be directly converted to int + # if not can_convert_to_int(instance["answer"]): + # continue + + if contains_chinese(instance["problem"]): + continue + + # Process generations + generations_keeped = [] + correctness_list = instance.get("correctness_math_verify", []) + generations = instance.get("generations", []) + + for i, generation in enumerate(generations): + # Check correctness_math_verify + if i >= len(correctness_list) or not correctness_list[i]: + continue + + # Check token length + tokenized_length = len(tokenizer.tokenize(generation)) + if tokenized_length > MAX_TOKEN_LENGTH: + continue + + generations_keeped.append(generation) + + # Add to processed data if there are kept generations + if generations_keeped: + processed_data.append( + { + "problem": instance["problem"], + "answer": instance["answer"], + "generations": generations_keeped, + } + ) + + return processed_data + + +def create_sft_dataset(data, sample_size): + """Create SFT dataset with message format""" + sft_messages = [] + + # Random sample specified number of instances + sampled_data = random.sample(data, min(sample_size, len(data))) + + for instance in sampled_data: + # Randomly select one generation as response + generation = random.choice(instance["generations"]) + + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": instance["problem"]}, + {"role": "assistant", "content": generation}, + ] + sft_messages.append({"messages": messages}) + + return sft_messages, sampled_data + + +def create_rl_dataset(data, sample_size, output_dir): + """Create RL dataset in HuggingFace format""" + + filtered_data = [d for d in data if can_convert_to_int(d["answer"])] + print("Number of instances can convert to int: ", len(filtered_data)) + + # Filter instances with at least 2 generations + filtered_data = [d for d in filtered_data if len(d["generations"]) >= 2] + print(f"Number of instances with >= 2 generations: {len(filtered_data)}") + + # or No filter + # filtered_data = data + + # Random sample + sampled_data = random.sample(filtered_data, min(sample_size, len(filtered_data))) + + # Prepare data in HuggingFace format (only problem and answer) + rl_data = [] + for instance in sampled_data: + rl_data.append({"problem": instance["problem"], "answer": instance["answer"]}) + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Save as JSONL format for HuggingFace datasets + train_file = os.path.join(output_dir, "train.jsonl") + with open(train_file, "w", encoding="utf-8") as f: + for item in rl_data: + f.write(json.dumps(item, ensure_ascii=False) + "\n") + + # Create dataset_dict.json for HuggingFace format + dataset_info = { + "citation": "", + "description": "OpenR1 RLVR dataset subset", + "splits": {"train": {"name": "train", "num_examples": len(rl_data)}}, + } + + with open(os.path.join(output_dir, "dataset_dict.json"), "w", encoding="utf-8") as f: + json.dump(dataset_info, f, indent=2) + + print(f"Saved RL dataset to {output_dir}") + print(f"Total instances: {len(rl_data)}") + + return sampled_data # Return sampled data with generations for reference + + +def save_json(data, filename): + """Save data to JSON file""" + with open(filename, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=4) + print(f"Saved {len(data)} instances to {filename}") + + +def main(): + # Load dataset from HuggingFace + print("Loading dataset from HuggingFace...") + openr1_ds = load_dataset("open-r1/OpenR1-Math-220k", "default", split="train").to_list() + + print(f"Original dataset size: {len(openr1_ds)}") + + # Load tokenizer + print("Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL_PATH, use_fast=False) + + # Process dataset with filtering + print("Processing dataset with filters...") + processed_data = process_dataset(openr1_ds, tokenizer) + print(f"Processed dataset size: {len(processed_data)}") + + # Create SFT dataset + print(f"\nCreating SFT dataset (sampling {SFT_SAMPLE_SIZE} instances)...") + sft_dataset, sampled_for_sft = create_sft_dataset(processed_data, SFT_SAMPLE_SIZE) + save_json(sft_dataset, "openr1_sft_dataset.json") + + # Create RL dataset from remaining data + print("\nCreating RL dataset...") + # Remove instances already used for SFT + remaining_data = [d for d in processed_data if d not in sampled_for_sft] + print(f"Remaining data after SFT sampling: {len(remaining_data)}") + + # Create RL dataset in HuggingFace format + rl_output_dir = "openr1_rl_dataset" + sampled_rl_data = create_rl_dataset(remaining_data, RL_SAMPLE_SIZE, rl_output_dir) + + # Optionally save the full RL data with generations for reference + # save_json(sampled_rl_data, "openr1_rl_dataset_with_generations.json") + + print("\n" + "=" * 50) + print("Dataset generation completed!") + print(f"SFT dataset: {len(sft_dataset)} instances") + print(f"RL dataset: {len(sampled_rl_data)} instances") + print("=" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/mix_chord/mix_chord.yaml b/examples/mix_chord/mix_chord.yaml new file mode 100644 index 0000000000..1d81c97fa8 --- /dev/null +++ b/examples/mix_chord/mix_chord.yaml @@ -0,0 +1,86 @@ +project: "mix_chord" +name: "test_mix_chord" +checkpoint_root_dir: /PATH/TO/checkpoints/ +algorithm: + algorithm_type: mix_chord + repeat_times: 8 + kl_loss_fn_args: + kl_coef: 0.0 + sample_strategy_args: + expert_data_ratio: 0.20 + policy_loss_fn_args: # feel free to change, we encourage you to try out different hyperparameters + mu_warmup_steps: 200 # 0 for chord-mu and chord-phi + mu_decay_steps: 400 # 200 for chord-mu and 0 for chord-phi + mu_peak: 0.5 # 0.9 for chord-mu and 0.1 for chord-phi + mu_valley: 0.02 # 0.05 for chord-mu and 0.1 for chord-phi + enable_phi_function: true # false for chord-mu and true for chord-phi + clip_range: 0.2 + use_token_level_loss_in_sft: true + use_dynamic_bsz: true + ppo_mini_batch_size: 320 + ppo_micro_batch_size_per_gpu: 4 + ngpus_trainer: 4 + train_batch_size_expert: 64 + train_batch_size_usual: 256 +model: + model_path: /cpfs/data/shared/qwen/Qwen2.5-7B-Instruct + max_response_tokens: 10240 + max_model_len: 11264 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 25 + batch_size: 40 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: openr1_data_filtered_int + storage_type: file + path: /your_path/to/rl_dataset + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 1.0 + logprobs: 0 + workflow_args: + with_think: true + eval_tasksets: [] # you can add your own eval tasksets here + default_workflow_type: 'math_boxed_workflow' + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue + path: 'sqlite:///test_mix_chord.db' + sft_warmup_dataset: + name: SFT_data + storage_type: file + algorithm_type: sft + path: /your_path/to/sft_dataset + split: 'train' + format: + prompt_type: messages + messages_key: 'messages' +explorer: + eval_interval: 10 + runner_num: 16 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/mix_chord/train_chord.yaml' + save_interval: 50 +monitor: + monitor_type: wandb diff --git a/examples/mix_chord/train_mix_chord.yaml b/examples/mix_chord/train_mix_chord.yaml new file mode 100644 index 0000000000..0e853f40d7 --- /dev/null +++ b/examples/mix_chord/train_mix_chord.yaml @@ -0,0 +1,48 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 25600 + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 2 # sp size + optim: + lr: 1e-6 # or 5e-6, larger lr with warm up can result in better performance for SFT training. + lr_warmup_steps_ratio: 0. # in experimence lr warmup is helpful for chord-mu + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index e652ac8f4e..eb45e839a3 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -195,6 +195,30 @@ def default_config(cls) -> Dict: "policy_loss_fn": "mix", "advantage_fn": "grpo", "sample_strategy": "mix", + "entropy_loss_fn": "mix", + } + + +@ALGORITHM_TYPE.register_module("mix_chord") +class MIXCHORDAlgorithm(AlgorithmType): + """MIX algorithm.""" + + use_critic: bool = False + use_reference: bool = True + compute_advantage_in_trainer: bool = False + use_rollout: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 8, + "add_strategy": "grpo", + "policy_loss_fn": "mix_chord", + "advantage_fn": "grpo", + "sample_strategy": "mix", + "entropy_loss_fn": "mix", } diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py index d6179a832c..7a81fe96bf 100644 --- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -59,6 +59,33 @@ def __call__( return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()} +@ENTROPY_LOSS_FN.register_module("mix") +class MixEntropyLossFn(EntropyLossFn): + """ + Basic entropy loss function for mix algorithm. + """ + + def __init__(self, entropy_coef: float): + self.entropy_coef = entropy_coef + + def __call__( + self, + entropy: torch.Tensor, + action_mask: torch.Tensor, + expert_mask: torch.Tensor = None, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + if expert_mask is None: + raise ValueError("expert_mask is required for MixEntropyLossFn") + assert ( + len(expert_mask) == entropy.shape[0] + ), f"Error: {len(expert_mask)=} != {entropy.shape[0]=}" + entropy = entropy[~expert_mask] + action_mask = action_mask[~expert_mask] + entropy_loss = masked_mean(entropy, action_mask) + return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()} + + @ENTROPY_LOSS_FN.register_module("none") class DummyEntropyLossFn(EntropyLossFn): """ diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index 4d88215a0a..7416e94d5b 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -1,3 +1,8 @@ +from trinity.algorithm.policy_loss_fn.chord_policy_loss import ( + MIXCHORDPolicyLossFn, + SFTISLossFn, + SFTPhiLossFn, +) from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn from trinity.algorithm.policy_loss_fn.gspo_policy_loss import GSPOLossFn from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn @@ -15,4 +20,7 @@ "SFTLossFn", "MIXPolicyLossFn", "GSPOLossFn", + "MIXCHORDPolicyLossFn", + "SFTISLossFn", + "SFTPhiLossFn", ] diff --git a/trinity/algorithm/policy_loss_fn/chord_policy_loss.py b/trinity/algorithm/policy_loss_fn/chord_policy_loss.py new file mode 100644 index 0000000000..b05d971bbd --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/chord_policy_loss.py @@ -0,0 +1,258 @@ +"""Implements the CHORD policy loss function.""" + +import math +from typing import Dict, Optional, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn +from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn +from trinity.algorithm.utils import masked_mean + + +def mu_schedule_function( + global_step: int, mu_warmup_steps: int, mu_decay_steps: int, mu_peak: float, mu_valley: float +): + """ + Cosine decay with warmup phase + """ + # Warmup + if global_step < mu_warmup_steps: + return (global_step / mu_warmup_steps) * mu_peak + + # Decay + if global_step >= (mu_warmup_steps + mu_decay_steps): + return mu_valley + + adjusted_step = global_step - mu_warmup_steps + cosine_decay = 0.5 * (1 + math.cos(math.pi * adjusted_step / mu_decay_steps)) + decayed_mu = (mu_peak - mu_valley) * cosine_decay + mu_valley + return decayed_mu + + +@POLICY_LOSS_FN.register_module("sft_is") +class SFTISLossFn(PolicyLossFn): + """ + SFT loss with importance sampling + """ + + def __init__(self, backend: str = "verl", use_token_level_loss: bool = True) -> None: + super().__init__(backend=backend) + self.use_token_level_loss = use_token_level_loss + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + token_prob = torch.exp(logprob) + if self.use_token_level_loss: + sft_loss = masked_mean(-logprob * token_prob.detach(), action_mask) + else: + sft_loss = masked_mean(-logprob * token_prob.detach(), action_mask, axis=1).mean() + return sft_loss, {"sft_is_loss": sft_loss.detach().item()} + + @classmethod + def default_args(cls): + return { + "use_token_level_loss": True, + } + + +def phi_function(token_prob): + """ + The phi function downweights token with extreme probability. + Feel free to modify this function. + """ + return token_prob * (1 - token_prob) + + +@POLICY_LOSS_FN.register_module("sft_phi") +class SFTPhiLossFn(PolicyLossFn): + """ + SFT loss with transformed phi function + """ + + def __init__( + self, backend: str = "verl", use_token_level_loss: bool = True, cutoff_prob: float = 1.0 + ) -> None: + super().__init__(backend=backend) + self.use_token_level_loss = use_token_level_loss + self.cutoff_prob = cutoff_prob + assert 0.0 <= self.cutoff_prob <= 1.0 + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + token_prob = torch.exp(logprob) + if self.cutoff_prob < 1.0: + logprob = torch.clamp(logprob, max=math.log(self.cutoff_prob)) + + weighted_phi = phi_function(token_prob) + + if self.use_token_level_loss: + sft_loss = masked_mean(-logprob * weighted_phi.detach(), action_mask) + else: + sft_loss = masked_mean(-logprob * weighted_phi.detach(), action_mask, axis=1).mean() + return sft_loss, {"sft_phi_loss": sft_loss.detach().item()} + + @classmethod + def default_args(cls): + return { + "use_token_level_loss": True, + "cutoff_prob": 1.0, + } + + +@POLICY_LOSS_FN.register_module("mix_chord") +class MIXCHORDPolicyLossFn(PolicyLossFn): + """Implements a mixed policy loss combining GRPO and SFT losses. + + This loss function applies different loss components to data based on whether + it comes from an expert or not, as indicated by `expert_mask`. It combines: + - GRPO loss (self.grpo_loss_fn) for non-expert data + - SFT loss (self.sft_loss_fn) for expert data + the weight of SFT loss is globally controled by `mu_schedule` function + the tokenwise weights are calculated using different SFT loss formulas + + The per-sample weights are normalized using either `experience_per_gpu` or + `gradient_accumulation`, depending on whether dynamic batch sizing is enabled, + to ensure consistent weighting across different batches of the same type experiences. + """ + + def __init__( + self, + backend: str = "verl", + mu_warmup_steps: int = 0, + mu_decay_steps: int = 0, + mu_peak: float = 0.1, + mu_valley: float = 0.1, + enable_phi_function: bool = True, + clip_range: Optional[float] = None, + clip_range_low: Optional[float] = None, + clip_range_high: Optional[float] = None, + use_dynamic_bsz: Optional[bool] = None, + ppo_mini_batch_size: int = 1, + ppo_micro_batch_size_per_gpu: int = 1, + ngpus_trainer: int = 1, + train_batch_size_usual: int = 1, + train_batch_size_expert: int = 1, + use_token_level_loss_in_sft: bool = True, + ) -> None: + super().__init__(backend=backend) + self.mu_warmup_steps = mu_warmup_steps + self.mu_decay_steps = mu_decay_steps + self.mu_peak = mu_peak + self.mu_valley = mu_valley + self.enable_phi_function = enable_phi_function + self.use_dynamic_bsz = use_dynamic_bsz + self.experience_per_gpu = ppo_mini_batch_size // ngpus_trainer + self.gradient_accumulation = ppo_mini_batch_size // ppo_micro_batch_size_per_gpu + self.train_batch_size_usual = train_batch_size_usual // ngpus_trainer + self.train_batch_size_expert = train_batch_size_expert // ngpus_trainer + self.grpo_loss_fn = PPOPolicyLossFn( + clip_range=clip_range, + clip_range_low=clip_range_low, + clip_range_high=clip_range_high, + ) + if enable_phi_function: + self.sft_loss_fn = SFTPhiLossFn(use_token_level_loss=use_token_level_loss_in_sft) + else: + self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft) + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + expert_mask: torch.Tensor, + step: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + assert ( + len(expert_mask) == logprob.shape[0] + ), f"Error: {len(expert_mask)=} != {logprob.shape[0]=}" + + assert len(step) == logprob.shape[0], f"Error: {len(step)=} != {logprob.shape[0]=}" + + assert ( + step.max().item() == step.min().item() + ), f"Error: {step.max().item()} != {step.min().item()}" + current_step = step.max().item() + + n_usual_exp = torch.sum(~expert_mask).item() + n_expert_exp = torch.sum(expert_mask).item() + + if self.use_dynamic_bsz: + per_micro_batch_weight_usual = self.experience_per_gpu / ( + logprob.shape[0] * self.train_batch_size_usual + ) + per_micro_batch_weight_expert = self.experience_per_gpu / ( + logprob.shape[0] * self.train_batch_size_expert + ) + else: + per_micro_batch_weight_usual = self.gradient_accumulation / self.train_batch_size_usual # type: ignore + per_micro_batch_weight_expert = self.gradient_accumulation / self.train_batch_size_expert # type: ignore + + if n_usual_exp > 0: + grpo_loss, grpo_metrics = self.grpo_loss_fn( + logprob[~expert_mask], + old_logprob[~expert_mask], + action_mask[~expert_mask], + advantages[~expert_mask], + **kwargs, + ) + grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual + grpo_metrics = { + k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items() + } + else: + grpo_loss = torch.tensor(0.0, device=logprob.device) + grpo_metrics = {} + + # SFT Loss (expert) + if n_expert_exp > 0: + sft_loss, sft_metrics = self.sft_loss_fn( + logprob[expert_mask], + action_mask[expert_mask], + ) + sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert + sft_metrics = { + k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items() + } + else: + sft_loss = torch.tensor(0.0, device=logprob.device) + sft_metrics = {} + + mu = mu_schedule_function( + current_step, self.mu_warmup_steps, self.mu_decay_steps, self.mu_peak, self.mu_valley + ) + + loss = (1 - mu) * grpo_loss + mu * sft_loss + + metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} + metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) + metrics.update({"loss": loss.item()}) + metrics.update({"mu": mu}) + + return loss, metrics + + @classmethod + def default_args(cls) -> Dict: + """ + mu_warmup_steps: int, mu_decay_steps: int, mu_peak: float, mu_valley: float + """ + return { + "mu_warmup_steps": 0, + "mu_decay_steps": 0, + "mu_peak": 0.1, + "mu_valley": 0.1, + "clip_range": 0.2, + "enable_phi_function": True, + } diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 32a34834bf..a66ca32bf7 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -52,16 +52,22 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: if exp.info is None: exp.info = {} exp.info["is_expert"] = False + exp.info["step"] = step expert_exp_list = await self.expert_exp_buffer.read_async() for exp in expert_exp_list: + # we add fake rewards and logprobs to make it compatible exp.reward = 0.0 exp.logprobs = torch.zeros_like( exp.tokens[exp.prompt_length :], dtype=torch.float32 ) + exp.advantages = torch.zeros_like( + exp.tokens[exp.prompt_length :], dtype=torch.float32 + ) if exp.info is None: exp.info = {} exp.info["is_expert"] = True + exp.info["step"] = step exp_list = usual_exp_list + expert_exp_list repr_samples = representative_sample(exp_list) @@ -75,7 +81,12 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: source_field="is_expert", destination_field="expert_mask", data_type=torch.bool, - ) + ), + CustomField( + source_field="step", + destination_field="step", + data_type=torch.int32, + ), ], ) # type: ignore return exps, metrics, repr_samples diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index c1ead6cbf6..e6d5a91458 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -196,6 +196,7 @@ def update_policy(self, data: DataProto): # noqa: C901 entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( # type: ignore entropy=entropy, action_mask=response_mask, + **data, ) prefix_metrics( src_metrics=entropy_loss_metrics, diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 59453592d2..e1707629a8 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -375,7 +375,7 @@ def save_checkpoint(self, block_until_saved: bool = False) -> None: self._save_checkpoint() if block_until_saved: self.actor_rollout_wg.wait_on_save_thread() - if self.algorithm.use_critic: + if self.algorithm and self.algorithm.use_critic: self.critic_wg.wait_on_save_thread() def sync_weight(self) -> None: From 83e5c400e991bf59422b41e77b9c00d638344d19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Mon, 18 Aug 2025 17:55:46 +0800 Subject: [PATCH 2/6] fix old mix config bug --- examples/mix_math/mix_math.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml index e4c4dd9172..f86c09527d 100644 --- a/examples/mix_math/mix_math.yaml +++ b/examples/mix_math/mix_math.yaml @@ -9,8 +9,8 @@ algorithm: policy_loss_fn_args: mu: 0.1 clip_range: 0.2 - use_token_level_loss_in_sft: False - use_dynamic_bsz: False + use_token_level_loss_in_sft: true + use_dynamic_bsz: true repeat_times: 8 ppo_mini_batch_size: 256 ppo_micro_batch_size_per_gpu: 4 From 382c38edee3207850e6488adcd188a927f4b017e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Mon, 18 Aug 2025 17:57:16 +0800 Subject: [PATCH 3/6] old mix algorithm config bug fix --- examples/mix_math/mix_math.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml index f86c09527d..4dc588cc51 100644 --- a/examples/mix_math/mix_math.yaml +++ b/examples/mix_math/mix_math.yaml @@ -1,22 +1,22 @@ project: "mix_math" -name: "expert0.25_mu0.1" +name: "expert0.20_mu0.1" checkpoint_root_dir: /PATH/TO/CHECKPOINT/ algorithm: algorithm_type: mix repeat_times: 8 sample_strategy_args: - expert_data_ratio: 0.25 + expert_data_ratio: 0.20 policy_loss_fn_args: mu: 0.1 clip_range: 0.2 use_token_level_loss_in_sft: true use_dynamic_bsz: true repeat_times: 8 - ppo_mini_batch_size: 256 + ppo_mini_batch_size: 320 ppo_micro_batch_size_per_gpu: 4 ngpus_trainer: 4 train_batch_size_expert: 64 - train_batch_size_usual: 192 + train_batch_size_usual: 256 model: model_path: /PATH/TO/MODEL/ max_response_tokens: 10240 From 68b7d319177a2df01d3f5acae4a611a00ddf7e76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Mon, 18 Aug 2025 18:13:14 +0800 Subject: [PATCH 4/6] fix config --- examples/mix_chord/mix_chord.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/mix_chord/mix_chord.yaml b/examples/mix_chord/mix_chord.yaml index 1d81c97fa8..b5e994dd95 100644 --- a/examples/mix_chord/mix_chord.yaml +++ b/examples/mix_chord/mix_chord.yaml @@ -1,6 +1,6 @@ project: "mix_chord" name: "test_mix_chord" -checkpoint_root_dir: /PATH/TO/checkpoints/ +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ algorithm: algorithm_type: mix_chord repeat_times: 8 @@ -23,7 +23,7 @@ algorithm: train_batch_size_expert: 64 train_batch_size_usual: 256 model: - model_path: /cpfs/data/shared/qwen/Qwen2.5-7B-Instruct + model_path: /PATH/TO/MODEL/ max_response_tokens: 10240 max_model_len: 11264 cluster: @@ -38,7 +38,7 @@ buffer: taskset: name: openr1_data_filtered_int storage_type: file - path: /your_path/to/rl_dataset + path: /PATH/TO/RL_DATASET format: prompt_key: 'problem' response_key: 'answer' @@ -58,7 +58,7 @@ buffer: name: SFT_data storage_type: file algorithm_type: sft - path: /your_path/to/sft_dataset + path: /PATH/TO/SFT_DATASET split: 'train' format: prompt_type: messages From b32a2206918f973f71ae886917179f1747484b3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Mon, 18 Aug 2025 18:20:13 +0800 Subject: [PATCH 5/6] resolve comments --- trinity/algorithm/policy_loss_fn/chord_policy_loss.py | 7 +++---- trinity/algorithm/policy_loss_fn/mix_policy_loss.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/trinity/algorithm/policy_loss_fn/chord_policy_loss.py b/trinity/algorithm/policy_loss_fn/chord_policy_loss.py index b05d971bbd..dc1a5504bc 100644 --- a/trinity/algorithm/policy_loss_fn/chord_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/chord_policy_loss.py @@ -13,9 +13,9 @@ def mu_schedule_function( global_step: int, mu_warmup_steps: int, mu_decay_steps: int, mu_peak: float, mu_valley: float -): +) -> float: """ - Cosine decay with warmup phase + Computes a cosine decay schedule with a warmup phase for the mu parameter. """ # Warmup if global_step < mu_warmup_steps: @@ -238,8 +238,7 @@ def __call__( # type: ignore metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) - metrics.update({"loss": loss.item()}) - metrics.update({"mu": mu}) + metrics.update({"loss": loss.item(), "mu": mu}) return loss, metrics diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index b8c6a54a50..2f70b788f4 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -114,7 +114,7 @@ def __call__( # type: ignore metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) - metrics.update({"loss": loss.item()}) + metrics["loss"] = loss.item() return loss, metrics From e191b119cdb51b2b09bd550988a1ec4947ea64ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Mon, 18 Aug 2025 19:46:51 +0800 Subject: [PATCH 6/6] fix more config promblems --- examples/mix_chord/mix_chord.yaml | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/mix_chord/mix_chord.yaml b/examples/mix_chord/mix_chord.yaml index b5e994dd95..caa44c573b 100644 --- a/examples/mix_chord/mix_chord.yaml +++ b/examples/mix_chord/mix_chord.yaml @@ -3,7 +3,7 @@ name: "test_mix_chord" checkpoint_root_dir: /PATH/TO/CHECKPOINT/ algorithm: algorithm_type: mix_chord - repeat_times: 8 + repeat_times: 8 # or 16 for better performance in math related tasks kl_loss_fn_args: kl_coef: 0.0 sample_strategy_args: @@ -17,11 +17,11 @@ algorithm: clip_range: 0.2 use_token_level_loss_in_sft: true use_dynamic_bsz: true - ppo_mini_batch_size: 320 + ppo_mini_batch_size: 320 # 320 = 256 + 64; if you set repeat times = 16, then it shoudle be 32 * 16 + 64 ppo_micro_batch_size_per_gpu: 4 ngpus_trainer: 4 train_batch_size_expert: 64 - train_batch_size_usual: 256 + train_batch_size_usual: 256 # (40 batchsize * (1 - 0.2 expert_data_ratio)) * 8 repeat times model: model_path: /PATH/TO/MODEL/ max_response_tokens: 10240 @@ -30,7 +30,7 @@ cluster: node_num: 1 gpu_per_node: 8 buffer: - total_epochs: 25 + total_epochs: 4 batch_size: 40 max_retry_times: 3 max_retry_interval: 1 @@ -55,6 +55,7 @@ buffer: storage_type: queue path: 'sqlite:///test_mix_chord.db' sft_warmup_dataset: + total_epochs: 25 name: SFT_data storage_type: file algorithm_type: sft @@ -80,7 +81,7 @@ synchronizer: sync_timeout: 1200 trainer: trainer_type: 'verl' - trainer_config_path: 'examples/mix_chord/train_chord.yaml' + trainer_config_path: 'examples/mix_chord/train_mix_chord.yaml' save_interval: 50 monitor: monitor_type: wandb