Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions examples/mix_chord/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Example: CHORD Algorithm

Below we show an example of implementing the [CHORD](https://arxiv.org/pdf/2508.11408) algorithm.

Here we provide a basic runnable example demonstrating the core functionality of CHORD. The hyperparameters used in our experiments may not be optimal across different datasets—we encourage researchers to build upon this implementation and explore further improvements.

If you are interested in implementing your own algorithm, you may refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md) for guidance.

## How to run

### Install Trinity-RFT

First, you should install Trinity-RFT.

Please follow the guide in [README.md](../../README.md) to install the dependencies and set up the environment.

### Prepare the models and datasets

Then you should prepare the models and datasets, and fill them in the configuration file.

You should first download the model you want from Hugging Face or ModelScope, for example:
```bash
# Using Hugging Face
huggingface-cli download Qwen/Qwen2.5-1.5B-Instruct --local-dir $MODEL_PATH/Qwen/Qwen2.5-1.5B-Instruct

# Using ModelScope
modelscope download {model_name} --local_dir $MODEL_PATH/{model_name}
```

For the dataset, you need to prepare both the SFT dataset and the RL dataset. Below we provide a script for processing the dataset into our required format.

Before running the dataset processing script, you need to fill in the tokenizer path in the script for filtering SFT data that is too long.
You can also change the sample size if you want.
```python
TOKENIZER_MODEL_PATH = "YOUR MODEL TOKENIZER PATH"
MAX_TOKEN_LENGTH = 8196
SFT_SAMPLE_SIZE = 5000
PREFERENCE_SAMPLE_SIZE = 20000
```

Then just run the script:
```bash
python examples/mix_chord/get_openr1_data.py
```
This may take a while to run.

> **Note**: Here we provide scripts for sampling SFT and RL data from the OpenR1 dataset, but unfortunately, since our original experiments did not use a fixed random seed, the data selection and ordering may differ from the paper.

### Modify the running script

Fill in the config files in [`mix_chord.yaml`](mix_chord.yaml) and [`train_mix_chord.yaml`](train_mix_chord.yaml).

### Run the script

```bash
# Stop existing ray processes
ray stop

# Start ray
ray start --head

# Run Trinity
trinity run --config examples/mix_chord/mix_chord.yaml
```

## Citation

If you find this code useful, please consider citing our paper:
```bibtex
@misc{TrinityRFT,
title={Trinity-RFT: A General-Purpose and Unified Framework for Reinforcement Fine-Tuning of Large Language Models},
author={Xuchen Pan and Yanxi Chen and Yushuo Chen and Yuchang Sun and Daoyuan Chen and Wenhao Zhang and Yuexiang Xie and Yilun Huang and Yilei Zhang and Dawei Gao and Weijie Shi and Yaliang Li and Bolin Ding and Jingren Zhou},
year={2025},
eprint={2505.17826},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2505.17826},
}

@misc{MIXCHORD,
title={On-Policy RL Meets Off-Policy Experts: Harmonizing Supervised Fine-Tuning and Reinforcement Learning via Dynamic Weighting},
author={Wenhao Zhang and Yuexiang Xie and Yuchang Sun and Yanxi Chen and Guoyin Wang and Yaliang Li and Bolin Ding and Jingren Zhou},
year={2025},
eprint={2508.11408},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2508.11408},
}
```
205 changes: 205 additions & 0 deletions examples/mix_chord/get_openr1_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""
We provide scripts for generating the SFT and RL dataset.
"""

import json
import os
import random

from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer

# import re

# Set random seed for reproducibility
random.seed(42)

# Configuration parameters
TOKENIZER_MODEL_PATH = "YOUR MODEL TOKENIZER PATH"
MAX_TOKEN_LENGTH = 8196
SFT_SAMPLE_SIZE = 5000
RL_SAMPLE_SIZE = 20000
SYSTEM_PROMPT = """You are a helpful assistant that solves MATH problems. You should first thinks about the reasoning process in mind and then provides the user with the answer. You should present your reasoning process using the format: <think>\n ...your reasoning process here... </think>\n first. You should always include your final answer in \\boxed{} as closed-form results."""


def can_convert_to_int(answer):
"""Check if answer can be directly converted to integer"""
try:
int(answer)
return True
except (ValueError, TypeError):
return False


def contains_chinese(text):
"""There are many incorrect translated problems in OpenR1. We may want to filter them out."""
for char in text:
if "\u4e00" <= char <= "\u9fff":
return True
if "\u3400" <= char <= "\u4dbf":
return True
if "\u20000" <= char <= "\u2a6df":
return True
return False


def process_dataset(openr1_ds, tokenizer):
"""Process dataset and filter out instances that don't meet criteria"""
processed_data = []

for instance in tqdm(openr1_ds, desc="Processing dataset"):
# Filter out answers that cannot be directly converted to int
# if not can_convert_to_int(instance["answer"]):
# continue

if contains_chinese(instance["problem"]):
continue

# Process generations
generations_keeped = []
correctness_list = instance.get("correctness_math_verify", [])
generations = instance.get("generations", [])

for i, generation in enumerate(generations):
# Check correctness_math_verify
if i >= len(correctness_list) or not correctness_list[i]:
continue

# Check token length
tokenized_length = len(tokenizer.tokenize(generation))
if tokenized_length > MAX_TOKEN_LENGTH:
continue

generations_keeped.append(generation)

# Add to processed data if there are kept generations
if generations_keeped:
processed_data.append(
{
"problem": instance["problem"],
"answer": instance["answer"],
"generations": generations_keeped,
}
)

return processed_data


def create_sft_dataset(data, sample_size):
"""Create SFT dataset with message format"""
sft_messages = []

# Random sample specified number of instances
sampled_data = random.sample(data, min(sample_size, len(data)))

for instance in sampled_data:
# Randomly select one generation as response
generation = random.choice(instance["generations"])

messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": instance["problem"]},
{"role": "assistant", "content": generation},
]
sft_messages.append({"messages": messages})

return sft_messages, sampled_data


def create_rl_dataset(data, sample_size, output_dir):
"""Create RL dataset in HuggingFace format"""

filtered_data = [d for d in data if can_convert_to_int(d["answer"])]
print("Number of instances can convert to int: ", len(filtered_data))

# Filter instances with at least 2 generations
filtered_data = [d for d in filtered_data if len(d["generations"]) >= 2]
print(f"Number of instances with >= 2 generations: {len(filtered_data)}")

# or No filter
# filtered_data = data

# Random sample
sampled_data = random.sample(filtered_data, min(sample_size, len(filtered_data)))

# Prepare data in HuggingFace format (only problem and answer)
rl_data = []
for instance in sampled_data:
rl_data.append({"problem": instance["problem"], "answer": instance["answer"]})

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Save as JSONL format for HuggingFace datasets
train_file = os.path.join(output_dir, "train.jsonl")
with open(train_file, "w", encoding="utf-8") as f:
for item in rl_data:
f.write(json.dumps(item, ensure_ascii=False) + "\n")

# Create dataset_dict.json for HuggingFace format
dataset_info = {
"citation": "",
"description": "OpenR1 RLVR dataset subset",
"splits": {"train": {"name": "train", "num_examples": len(rl_data)}},
}

with open(os.path.join(output_dir, "dataset_dict.json"), "w", encoding="utf-8") as f:
json.dump(dataset_info, f, indent=2)

print(f"Saved RL dataset to {output_dir}")
print(f"Total instances: {len(rl_data)}")

return sampled_data # Return sampled data with generations for reference


