-
Notifications
You must be signed in to change notification settings - Fork 55
Add CHORD algorithm example #194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
0414330
feat: migrate CHORD to new version with phi function support
garyzhang99 83e5c40
fix old mix config bug
garyzhang99 382c38e
old mix algorithm config bug fix
garyzhang99 68b7d31
fix config
garyzhang99 b32a220
resolve comments
garyzhang99 e191b11
fix more config promblems
garyzhang99 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| # Example: CHORD Algorithm | ||
|
|
||
| Below we show an example of implementing the [CHORD](https://arxiv.org/pdf/2508.11408) algorithm. | ||
|
|
||
| Here we provide a basic runnable example demonstrating the core functionality of CHORD. The hyperparameters used in our experiments may not be optimal across different datasets—we encourage researchers to build upon this implementation and explore further improvements. | ||
|
|
||
| If you are interested in implementing your own algorithm, you may refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md) for guidance. | ||
|
|
||
| ## How to run | ||
|
|
||
| ### Install Trinity-RFT | ||
|
|
||
| First, you should install Trinity-RFT. | ||
|
|
||
| Please follow the guide in [README.md](../../README.md) to install the dependencies and set up the environment. | ||
|
|
||
| ### Prepare the models and datasets | ||
|
|
||
| Then you should prepare the models and datasets, and fill them in the configuration file. | ||
|
|
||
| You should first download the model you want from Hugging Face or ModelScope, for example: | ||
| ```bash | ||
| # Using Hugging Face | ||
| huggingface-cli download Qwen/Qwen2.5-1.5B-Instruct --local-dir $MODEL_PATH/Qwen/Qwen2.5-1.5B-Instruct | ||
|
|
||
| # Using ModelScope | ||
| modelscope download {model_name} --local_dir $MODEL_PATH/{model_name} | ||
| ``` | ||
|
|
||
| For the dataset, you need to prepare both the SFT dataset and the RL dataset. Below we provide a script for processing the dataset into our required format. | ||
|
|
||
| Before running the dataset processing script, you need to fill in the tokenizer path in the script for filtering SFT data that is too long. | ||
| You can also change the sample size if you want. | ||
| ```python | ||
| TOKENIZER_MODEL_PATH = "YOUR MODEL TOKENIZER PATH" | ||
| MAX_TOKEN_LENGTH = 8196 | ||
| SFT_SAMPLE_SIZE = 5000 | ||
| PREFERENCE_SAMPLE_SIZE = 20000 | ||
| ``` | ||
|
|
||
| Then just run the script: | ||
| ```bash | ||
| python examples/mix_chord/get_openr1_data.py | ||
| ``` | ||
| This may take a while to run. | ||
|
|
||
| > **Note**: Here we provide scripts for sampling SFT and RL data from the OpenR1 dataset, but unfortunately, since our original experiments did not use a fixed random seed, the data selection and ordering may differ from the paper. | ||
|
|
||
| ### Modify the running script | ||
|
|
||
| Fill in the config files in [`mix_chord.yaml`](mix_chord.yaml) and [`train_mix_chord.yaml`](train_mix_chord.yaml). | ||
|
|
||
| ### Run the script | ||
|
|
||
| ```bash | ||
| # Stop existing ray processes | ||
| ray stop | ||
|
|
||
| # Start ray | ||
| ray start --head | ||
|
|
||
| # Run Trinity | ||
| trinity run --config examples/mix_chord/mix_chord.yaml | ||
| ``` | ||
|
|
||
| ## Citation | ||
|
|
||
| If you find this code useful, please consider citing our paper: | ||
| ```bibtex | ||
| @misc{TrinityRFT, | ||
| title={Trinity-RFT: A General-Purpose and Unified Framework for Reinforcement Fine-Tuning of Large Language Models}, | ||
| author={Xuchen Pan and Yanxi Chen and Yushuo Chen and Yuchang Sun and Daoyuan Chen and Wenhao Zhang and Yuexiang Xie and Yilun Huang and Yilei Zhang and Dawei Gao and Weijie Shi and Yaliang Li and Bolin Ding and Jingren Zhou}, | ||
| year={2025}, | ||
| eprint={2505.17826}, | ||
| archivePrefix={arXiv}, | ||
| primaryClass={cs.LG}, | ||
| url={https://arxiv.org/abs/2505.17826}, | ||
| } | ||
|
|
||
| @misc{MIXCHORD, | ||
| title={On-Policy RL Meets Off-Policy Experts: Harmonizing Supervised Fine-Tuning and Reinforcement Learning via Dynamic Weighting}, | ||
| author={Wenhao Zhang and Yuexiang Xie and Yuchang Sun and Yanxi Chen and Guoyin Wang and Yaliang Li and Bolin Ding and Jingren Zhou}, | ||
| year={2025}, | ||
| eprint={2508.11408}, | ||
| archivePrefix={arXiv}, | ||
| primaryClass={cs.LG}, | ||
| url={https://arxiv.org/abs/2508.11408}, | ||
| } | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,205 @@ | ||
| """ | ||
| We provide scripts for generating the SFT and RL dataset. | ||
| """ | ||
|
|
||
| import json | ||
| import os | ||
| import random | ||
|
|
||
| from datasets import load_dataset | ||
| from tqdm import tqdm | ||
| from transformers import AutoTokenizer | ||
|
|
||
| # import re | ||
|
|
||
| # Set random seed for reproducibility | ||
| random.seed(42) | ||
|
|
||
| # Configuration parameters | ||
| TOKENIZER_MODEL_PATH = "YOUR MODEL TOKENIZER PATH" | ||
| MAX_TOKEN_LENGTH = 8196 | ||
| SFT_SAMPLE_SIZE = 5000 | ||
| RL_SAMPLE_SIZE = 20000 | ||
| SYSTEM_PROMPT = """You are a helpful assistant that solves MATH problems. You should first thinks about the reasoning process in mind and then provides the user with the answer. You should present your reasoning process using the format: <think>\n ...your reasoning process here... </think>\n first. You should always include your final answer in \\boxed{} as closed-form results.""" | ||
|
|
||
|
|
||
| def can_convert_to_int(answer): | ||
| """Check if answer can be directly converted to integer""" | ||
| try: | ||
| int(answer) | ||
| return True | ||
| except (ValueError, TypeError): | ||
| return False | ||
|
|
||
|
|
||
| def contains_chinese(text): | ||
| """There are many incorrect translated problems in OpenR1. We may want to filter them out.""" | ||
| for char in text: | ||
| if "\u4e00" <= char <= "\u9fff": | ||
| return True | ||
| if "\u3400" <= char <= "\u4dbf": | ||
| return True | ||
| if "\u20000" <= char <= "\u2a6df": | ||
| return True | ||
| return False | ||
|
|
||
|
|
||
| def process_dataset(openr1_ds, tokenizer): | ||
| """Process dataset and filter out instances that don't meet criteria""" | ||
| processed_data = [] | ||
|
|
||
| for instance in tqdm(openr1_ds, desc="Processing dataset"): | ||
| # Filter out answers that cannot be directly converted to int | ||
| # if not can_convert_to_int(instance["answer"]): | ||
| # continue | ||
|
|
||
| if contains_chinese(instance["problem"]): | ||
| continue | ||
|
|
||
| # Process generations | ||
| generations_keeped = [] | ||
| correctness_list = instance.get("correctness_math_verify", []) | ||
| generations = instance.get("generations", []) | ||
|
|
||
| for i, generation in enumerate(generations): | ||
| # Check correctness_math_verify | ||
| if i >= len(correctness_list) or not correctness_list[i]: | ||
| continue | ||
|
|
||
| # Check token length | ||
| tokenized_length = len(tokenizer.tokenize(generation)) | ||
| if tokenized_length > MAX_TOKEN_LENGTH: | ||
| continue | ||
|
|
||
| generations_keeped.append(generation) | ||
|
|
||
| # Add to processed data if there are kept generations | ||
| if generations_keeped: | ||
| processed_data.append( | ||
| { | ||
| "problem": instance["problem"], | ||
| "answer": instance["answer"], | ||
| "generations": generations_keeped, | ||
| } | ||
| ) | ||
|
|
||
| return processed_data | ||
|
|
||
|
|
||
| def create_sft_dataset(data, sample_size): | ||
| """Create SFT dataset with message format""" | ||
| sft_messages = [] | ||
|
|
||
| # Random sample specified number of instances | ||
| sampled_data = random.sample(data, min(sample_size, len(data))) | ||
|
|
||
| for instance in sampled_data: | ||
| # Randomly select one generation as response | ||
| generation = random.choice(instance["generations"]) | ||
|
|
||
| messages = [ | ||
| {"role": "system", "content": SYSTEM_PROMPT}, | ||
| {"role": "user", "content": instance["problem"]}, | ||
| {"role": "assistant", "content": generation}, | ||
| ] | ||
| sft_messages.append({"messages": messages}) | ||
|
|
||
| return sft_messages, sampled_data | ||
|
|
||
|
|
||
| def create_rl_dataset(data, sample_size, output_dir): | ||
| """Create RL dataset in HuggingFace format""" | ||
|
|
||
| filtered_data = [d for d in data if can_convert_to_int(d["answer"])] | ||
| print("Number of instances can convert to int: ", len(filtered_data)) | ||
|
|
||
| # Filter instances with at least 2 generations | ||
| filtered_data = [d for d in filtered_data if len(d["generations"]) >= 2] | ||
| print(f"Number of instances with >= 2 generations: {len(filtered_data)}") | ||
|
|
||
| # or No filter | ||
| # filtered_data = data | ||
|
|
||
| # Random sample | ||
| sampled_data = random.sample(filtered_data, min(sample_size, len(filtered_data))) | ||
|
|
||
| # Prepare data in HuggingFace format (only problem and answer) | ||
| rl_data = [] | ||
| for instance in sampled_data: | ||
| rl_data.append({"problem": instance["problem"], "answer": instance["answer"]}) | ||
|
|
||
| # Create output directory if it doesn't exist | ||
| os.makedirs(output_dir, exist_ok=True) | ||
|
|
||
| # Save as JSONL format for HuggingFace datasets | ||
| train_file = os.path.join(output_dir, "train.jsonl") | ||
| with open(train_file, "w", encoding="utf-8") as f: | ||
| for item in rl_data: | ||
| f.write(json.dumps(item, ensure_ascii=False) + "\n") | ||
|
|
||
| # Create dataset_dict.json for HuggingFace format | ||
| dataset_info = { | ||
| "citation": "", | ||
| "description": "OpenR1 RLVR dataset subset", | ||
| "splits": {"train": {"name": "train", "num_examples": len(rl_data)}}, | ||
| } | ||
|
|
||
| with open(os.path.join(output_dir, "dataset_dict.json"), "w", encoding="utf-8") as f: | ||
| json.dump(dataset_info, f, indent=2) | ||
|
|
||
| print(f"Saved RL dataset to {output_dir}") | ||
| print(f"Total instances: {len(rl_data)}") | ||
|
|
||
| return sampled_data # Return sampled data with generations for reference | ||
|
|
||
|
|
||
| def save_json(data, filename): | ||
| """Save data to JSON file""" | ||
| with open(filename, "w", encoding="utf-8") as f: | ||
| json.dump(data, f, ensure_ascii=False, indent=4) | ||
| print(f"Saved {len(data)} instances to {filename}") | ||
|
|
||
|
|
||
| def main(): | ||
| # Load dataset from HuggingFace | ||
| print("Loading dataset from HuggingFace...") | ||
| openr1_ds = load_dataset("open-r1/OpenR1-Math-220k", "default", split="train").to_list() | ||
|
|
||
| print(f"Original dataset size: {len(openr1_ds)}") | ||
|
|
||
| # Load tokenizer | ||
| print("Loading tokenizer...") | ||
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL_PATH, use_fast=False) | ||
|
|
||
| # Process dataset with filtering | ||
| print("Processing dataset with filters...") | ||
| processed_data = process_dataset(openr1_ds, tokenizer) | ||
| print(f"Processed dataset size: {len(processed_data)}") | ||
|
|
||
| # Create SFT dataset | ||
| print(f"\nCreating SFT dataset (sampling {SFT_SAMPLE_SIZE} instances)...") | ||
| sft_dataset, sampled_for_sft = create_sft_dataset(processed_data, SFT_SAMPLE_SIZE) | ||
| save_json(sft_dataset, "openr1_sft_dataset.json") | ||
|
|
||
| # Create RL dataset from remaining data | ||
| print("\nCreating RL dataset...") | ||
| # Remove instances already used for SFT | ||
| remaining_data = [d for d in processed_data if d not in sampled_for_sft] | ||
garyzhang99 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| print(f"Remaining data after SFT sampling: {len(remaining_data)}") | ||
|
|
||
| # Create RL dataset in HuggingFace format | ||
| rl_output_dir = "openr1_rl_dataset" | ||
| sampled_rl_data = create_rl_dataset(remaining_data, RL_SAMPLE_SIZE, rl_output_dir) | ||
|
|
||
| # Optionally save the full RL data with generations for reference | ||
| # save_json(sampled_rl_data, "openr1_rl_dataset_with_generations.json") | ||
|
|
||
| print("\n" + "=" * 50) | ||
| print("Dataset generation completed!") | ||
| print(f"SFT dataset: {len(sft_dataset)} instances") | ||
| print(f"RL dataset: {len(sampled_rl_data)} instances") | ||
| print("=" * 50) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| project: "mix_chord" | ||
| name: "test_mix_chord" | ||
| checkpoint_root_dir: /PATH/TO/CHECKPOINT/ | ||
| algorithm: | ||
| algorithm_type: mix_chord | ||
| repeat_times: 8 # or 16 for better performance in math related tasks | ||
| kl_loss_fn_args: | ||
| kl_coef: 0.0 | ||
| sample_strategy_args: | ||
| expert_data_ratio: 0.20 | ||
| policy_loss_fn_args: # feel free to change, we encourage you to try out different hyperparameters | ||
| mu_warmup_steps: 200 # 0 for chord-mu and chord-phi | ||
| mu_decay_steps: 400 # 200 for chord-mu and 0 for chord-phi | ||
| mu_peak: 0.5 # 0.9 for chord-mu and 0.1 for chord-phi | ||
| mu_valley: 0.02 # 0.05 for chord-mu and 0.1 for chord-phi | ||
| enable_phi_function: true # false for chord-mu and true for chord-phi | ||
| clip_range: 0.2 | ||
| use_token_level_loss_in_sft: true | ||
| use_dynamic_bsz: true | ||
| ppo_mini_batch_size: 320 # 320 = 256 + 64; if you set repeat times = 16, then it shoudle be 32 * 16 + 64 | ||
| ppo_micro_batch_size_per_gpu: 4 | ||
| ngpus_trainer: 4 | ||
| train_batch_size_expert: 64 | ||
| train_batch_size_usual: 256 # (40 batchsize * (1 - 0.2 expert_data_ratio)) * 8 repeat times | ||
| model: | ||
| model_path: /PATH/TO/MODEL/ | ||
| max_response_tokens: 10240 | ||
| max_model_len: 11264 | ||
| cluster: | ||
| node_num: 1 | ||
| gpu_per_node: 8 | ||
| buffer: | ||
| total_epochs: 4 | ||
| batch_size: 40 | ||
| max_retry_times: 3 | ||
| max_retry_interval: 1 | ||
| explorer_input: | ||
| taskset: | ||
| name: openr1_data_filtered_int | ||
| storage_type: file | ||
| path: /PATH/TO/RL_DATASET | ||
| format: | ||
| prompt_key: 'problem' | ||
| response_key: 'answer' | ||
| rollout_args: | ||
| temperature: 1.0 | ||
| logprobs: 0 | ||
| workflow_args: | ||
| with_think: true | ||
| eval_tasksets: [] # you can add your own eval tasksets here | ||
| default_workflow_type: 'math_boxed_workflow' | ||
| trainer_input: | ||
| experience_buffer: | ||
| name: math_buffer | ||
| storage_type: queue | ||
| path: 'sqlite:///test_mix_chord.db' | ||
| sft_warmup_dataset: | ||
| total_epochs: 25 | ||
| name: SFT_data | ||
| storage_type: file | ||
| algorithm_type: sft | ||
| path: /PATH/TO/SFT_DATASET | ||
| split: 'train' | ||
| format: | ||
| prompt_type: messages | ||
| messages_key: 'messages' | ||
| explorer: | ||
| eval_interval: 10 | ||
| runner_num: 16 | ||
| rollout_model: | ||
| engine_type: vllm_async | ||
| engine_num: 4 | ||
| tensor_parallel_size: 1 | ||
| enable_prefix_caching: false | ||
| enforce_eager: true | ||
| dtype: bfloat16 | ||
| seed: 42 | ||
| synchronizer: | ||
| sync_method: 'nccl' | ||
| sync_interval: 1 | ||
| sync_timeout: 1200 | ||
| trainer: | ||
| trainer_type: 'verl' | ||
| trainer_config_path: 'examples/mix_chord/train_mix_chord.yaml' | ||
| save_interval: 50 | ||
| monitor: | ||
| monitor_type: wandb |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.