diff --git a/examples/mix_chord/README.md b/examples/mix_chord/README.md new file mode 100644 index 0000000000..9a74cb079e --- /dev/null +++ b/examples/mix_chord/README.md @@ -0,0 +1,89 @@ +# Example: CHORD Algorithm + +Below we show an example of implementing the [CHORD](https://arxiv.org/pdf/2508.11408) algorithm. + +Here we provide a basic runnable example demonstrating the core functionality of CHORD. The hyperparameters used in our experiments may not be optimal across different datasets—we encourage researchers to build upon this implementation and explore further improvements. + +If you are interested in implementing your own algorithm, you may refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md) for guidance. + +## How to run + +### Install Trinity-RFT + +First, you should install Trinity-RFT. + +Please follow the guide in [README.md](../../README.md) to install the dependencies and set up the environment. + +### Prepare the models and datasets + +Then you should prepare the models and datasets, and fill them in the configuration file. + +You should first download the model you want from Hugging Face or ModelScope, for example: +```bash +# Using Hugging Face +huggingface-cli download Qwen/Qwen2.5-1.5B-Instruct --local-dir $MODEL_PATH/Qwen/Qwen2.5-1.5B-Instruct + +# Using ModelScope +modelscope download {model_name} --local_dir $MODEL_PATH/{model_name} +``` + +For the dataset, you need to prepare both the SFT dataset and the RL dataset. Below we provide a script for processing the dataset into our required format. + +Before running the dataset processing script, you need to fill in the tokenizer path in the script for filtering SFT data that is too long. +You can also change the sample size if you want. +```python +TOKENIZER_MODEL_PATH = "YOUR MODEL TOKENIZER PATH" +MAX_TOKEN_LENGTH = 8196 +SFT_SAMPLE_SIZE = 5000 +PREFERENCE_SAMPLE_SIZE = 20000 +``` + +Then just run the script: +```bash +python examples/mix_chord/get_openr1_data.py +``` +This may take a while to run. + +> **Note**: Here we provide scripts for sampling SFT and RL data from the OpenR1 dataset, but unfortunately, since our original experiments did not use a fixed random seed, the data selection and ordering may differ from the paper. + +### Modify the running script + +Fill in the config files in [`mix_chord.yaml`](mix_chord.yaml) and [`train_mix_chord.yaml`](train_mix_chord.yaml). + +### Run the script + +```bash +# Stop existing ray processes +ray stop + +# Start ray +ray start --head + +# Run Trinity +trinity run --config examples/mix_chord/mix_chord.yaml +``` + +## Citation + +If you find this code useful, please consider citing our paper: +```bibtex +@misc{TrinityRFT, + title={Trinity-RFT: A General-Purpose and Unified Framework for Reinforcement Fine-Tuning of Large Language Models}, + author={Xuchen Pan and Yanxi Chen and Yushuo Chen and Yuchang Sun and Daoyuan Chen and Wenhao Zhang and Yuexiang Xie and Yilun Huang and Yilei Zhang and Dawei Gao and Weijie Shi and Yaliang Li and Bolin Ding and Jingren Zhou}, + year={2025}, + eprint={2505.17826}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2505.17826}, +} + +@misc{MIXCHORD, + title={On-Policy RL Meets Off-Policy Experts: Harmonizing Supervised Fine-Tuning and Reinforcement Learning via Dynamic Weighting}, + author={Wenhao Zhang and Yuexiang Xie and Yuchang Sun and Yanxi Chen and Guoyin Wang and Yaliang Li and Bolin Ding and Jingren Zhou}, + year={2025}, + eprint={2508.11408}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2508.11408}, +} +``` diff --git a/examples/mix_chord/get_openr1_data.py b/examples/mix_chord/get_openr1_data.py new file mode 100644 index 0000000000..71d188066c --- /dev/null +++ b/examples/mix_chord/get_openr1_data.py @@ -0,0 +1,205 @@ +""" +We provide scripts for generating the SFT and RL dataset. +""" + +import json +import os +import random + +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +# import re + +# Set random seed for reproducibility +random.seed(42) + +# Configuration parameters +TOKENIZER_MODEL_PATH = "YOUR MODEL TOKENIZER PATH" +MAX_TOKEN_LENGTH = 8196 +SFT_SAMPLE_SIZE = 5000 +RL_SAMPLE_SIZE = 20000 +SYSTEM_PROMPT = """You are a helpful assistant that solves MATH problems. You should first thinks about the reasoning process in mind and then provides the user with the answer. You should present your reasoning process using the format: \n ...your reasoning process here... \n first. You should always include your final answer in \\boxed{} as closed-form results.""" + + +def can_convert_to_int(answer): + """Check if answer can be directly converted to integer""" + try: + int(answer) + return True + except (ValueError, TypeError): + return False + + +def contains_chinese(text): + """There are many incorrect translated problems in OpenR1. We may want to filter them out.""" + for char in text: + if "\u4e00" <= char <= "\u9fff": + return True + if "\u3400" <= char <= "\u4dbf": + return True + if "\u20000" <= char <= "\u2a6df": + return True + return False + + +def process_dataset(openr1_ds, tokenizer): + """Process dataset and filter out instances that don't meet criteria""" + processed_data = [] + + for instance in tqdm(openr1_ds, desc="Processing dataset"): + # Filter out answers that cannot be directly converted to int + # if not can_convert_to_int(instance["answer"]): + # continue + + if contains_chinese(instance["problem"]): + continue + + # Process generations + generations_keeped = [] + correctness_list = instance.get("correctness_math_verify", []) + generations = instance.get("generations", []) + + for i, generation in enumerate(generations): + # Check correctness_math_verify + if i >= len(correctness_list) or not correctness_list[i]: + continue + + # Check token length + tokenized_length = len(tokenizer.tokenize(generation)) + if tokenized_length > MAX_TOKEN_LENGTH: + continue + + generations_keeped.append(generation) + + # Add to processed data if there are kept generations + if generations_keeped: + processed_data.append( + { + "problem": instance["problem"], + "answer": instance["answer"], + "generations": generations_keeped, + } + ) + + return processed_data + + +def create_sft_dataset(data, sample_size): + """Create SFT dataset with message format""" + sft_messages = [] + + # Random sample specified number of instances + sampled_data = random.sample(data, min(sample_size, len(data))) + + for instance in sampled_data: + # Randomly select one generation as response + generation = random.choice(instance["generations"]) + + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": instance["problem"]}, + {"role": "assistant", "content": generation}, + ] + sft_messages.append({"messages": messages}) + + return sft_messages, sampled_data + + +def create_rl_dataset(data, sample_size, output_dir): + """Create RL dataset in HuggingFace format""" + + filtered_data = [d for d in data if can_convert_to_int(d["answer"])] + print("Number of instances can convert to int: ", len(filtered_data)) + + # Filter instances with at least 2 generations + filtered_data = [d for d in filtered_data if len(d["generations"]) >= 2] + print(f"Number of instances with >= 2 generations: {len(filtered_data)}") + + # or No filter + # filtered_data = data + + # Random sample + sampled_data = random.sample(filtered_data, min(sample_size, len(filtered_data))) + + # Prepare data in HuggingFace format (only problem and answer) + rl_data = [] + for instance in sampled_data: + rl_data.append({"problem": instance["problem"], "answer": instance["answer"]}) + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Save as JSONL format for HuggingFace datasets + train_file = os.path.join(output_dir, "train.jsonl") + with open(train_file, "w", encoding="utf-8") as f: + for item in rl_data: + f.write(json.dumps(item, ensure_ascii=False) + "\n") + + # Create dataset_dict.json for HuggingFace format + dataset_info = { + "citation": "", + "description": "OpenR1 RLVR dataset subset", + "splits": {"train": {"name": "train", "num_examples": len(rl_data)}}, + } + + with open(os.path.join(output_dir, "dataset_dict.json"), "w", encoding="utf-8") as f: + json.dump(dataset_info, f, indent=2) + + print(f"Saved RL dataset to {output_dir}") + print(f"Total instances: {len(rl_data)}") + + return sampled_data # Return sampled data with generations for reference + + +def save_json(data, filename): + """Save data to JSON file""" + with open(filename, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=4) + print(f"Saved {len(data)} instances to {filename}") + + +def main(): + # Load dataset from HuggingFace + print("Loading dataset from HuggingFace...") + openr1_ds = load_dataset("open-r1/OpenR1-Math-220k", "default", split="train").to_list() + + print(f"Original dataset size: {len(openr1_ds)}") + + # Load tokenizer + print("Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL_PATH, use_fast=False) + + # Process dataset with filtering + print("Processing dataset with filters...") + processed_data = process_dataset(openr1_ds, tokenizer) + print(f"Processed dataset size: {len(processed_data)}") + + # Create SFT dataset + print(f"\nCreating SFT dataset (sampling {SFT_SAMPLE_SIZE} instances)...") + sft_dataset, sampled_for_sft = create_sft_dataset(processed_data, SFT_SAMPLE_SIZE) + save_json(sft_dataset, "openr1_sft_dataset.json") + + # Create RL dataset from remaining data + print("\nCreating RL dataset...") + # Remove instances already used for SFT + remaining_data = [d for d in processed_data if d not in sampled_for_sft] + print(f"Remaining data after SFT sampling: {len(remaining_data)}") + + # Create RL dataset in HuggingFace format + rl_output_dir = "openr1_rl_dataset" + sampled_rl_data = create_rl_dataset(remaining_data, RL_SAMPLE_SIZE, rl_output_dir) + + # Optionally save the full RL data with generations for reference + # save_json(sampled_rl_data, "openr1_rl_dataset_with_generations.json") + + print("\n" + "=" * 50) + print("Dataset generation completed!") + print(f"SFT dataset: {len(sft_dataset)} instances") + print(f"RL dataset: {len(sampled_rl_data)} instances") + print("=" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/mix_chord/mix_chord.yaml b/examples/mix_chord/mix_chord.yaml new file mode 100644 index 0000000000..caa44c573b --- /dev/null +++ b/examples/mix_chord/mix_chord.yaml @@ -0,0 +1,87 @@ +project: "mix_chord" +name: "test_mix_chord" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: mix_chord + repeat_times: 8 # or 16 for better performance in math related tasks + kl_loss_fn_args: + kl_coef: 0.0 + sample_strategy_args: + expert_data_ratio: 0.20 + policy_loss_fn_args: # feel free to change, we encourage you to try out different hyperparameters + mu_warmup_steps: 200 # 0 for chord-mu and chord-phi + mu_decay_steps: 400 # 200 for chord-mu and 0 for chord-phi + mu_peak: 0.5 # 0.9 for chord-mu and 0.1 for chord-phi + mu_valley: 0.02 # 0.05 for chord-mu and 0.1 for chord-phi + enable_phi_function: true # false for chord-mu and true for chord-phi + clip_range: 0.2 + use_token_level_loss_in_sft: true + use_dynamic_bsz: true + ppo_mini_batch_size: 320 # 320 = 256 + 64; if you set repeat times = 16, then it shoudle be 32 * 16 + 64 + ppo_micro_batch_size_per_gpu: 4 + ngpus_trainer: 4 + train_batch_size_expert: 64 + train_batch_size_usual: 256 # (40 batchsize * (1 - 0.2 expert_data_ratio)) * 8 repeat times +model: + model_path: /PATH/TO/MODEL/ + max_response_tokens: 10240 + max_model_len: 11264 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 4 + batch_size: 40 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: openr1_data_filtered_int + storage_type: file + path: /PATH/TO/RL_DATASET + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 1.0 + logprobs: 0 + workflow_args: + with_think: true + eval_tasksets: [] # you can add your own eval tasksets here + default_workflow_type: 'math_boxed_workflow' + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue + path: 'sqlite:///test_mix_chord.db' + sft_warmup_dataset: + total_epochs: 25 + name: SFT_data + storage_type: file + algorithm_type: sft + path: /PATH/TO/SFT_DATASET + split: 'train' + format: + prompt_type: messages + messages_key: 'messages' +explorer: + eval_interval: 10 + runner_num: 16 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/mix_chord/train_mix_chord.yaml' + save_interval: 50 +monitor: + monitor_type: wandb diff --git a/examples/mix_chord/train_mix_chord.yaml b/examples/mix_chord/train_mix_chord.yaml new file mode 100644 index 0000000000..0e853f40d7 --- /dev/null +++ b/examples/mix_chord/train_mix_chord.yaml @@ -0,0 +1,48 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 25600 + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 2 # sp size + optim: + lr: 1e-6 # or 5e-6, larger lr with warm up can result in better performance for SFT training. + lr_warmup_steps_ratio: 0. # in experimence lr warmup is helpful for chord-mu + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml index e4c4dd9172..4dc588cc51 100644 --- a/examples/mix_math/mix_math.yaml +++ b/examples/mix_math/mix_math.yaml @@ -1,22 +1,22 @@ project: "mix_math" -name: "expert0.25_mu0.1" +name: "expert0.20_mu0.1" checkpoint_root_dir: /PATH/TO/CHECKPOINT/ algorithm: algorithm_type: mix repeat_times: 8 sample_strategy_args: - expert_data_ratio: 0.25 + expert_data_ratio: 0.20 policy_loss_fn_args: mu: 0.1 clip_range: 0.2 - use_token_level_loss_in_sft: False - use_dynamic_bsz: False + use_token_level_loss_in_sft: true + use_dynamic_bsz: true repeat_times: 8 - ppo_mini_batch_size: 256 + ppo_mini_batch_size: 320 ppo_micro_batch_size_per_gpu: 4 ngpus_trainer: 4 train_batch_size_expert: 64 - train_batch_size_usual: 192 + train_batch_size_usual: 256 model: model_path: /PATH/TO/MODEL/ max_response_tokens: 10240 diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index e652ac8f4e..eb45e839a3 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -195,6 +195,30 @@ def default_config(cls) -> Dict: "policy_loss_fn": "mix", "advantage_fn": "grpo", "sample_strategy": "mix", + "entropy_loss_fn": "mix", + } + + +@ALGORITHM_TYPE.register_module("mix_chord") +class MIXCHORDAlgorithm(AlgorithmType): + """MIX algorithm.""" + + use_critic: bool = False + use_reference: bool = True + compute_advantage_in_trainer: bool = False + use_rollout: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 8, + "add_strategy": "grpo", + "policy_loss_fn": "mix_chord", + "advantage_fn": "grpo", + "sample_strategy": "mix", + "entropy_loss_fn": "mix", } diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py index d6179a832c..7a81fe96bf 100644 --- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -59,6 +59,33 @@ def __call__( return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()} +@ENTROPY_LOSS_FN.register_module("mix") +class MixEntropyLossFn(EntropyLossFn): + """ + Basic entropy loss function for mix algorithm. + """ + + def __init__(self, entropy_coef: float): + self.entropy_coef = entropy_coef + + def __call__( + self, + entropy: torch.Tensor, + action_mask: torch.Tensor, + expert_mask: torch.Tensor = None, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + if expert_mask is None: + raise ValueError("expert_mask is required for MixEntropyLossFn") + assert ( + len(expert_mask) == entropy.shape[0] + ), f"Error: {len(expert_mask)=} != {entropy.shape[0]=}" + entropy = entropy[~expert_mask] + action_mask = action_mask[~expert_mask] + entropy_loss = masked_mean(entropy, action_mask) + return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()} + + @ENTROPY_LOSS_FN.register_module("none") class DummyEntropyLossFn(EntropyLossFn): """ diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index 4d88215a0a..7416e94d5b 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -1,3 +1,8 @@ +from trinity.algorithm.policy_loss_fn.chord_policy_loss import ( + MIXCHORDPolicyLossFn, + SFTISLossFn, + SFTPhiLossFn, +) from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn from trinity.algorithm.policy_loss_fn.gspo_policy_loss import GSPOLossFn from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn @@ -15,4 +20,7 @@ "SFTLossFn", "MIXPolicyLossFn", "GSPOLossFn", + "MIXCHORDPolicyLossFn", + "SFTISLossFn", + "SFTPhiLossFn", ] diff --git a/trinity/algorithm/policy_loss_fn/chord_policy_loss.py b/trinity/algorithm/policy_loss_fn/chord_policy_loss.py new file mode 100644 index 0000000000..dc1a5504bc --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/chord_policy_loss.py @@ -0,0 +1,257 @@ +"""Implements the CHORD policy loss function.""" + +import math +from typing import Dict, Optional, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn +from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn +from trinity.algorithm.utils import masked_mean + + +def mu_schedule_function( + global_step: int, mu_warmup_steps: int, mu_decay_steps: int, mu_peak: float, mu_valley: float +) -> float: + """ + Computes a cosine decay schedule with a warmup phase for the mu parameter. + """ + # Warmup + if global_step < mu_warmup_steps: + return (global_step / mu_warmup_steps) * mu_peak + + # Decay + if global_step >= (mu_warmup_steps + mu_decay_steps): + return mu_valley + + adjusted_step = global_step - mu_warmup_steps + cosine_decay = 0.5 * (1 + math.cos(math.pi * adjusted_step / mu_decay_steps)) + decayed_mu = (mu_peak - mu_valley) * cosine_decay + mu_valley + return decayed_mu + + +@POLICY_LOSS_FN.register_module("sft_is") +class SFTISLossFn(PolicyLossFn): + """ + SFT loss with importance sampling + """ + + def __init__(self, backend: str = "verl", use_token_level_loss: bool = True) -> None: + super().__init__(backend=backend) + self.use_token_level_loss = use_token_level_loss + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + token_prob = torch.exp(logprob) + if self.use_token_level_loss: + sft_loss = masked_mean(-logprob * token_prob.detach(), action_mask) + else: + sft_loss = masked_mean(-logprob * token_prob.detach(), action_mask, axis=1).mean() + return sft_loss, {"sft_is_loss": sft_loss.detach().item()} + + @classmethod + def default_args(cls): + return { + "use_token_level_loss": True, + } + + +def phi_function(token_prob): + """ + The phi function downweights token with extreme probability. + Feel free to modify this function. + """ + return token_prob * (1 - token_prob) + + +@POLICY_LOSS_FN.register_module("sft_phi") +class SFTPhiLossFn(PolicyLossFn): + """ + SFT loss with transformed phi function + """ + + def __init__( + self, backend: str = "verl", use_token_level_loss: bool = True, cutoff_prob: float = 1.0 + ) -> None: + super().__init__(backend=backend) + self.use_token_level_loss = use_token_level_loss + self.cutoff_prob = cutoff_prob + assert 0.0 <= self.cutoff_prob <= 1.0 + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + token_prob = torch.exp(logprob) + if self.cutoff_prob < 1.0: + logprob = torch.clamp(logprob, max=math.log(self.cutoff_prob)) + + weighted_phi = phi_function(token_prob) + + if self.use_token_level_loss: + sft_loss = masked_mean(-logprob * weighted_phi.detach(), action_mask) + else: + sft_loss = masked_mean(-logprob * weighted_phi.detach(), action_mask, axis=1).mean() + return sft_loss, {"sft_phi_loss": sft_loss.detach().item()} + + @classmethod + def default_args(cls): + return { + "use_token_level_loss": True, + "cutoff_prob": 1.0, + } + + +@POLICY_LOSS_FN.register_module("mix_chord") +class MIXCHORDPolicyLossFn(PolicyLossFn): + """Implements a mixed policy loss combining GRPO and SFT losses. + + This loss function applies different loss components to data based on whether + it comes from an expert or not, as indicated by `expert_mask`. It combines: + - GRPO loss (self.grpo_loss_fn) for non-expert data + - SFT loss (self.sft_loss_fn) for expert data + the weight of SFT loss is globally controled by `mu_schedule` function + the tokenwise weights are calculated using different SFT loss formulas + + The per-sample weights are normalized using either `experience_per_gpu` or + `gradient_accumulation`, depending on whether dynamic batch sizing is enabled, + to ensure consistent weighting across different batches of the same type experiences. + """ + + def __init__( + self, + backend: str = "verl", + mu_warmup_steps: int = 0, + mu_decay_steps: int = 0, + mu_peak: float = 0.1, + mu_valley: float = 0.1, + enable_phi_function: bool = True, + clip_range: Optional[float] = None, + clip_range_low: Optional[float] = None, + clip_range_high: Optional[float] = None, + use_dynamic_bsz: Optional[bool] = None, + ppo_mini_batch_size: int = 1, + ppo_micro_batch_size_per_gpu: int = 1, + ngpus_trainer: int = 1, + train_batch_size_usual: int = 1, + train_batch_size_expert: int = 1, + use_token_level_loss_in_sft: bool = True, + ) -> None: + super().__init__(backend=backend) + self.mu_warmup_steps = mu_warmup_steps + self.mu_decay_steps = mu_decay_steps + self.mu_peak = mu_peak + self.mu_valley = mu_valley + self.enable_phi_function = enable_phi_function + self.use_dynamic_bsz = use_dynamic_bsz + self.experience_per_gpu = ppo_mini_batch_size // ngpus_trainer + self.gradient_accumulation = ppo_mini_batch_size // ppo_micro_batch_size_per_gpu + self.train_batch_size_usual = train_batch_size_usual // ngpus_trainer + self.train_batch_size_expert = train_batch_size_expert // ngpus_trainer + self.grpo_loss_fn = PPOPolicyLossFn( + clip_range=clip_range, + clip_range_low=clip_range_low, + clip_range_high=clip_range_high, + ) + if enable_phi_function: + self.sft_loss_fn = SFTPhiLossFn(use_token_level_loss=use_token_level_loss_in_sft) + else: + self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft) + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + expert_mask: torch.Tensor, + step: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + assert ( + len(expert_mask) == logprob.shape[0] + ), f"Error: {len(expert_mask)=} != {logprob.shape[0]=}" + + assert len(step) == logprob.shape[0], f"Error: {len(step)=} != {logprob.shape[0]=}" + + assert ( + step.max().item() == step.min().item() + ), f"Error: {step.max().item()} != {step.min().item()}" + current_step = step.max().item() + + n_usual_exp = torch.sum(~expert_mask).item() + n_expert_exp = torch.sum(expert_mask).item() + + if self.use_dynamic_bsz: + per_micro_batch_weight_usual = self.experience_per_gpu / ( + logprob.shape[0] * self.train_batch_size_usual + ) + per_micro_batch_weight_expert = self.experience_per_gpu / ( + logprob.shape[0] * self.train_batch_size_expert + ) + else: + per_micro_batch_weight_usual = self.gradient_accumulation / self.train_batch_size_usual # type: ignore + per_micro_batch_weight_expert = self.gradient_accumulation / self.train_batch_size_expert # type: ignore + + if n_usual_exp > 0: + grpo_loss, grpo_metrics = self.grpo_loss_fn( + logprob[~expert_mask], + old_logprob[~expert_mask], + action_mask[~expert_mask], + advantages[~expert_mask], + **kwargs, + ) + grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual + grpo_metrics = { + k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items() + } + else: + grpo_loss = torch.tensor(0.0, device=logprob.device) + grpo_metrics = {} + + # SFT Loss (expert) + if n_expert_exp > 0: + sft_loss, sft_metrics = self.sft_loss_fn( + logprob[expert_mask], + action_mask[expert_mask], + ) + sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert + sft_metrics = { + k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items() + } + else: + sft_loss = torch.tensor(0.0, device=logprob.device) + sft_metrics = {} + + mu = mu_schedule_function( + current_step, self.mu_warmup_steps, self.mu_decay_steps, self.mu_peak, self.mu_valley + ) + + loss = (1 - mu) * grpo_loss + mu * sft_loss + + metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} + metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) + metrics.update({"loss": loss.item(), "mu": mu}) + + return loss, metrics + + @classmethod + def default_args(cls) -> Dict: + """ + mu_warmup_steps: int, mu_decay_steps: int, mu_peak: float, mu_valley: float + """ + return { + "mu_warmup_steps": 0, + "mu_decay_steps": 0, + "mu_peak": 0.1, + "mu_valley": 0.1, + "clip_range": 0.2, + "enable_phi_function": True, + } diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index b8c6a54a50..2f70b788f4 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -114,7 +114,7 @@ def __call__( # type: ignore metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) - metrics.update({"loss": loss.item()}) + metrics["loss"] = loss.item() return loss, metrics diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 32a34834bf..a66ca32bf7 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -52,16 +52,22 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: if exp.info is None: exp.info = {} exp.info["is_expert"] = False + exp.info["step"] = step expert_exp_list = await self.expert_exp_buffer.read_async() for exp in expert_exp_list: + # we add fake rewards and logprobs to make it compatible exp.reward = 0.0 exp.logprobs = torch.zeros_like( exp.tokens[exp.prompt_length :], dtype=torch.float32 ) + exp.advantages = torch.zeros_like( + exp.tokens[exp.prompt_length :], dtype=torch.float32 + ) if exp.info is None: exp.info = {} exp.info["is_expert"] = True + exp.info["step"] = step exp_list = usual_exp_list + expert_exp_list repr_samples = representative_sample(exp_list) @@ -75,7 +81,12 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: source_field="is_expert", destination_field="expert_mask", data_type=torch.bool, - ) + ), + CustomField( + source_field="step", + destination_field="step", + data_type=torch.int32, + ), ], ) # type: ignore return exps, metrics, repr_samples diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index c1ead6cbf6..e6d5a91458 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -196,6 +196,7 @@ def update_policy(self, data: DataProto): # noqa: C901 entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( # type: ignore entropy=entropy, action_mask=response_mask, + **data, ) prefix_metrics( src_metrics=entropy_loss_metrics, diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 59453592d2..e1707629a8 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -375,7 +375,7 @@ def save_checkpoint(self, block_until_saved: bool = False) -> None: self._save_checkpoint() if block_until_saved: self.actor_rollout_wg.wait_on_save_thread() - if self.algorithm.use_critic: + if self.algorithm and self.algorithm.use_critic: self.critic_wg.wait_on_save_thread() def sync_weight(self) -> None: