-
Notifications
You must be signed in to change notification settings - Fork 55
Add benchmark #178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add benchmark #178
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
0fc9959
Add benchmark
chenyushuo 51707ff
Bug fix in dlc mode && apply suggestions form gemini
chenyushuo 0be9a70
apply suggestions from gemini
chenyushuo 4d498e3
apply suggestions from gemini
chenyushuo 1cdcc81
Merge branch 'main' of github.com:modelscope/Trinity-RFT into dev/add…
chenyushuo d8f230e
rename `project` and `name`
chenyushuo c6e7dfd
Add readme for benchmark
chenyushuo 3d3ef07
Update readme.md
chenyushuo 6bfac7c
Update readme.md
chenyushuo e7d4ad5
fix bench.py
chenyushuo fad7d13
doc fix && fix in explorer
chenyushuo 40630b8
doc fix
chenyushuo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| # Welcome to the Trinity Benchmark Runner 🌟 | ||
|
|
||
| This tool makes it easy to run benchmarks for the **Trinity-RFT**. Whether you're testing training performance or inference speed, this CLI lets you configure and launch experiments quickly—no complex setup required. Just pick your dataset, hardware, and model settings, and let the tool handle the rest. | ||
|
|
||
| --- | ||
|
|
||
| ## 🚀 What You Can Do | ||
|
|
||
| - **Single or Multi-Machine Training**: Run experiments on one computer or scale across multiple nodes. | ||
| - **Auto-Config**: The tool adjusts settings based on your cluster resources and inputs. | ||
| - **Flexible Datasets**: Works with datasets like `gsm8k` and `countdown`. | ||
| - **Custom Settings**: Tweak learning rates, sync intervals, and model configurations. | ||
| - **Cloud Ready**: Supports local runs *and* cloud environments like **Aliyun PAI DLC**. | ||
|
|
||
| --- | ||
|
|
||
| ## 🛠️ How to Use It | ||
|
|
||
| ### 1. Basic Command Structure | ||
| ```bash | ||
| python bench.py <dataset> [options] | ||
| ``` | ||
|
|
||
| ### 2. Example: Run a Benchmark | ||
| ```bash | ||
| python bench.py gsm8k --node_num 1 --gpu_per_node 8 --model_path /your/model/path | ||
| ``` | ||
|
|
||
| ### 3. Key Options Explained | ||
| | Option | What It Does | | ||
| |--------|--------------| | ||
| | `dataset` | Choose `gsm8k` or `countdown` | | ||
| | `--dlc` | Use when running in Aliyun PAI DLC environment | | ||
| | `--node_num` | Number of nodes in the cluster (default: 1) | | ||
| | `--gpu_per_node` | Number of GPUs per node (default: 8) | | ||
| | `--vllm_engine_num` | Number of vLLM engines to use | | ||
| | `--vllm_tp_size` | Tensor parallel size for vLLM | | ||
| | `--explorer_trainer_ratio` | Ratio of explorer engine number to trainer GPU number (default: 0.6), used when `--vllm_engine_num` is not specified | | ||
| | `--model_path` | Path to the main model checkpoint | | ||
| | `--critic_model_path` | Path to the critic model checkpoint | | ||
| | `--taskset_path` | Path to the taskset file | | ||
| | `--lr` | Learning rate for actor model | | ||
| | `--critic_lr` | Learning rate for critic model | | ||
| | `--sync_interval` | Synchronization interval between Trainer and Explorer | | ||
|
|
||
|
|
||
| --- | ||
|
|
||
| ## 📂 What Gets Saved | ||
|
|
||
| After running a benchmark, results are stored in `runs/<timestamp>/`: | ||
| - `config.yaml`: The exact settings used for your run. | ||
| - `checkpoints/`: Model snapshots saved during training. | ||
|
|
||
| --- | ||
|
|
||
| ## 📊 Benchmark Examples | ||
|
|
||
| ### 1. GSM8K | ||
| To reproduce this experiment: | ||
| ```bash | ||
| python bench.py gsm8k --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct | ||
| ``` | ||
| #### GSM8K Results | ||
| The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/068da409d215bb2450d93b6b7a56740d4751669d). | ||
|  | ||
|
|
||
| ### 2. Countdown | ||
| First generate data, then run the benchmark: | ||
| ```bash | ||
| # Step 1: Generate data | ||
| python benchmark/scripts/gen-countdown-data.py --local_dir /your/data/path | ||
| # Step 2: Run benchmark | ||
| python bench.py countdown --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct --taskset_path /your/data/path | ||
| ``` | ||
| #### Countdown Results | ||
| The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/068da409d215bb2450d93b6b7a56740d4751669d). | ||
|  | ||
|
|
||
| *More benchmarks will be added soon!* | ||
|
|
||
| --- | ||
|
|
||
| ## ✅ Tips for Success | ||
|
|
||
| 1. **Pre-Download Models**: Make sure all models and tasksets are ready at the paths you specify. | ||
| 2. **Multi-Node Setup**: If using multiple nodes, ensure they can communicate and share storage. | ||
| 3. **vLLM Users**: Check your vLLM installation supports the features you need (like tensor parallelism). | ||
| 4. **Aliyun Users**: Don’t forget the `--dlc` flag when running in PAI DLC! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,197 @@ | ||
| import argparse | ||
| import os | ||
| import subprocess | ||
| import time | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| import yaml | ||
|
|
||
| from trinity.algorithm.algorithm import ALGORITHM_TYPE | ||
| from trinity.utils.dlc_utils import get_dlc_env_vars | ||
|
|
||
|
|
||
| def set_engine_num(config, args): | ||
| config["cluster"]["node_num"] = args.node_num | ||
| config["cluster"]["gpu_per_node"] = args.gpu_per_node | ||
| batch_size = config["buffer"]["batch_size"] | ||
| if config["mode"] == "train": | ||
| return | ||
|
|
||
| if args.vllm_tp_size is not None: | ||
| config["explorer"]["rollout_model"]["tensor_parallel_size"] = args.vllm_tp_size | ||
| tensor_parallel_size = config["explorer"]["rollout_model"]["tensor_parallel_size"] | ||
|
|
||
| if args.vllm_engine_num is not None: | ||
| config["explorer"]["rollout_model"]["engine_num"] = args.vllm_engine_num | ||
| else: # auto set engine_num | ||
| opt_explorer_num, opt_ratio_diff = None, float("inf") | ||
| total_gpu_num = args.node_num * args.gpu_per_node | ||
|
|
||
| def update_opt_explorer_num(trainer_gpu_num, opt_explorer_num, opt_ratio_diff): | ||
| if batch_size % trainer_gpu_num != 0: | ||
| return opt_explorer_num, opt_ratio_diff | ||
| explorer_gpu_num = total_gpu_num - trainer_gpu_num | ||
| if explorer_gpu_num % tensor_parallel_size != 0: | ||
| return opt_explorer_num, opt_ratio_diff | ||
| explorer_num = explorer_gpu_num // tensor_parallel_size | ||
| ratio = explorer_num / trainer_gpu_num | ||
| if opt_ratio_diff > abs(ratio - args.explorer_trainer_ratio): | ||
| return explorer_num, abs(ratio - args.explorer_trainer_ratio) | ||
| return opt_explorer_num, opt_ratio_diff | ||
|
|
||
| if args.node_num == 1: # single node | ||
| for trainer_gpu_num in range(1, args.gpu_per_node): | ||
| opt_explorer_num, opt_ratio_diff = update_opt_explorer_num( | ||
| trainer_gpu_num, opt_explorer_num, opt_ratio_diff | ||
| ) | ||
| else: # multi node | ||
| assert ( | ||
| args.gpu_per_node % tensor_parallel_size == 0 | ||
| ), "Please adjust the value of `tensor_parallel_size` so that it is a divisor of `gpu_per_node`." | ||
| for trainer_node_num in range(1, args.node_num): | ||
| trainer_gpu_num = args.gpu_per_node * trainer_node_num | ||
| opt_explorer_num, opt_ratio_diff = update_opt_explorer_num( | ||
| trainer_gpu_num, opt_explorer_num, opt_ratio_diff | ||
| ) | ||
| assert ( | ||
| opt_explorer_num is not None | ||
| ), "Cannot find a suitable explorer number. Please check the value of `train_batch_size`." | ||
| config["explorer"]["rollout_model"]["engine_num"] = opt_explorer_num | ||
|
|
||
|
|
||
| def prepare_configs(args, rank, current_time): | ||
| base_path = os.path.dirname(os.path.abspath(__file__)) | ||
|
|
||
| current_time_str = time.strftime("%Y%m%d-%H%M%S", time.localtime(current_time)) | ||
| run_path = os.path.join(base_path, "runs", current_time_str) | ||
| config_path = os.path.join(run_path, "config.yaml") | ||
| if rank == 0: | ||
| os.makedirs(run_path) | ||
|
|
||
| with open(os.path.join(base_path, "config", f"{args.dataset}-template.yaml")) as f: | ||
| config = yaml.safe_load(f) | ||
|
|
||
| config["name"] += f"-{current_time_str}" | ||
| config["checkpoint_root_dir"] = os.path.join(run_path, "checkpoints") | ||
| set_engine_num(config, args) | ||
| config["model"]["model_path"] = ( | ||
| args.model_path | ||
| or os.environ.get("MODEL_PATH") | ||
| or config["model"]["model_path"] | ||
| or "Qwen/Qwen2.5-1.5B-Instruct" | ||
| ) | ||
| if ALGORITHM_TYPE.get(config["algorithm"]["algorithm_type"]).use_critic: | ||
| config["model"]["critic_model_path"] = ( | ||
| args.critic_model_path | ||
| or config["model"].get("critic_model_path") | ||
| or config["model"]["model_path"] | ||
| ) | ||
| if args.critic_lr: | ||
| config["trainer"]["trainer_config"]["critic"]["optim"]["lr"] = args.critic_lr | ||
| config["buffer"]["explorer_input"]["taskset"]["path"] = ( | ||
| args.taskset_path | ||
| or os.environ.get("TASKSET_PATH") | ||
| or config["buffer"]["explorer_input"]["taskset"]["path"] | ||
| ) | ||
| assert ( | ||
| config["buffer"]["explorer_input"]["taskset"]["path"] is not None | ||
| ), "Please specify taskset path." | ||
| if args.lr: | ||
| config["trainer"]["trainer_config"]["actor_rollout_ref"]["actor"]["optim"][ | ||
| "lr" | ||
| ] = args.lr | ||
| if args.sync_interval: | ||
| config["synchronizer"]["sync_interval"] = args.sync_interval | ||
|
|
||
| with open(config_path, "w") as f: | ||
| yaml.dump(config, f, allow_unicode=True, sort_keys=False) | ||
| return config_path | ||
|
|
||
|
|
||
| def setup_dlc(): | ||
| envs = get_dlc_env_vars() | ||
| dist.init_process_group( | ||
| backend="gloo", | ||
| init_method="env://", | ||
| world_size=envs["WORLD_SIZE"], | ||
| rank=envs["RANK"], | ||
| ) | ||
| if envs["RANK"] == 0: | ||
| current_time = time.time() | ||
| time_tensor = torch.tensor([current_time], device="cpu") | ||
| else: | ||
| time_tensor = torch.tensor([0.0], device="cpu") | ||
| dist.broadcast(time_tensor, src=0) | ||
| return envs["RANK"], time_tensor.item() | ||
|
|
||
|
|
||
| def main(args): | ||
| if args.dlc: | ||
| rank, current_time = setup_dlc() | ||
| else: | ||
| rank, current_time = 0, time.time() | ||
| config_path = prepare_configs(args, rank, current_time) | ||
| cmd_list = [ | ||
| "python", | ||
| "-m", | ||
| "trinity.cli.launcher", | ||
| "run", | ||
| "--config", | ||
| config_path, | ||
| ] | ||
| if args.dlc: | ||
| dist.barrier() | ||
| dist.destroy_process_group() | ||
| cmd_list.append("--dlc") | ||
| subprocess.run(cmd_list, check=True) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("dataset", type=str, choices=["gsm8k", "countdown", "openr1"]) | ||
| parser.add_argument( | ||
| "--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC." | ||
| ) | ||
| parser.add_argument("--node_num", type=int, default=1, help="Specify the number of nodes.") | ||
| parser.add_argument( | ||
| "--gpu_per_node", type=int, default=8, help="Specify the number of GPUs per node." | ||
| ) | ||
| parser.add_argument( | ||
| "--vllm_engine_num", type=int, default=None, help="Specify the number of vLLM engines." | ||
| ) | ||
| parser.add_argument( | ||
| "--vllm_tp_size", type=int, default=None, help="Specify the number of vLLM tp size." | ||
| ) | ||
| parser.add_argument( | ||
| "--explorer_trainer_ratio", | ||
| type=float, | ||
| default=0.6, | ||
| help="Specify the ratio of explorer engine num to trainer gpu num.", | ||
| ) | ||
| parser.add_argument( | ||
| "--model_path", | ||
| type=str, | ||
| default=None, | ||
| help="Specify the path to the model checkpoint.", | ||
| ) | ||
| parser.add_argument( | ||
| "--critic_model_path", | ||
| type=str, | ||
| default=None, | ||
| help="Specify the path to the critic model checkpoint.", | ||
| ) | ||
| parser.add_argument( | ||
| "--taskset_path", type=str, default=None, help="Specify the path to the taskset." | ||
| ) | ||
| parser.add_argument( | ||
| "--lr", type=float, default=None, help="Specify the learning rate for actor model." | ||
| ) | ||
| parser.add_argument( | ||
| "--critic_lr", type=float, default=None, help="Specify the learning rate for critic model." | ||
| ) | ||
| parser.add_argument( | ||
| "--sync_interval", type=int, default=None, help="Specify the sync interval." | ||
| ) | ||
| args = parser.parse_args() | ||
| main(args) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.