def save_json(data, filename):
"""Save data to JSON file"""
with open(filename, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
print(f"Saved {len(data)} instances to {filename}")


def main():
# Load dataset from HuggingFace
print("Loading dataset from HuggingFace...")
openr1_ds = load_dataset("open-r1/OpenR1-Math-220k", "default", split="train").to_list()

print(f"Original dataset size: {len(openr1_ds)}")

# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL_PATH, use_fast=False)

# Process dataset with filtering
print("Processing dataset with filters...")
processed_data = process_dataset(openr1_ds, tokenizer)
print(f"Processed dataset size: {len(processed_data)}")

# Create SFT dataset
print(f"\nCreating SFT dataset (sampling {SFT_SAMPLE_SIZE} instances)...")
sft_dataset, sampled_for_sft = create_sft_dataset(processed_data, SFT_SAMPLE_SIZE)
save_json(sft_dataset, "openr1_sft_dataset.json")

# Create RL dataset from remaining data
print("\nCreating RL dataset...")
# Remove instances already used for SFT
remaining_data = [d for d in processed_data if d not in sampled_for_sft]
print(f"Remaining data after SFT sampling: {len(remaining_data)}")

# Create RL dataset in HuggingFace format
rl_output_dir = "openr1_rl_dataset"
sampled_rl_data = create_rl_dataset(remaining_data, RL_SAMPLE_SIZE, rl_output_dir)

# Optionally save the full RL data with generations for reference
# save_json(sampled_rl_data, "openr1_rl_dataset_with_generations.json")

print("\n" + "=" * 50)
print("Dataset generation completed!")
print(f"SFT dataset: {len(sft_dataset)} instances")
print(f"RL dataset: {len(sampled_rl_data)} instances")
print("=" * 50)


if __name__ == "__main__":
main()
87 changes: 87 additions & 0 deletions examples/mix_chord/mix_chord.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
project: "mix_chord"
name: "test_mix_chord"
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
algorithm:
algorithm_type: mix_chord
repeat_times: 8 # or 16 for better performance in math related tasks
kl_loss_fn_args:
kl_coef: 0.0
sample_strategy_args:
expert_data_ratio: 0.20
policy_loss_fn_args: # feel free to change, we encourage you to try out different hyperparameters
mu_warmup_steps: 200 # 0 for chord-mu and chord-phi
mu_decay_steps: 400 # 200 for chord-mu and 0 for chord-phi
mu_peak: 0.5 # 0.9 for chord-mu and 0.1 for chord-phi
mu_valley: 0.02 # 0.05 for chord-mu and 0.1 for chord-phi
enable_phi_function: true # false for chord-mu and true for chord-phi
clip_range: 0.2
use_token_level_loss_in_sft: true
use_dynamic_bsz: true
ppo_mini_batch_size: 320 # 320 = 256 + 64; if you set repeat times = 16, then it shoudle be 32 * 16 + 64
ppo_micro_batch_size_per_gpu: 4
ngpus_trainer: 4
train_batch_size_expert: 64
train_batch_size_usual: 256 # (40 batchsize * (1 - 0.2 expert_data_ratio)) * 8 repeat times
model:
model_path: /PATH/TO/MODEL/
max_response_tokens: 10240
max_model_len: 11264
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 4
batch_size: 40
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: openr1_data_filtered_int
storage_type: file
path: /PATH/TO/RL_DATASET
format:
prompt_key: 'problem'
response_key: 'answer'
rollout_args:
temperature: 1.0
logprobs: 0
workflow_args:
with_think: true
eval_tasksets: [] # you can add your own eval tasksets here
default_workflow_type: 'math_boxed_workflow'
trainer_input:
experience_buffer:
name: math_buffer
storage_type: queue
path: 'sqlite:///test_mix_chord.db'
sft_warmup_dataset:
total_epochs: 25
name: SFT_data
storage_type: file
algorithm_type: sft
path: /PATH/TO/SFT_DATASET
split: 'train'
format:
prompt_type: messages
messages_key: 'messages'
explorer:
eval_interval: 10
runner_num: 16
rollout_model:
engine_type: vllm_async
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 1
sync_timeout: 1200
trainer:
trainer_type: 'verl'
trainer_config_path: 'examples/mix_chord/train_mix_chord.yaml'
save_interval: 50
monitor:
monitor_type: wandb
Loading