diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 0000000000..675653003e --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,89 @@ +# Welcome to the Trinity Benchmark Runner ๐ŸŒŸ + +This tool makes it easy to run benchmarks for the **Trinity-RFT**. Whether you're testing training performance or inference speed, this CLI lets you configure and launch experiments quicklyโ€”no complex setup required. Just pick your dataset, hardware, and model settings, and let the tool handle the rest. + +--- + +## ๐Ÿš€ What You Can Do + +- **Single or Multi-Machine Training**: Run experiments on one computer or scale across multiple nodes. +- **Auto-Config**: The tool adjusts settings based on your cluster resources and inputs. +- **Flexible Datasets**: Works with datasets like `gsm8k` and `countdown`. +- **Custom Settings**: Tweak learning rates, sync intervals, and model configurations. +- **Cloud Ready**: Supports local runs *and* cloud environments like **Aliyun PAI DLC**. + +--- + +## ๐Ÿ› ๏ธ How to Use It + +### 1. Basic Command Structure +```bash +python bench.py [options] +``` + +### 2. Example: Run a Benchmark +```bash +python bench.py gsm8k --node_num 1 --gpu_per_node 8 --model_path /your/model/path +``` + +### 3. Key Options Explained +| Option | What It Does | +|--------|--------------| +| `dataset` | Choose `gsm8k` or `countdown` | +| `--dlc` | Use when running in Aliyun PAI DLC environment | +| `--node_num` | Number of nodes in the cluster (default: 1) | +| `--gpu_per_node` | Number of GPUs per node (default: 8) | +| `--vllm_engine_num` | Number of vLLM engines to use | +| `--vllm_tp_size` | Tensor parallel size for vLLM | +| `--explorer_trainer_ratio` | Ratio of explorer engine number to trainer GPU number (default: 0.6), used when `--vllm_engine_num` is not specified | +| `--model_path` | Path to the main model checkpoint | +| `--critic_model_path` | Path to the critic model checkpoint | +| `--taskset_path` | Path to the taskset file | +| `--lr` | Learning rate for actor model | +| `--critic_lr` | Learning rate for critic model | +| `--sync_interval` | Synchronization interval between Trainer and Explorer | + + +--- + +## ๐Ÿ“‚ What Gets Saved + +After running a benchmark, results are stored in `runs//`: +- `config.yaml`: The exact settings used for your run. +- `checkpoints/`: Model snapshots saved during training. + +--- + +## ๐Ÿ“Š Benchmark Examples + +### 1. GSM8K +To reproduce this experiment: +```bash +python bench.py gsm8k --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct +``` +#### GSM8K Results +The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/068da409d215bb2450d93b6b7a56740d4751669d). +![View Results](../docs/sphinx_doc/assets/gsm8k-bench.png) + +### 2. Countdown +First generate data, then run the benchmark: +```bash +# Step 1: Generate data +python benchmark/scripts/gen-countdown-data.py --local_dir /your/data/path +# Step 2: Run benchmark +python bench.py countdown --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct --taskset_path /your/data/path +``` +#### Countdown Results +The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/068da409d215bb2450d93b6b7a56740d4751669d). +![View Results](../docs/sphinx_doc/assets/gsm8k-bench.png) + +*More benchmarks will be added soon!* + +--- + +## โœ… Tips for Success + +1. **Pre-Download Models**: Make sure all models and tasksets are ready at the paths you specify. +2. **Multi-Node Setup**: If using multiple nodes, ensure they can communicate and share storage. +3. **vLLM Users**: Check your vLLM installation supports the features you need (like tensor parallelism). +4. **Aliyun Users**: Donโ€™t forget the `--dlc` flag when running in PAI DLC! diff --git a/benchmark/bench.py b/benchmark/bench.py new file mode 100644 index 0000000000..1aafd2b55c --- /dev/null +++ b/benchmark/bench.py @@ -0,0 +1,197 @@ +import argparse +import os +import subprocess +import time + +import torch +import torch.distributed as dist +import yaml + +from trinity.algorithm.algorithm import ALGORITHM_TYPE +from trinity.utils.dlc_utils import get_dlc_env_vars + + +def set_engine_num(config, args): + config["cluster"]["node_num"] = args.node_num + config["cluster"]["gpu_per_node"] = args.gpu_per_node + batch_size = config["buffer"]["batch_size"] + if config["mode"] == "train": + return + + if args.vllm_tp_size is not None: + config["explorer"]["rollout_model"]["tensor_parallel_size"] = args.vllm_tp_size + tensor_parallel_size = config["explorer"]["rollout_model"]["tensor_parallel_size"] + + if args.vllm_engine_num is not None: + config["explorer"]["rollout_model"]["engine_num"] = args.vllm_engine_num + else: # auto set engine_num + opt_explorer_num, opt_ratio_diff = None, float("inf") + total_gpu_num = args.node_num * args.gpu_per_node + + def update_opt_explorer_num(trainer_gpu_num, opt_explorer_num, opt_ratio_diff): + if batch_size % trainer_gpu_num != 0: + return opt_explorer_num, opt_ratio_diff + explorer_gpu_num = total_gpu_num - trainer_gpu_num + if explorer_gpu_num % tensor_parallel_size != 0: + return opt_explorer_num, opt_ratio_diff + explorer_num = explorer_gpu_num // tensor_parallel_size + ratio = explorer_num / trainer_gpu_num + if opt_ratio_diff > abs(ratio - args.explorer_trainer_ratio): + return explorer_num, abs(ratio - args.explorer_trainer_ratio) + return opt_explorer_num, opt_ratio_diff + + if args.node_num == 1: # single node + for trainer_gpu_num in range(1, args.gpu_per_node): + opt_explorer_num, opt_ratio_diff = update_opt_explorer_num( + trainer_gpu_num, opt_explorer_num, opt_ratio_diff + ) + else: # multi node + assert ( + args.gpu_per_node % tensor_parallel_size == 0 + ), "Please adjust the value of `tensor_parallel_size` so that it is a divisor of `gpu_per_node`." + for trainer_node_num in range(1, args.node_num): + trainer_gpu_num = args.gpu_per_node * trainer_node_num + opt_explorer_num, opt_ratio_diff = update_opt_explorer_num( + trainer_gpu_num, opt_explorer_num, opt_ratio_diff + ) + assert ( + opt_explorer_num is not None + ), "Cannot find a suitable explorer number. Please check the value of `train_batch_size`." + config["explorer"]["rollout_model"]["engine_num"] = opt_explorer_num + + +def prepare_configs(args, rank, current_time): + base_path = os.path.dirname(os.path.abspath(__file__)) + + current_time_str = time.strftime("%Y%m%d-%H%M%S", time.localtime(current_time)) + run_path = os.path.join(base_path, "runs", current_time_str) + config_path = os.path.join(run_path, "config.yaml") + if rank == 0: + os.makedirs(run_path) + + with open(os.path.join(base_path, "config", f"{args.dataset}-template.yaml")) as f: + config = yaml.safe_load(f) + + config["name"] += f"-{current_time_str}" + config["checkpoint_root_dir"] = os.path.join(run_path, "checkpoints") + set_engine_num(config, args) + config["model"]["model_path"] = ( + args.model_path + or os.environ.get("MODEL_PATH") + or config["model"]["model_path"] + or "Qwen/Qwen2.5-1.5B-Instruct" + ) + if ALGORITHM_TYPE.get(config["algorithm"]["algorithm_type"]).use_critic: + config["model"]["critic_model_path"] = ( + args.critic_model_path + or config["model"].get("critic_model_path") + or config["model"]["model_path"] + ) + if args.critic_lr: + config["trainer"]["trainer_config"]["critic"]["optim"]["lr"] = args.critic_lr + config["buffer"]["explorer_input"]["taskset"]["path"] = ( + args.taskset_path + or os.environ.get("TASKSET_PATH") + or config["buffer"]["explorer_input"]["taskset"]["path"] + ) + assert ( + config["buffer"]["explorer_input"]["taskset"]["path"] is not None + ), "Please specify taskset path." + if args.lr: + config["trainer"]["trainer_config"]["actor_rollout_ref"]["actor"]["optim"][ + "lr" + ] = args.lr + if args.sync_interval: + config["synchronizer"]["sync_interval"] = args.sync_interval + + with open(config_path, "w") as f: + yaml.dump(config, f, allow_unicode=True, sort_keys=False) + return config_path + + +def setup_dlc(): + envs = get_dlc_env_vars() + dist.init_process_group( + backend="gloo", + init_method="env://", + world_size=envs["WORLD_SIZE"], + rank=envs["RANK"], + ) + if envs["RANK"] == 0: + current_time = time.time() + time_tensor = torch.tensor([current_time], device="cpu") + else: + time_tensor = torch.tensor([0.0], device="cpu") + dist.broadcast(time_tensor, src=0) + return envs["RANK"], time_tensor.item() + + +def main(args): + if args.dlc: + rank, current_time = setup_dlc() + else: + rank, current_time = 0, time.time() + config_path = prepare_configs(args, rank, current_time) + cmd_list = [ + "python", + "-m", + "trinity.cli.launcher", + "run", + "--config", + config_path, + ] + if args.dlc: + dist.barrier() + dist.destroy_process_group() + cmd_list.append("--dlc") + subprocess.run(cmd_list, check=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("dataset", type=str, choices=["gsm8k", "countdown", "openr1"]) + parser.add_argument( + "--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC." + ) + parser.add_argument("--node_num", type=int, default=1, help="Specify the number of nodes.") + parser.add_argument( + "--gpu_per_node", type=int, default=8, help="Specify the number of GPUs per node." + ) + parser.add_argument( + "--vllm_engine_num", type=int, default=None, help="Specify the number of vLLM engines." + ) + parser.add_argument( + "--vllm_tp_size", type=int, default=None, help="Specify the number of vLLM tp size." + ) + parser.add_argument( + "--explorer_trainer_ratio", + type=float, + default=0.6, + help="Specify the ratio of explorer engine num to trainer gpu num.", + ) + parser.add_argument( + "--model_path", + type=str, + default=None, + help="Specify the path to the model checkpoint.", + ) + parser.add_argument( + "--critic_model_path", + type=str, + default=None, + help="Specify the path to the critic model checkpoint.", + ) + parser.add_argument( + "--taskset_path", type=str, default=None, help="Specify the path to the taskset." + ) + parser.add_argument( + "--lr", type=float, default=None, help="Specify the learning rate for actor model." + ) + parser.add_argument( + "--critic_lr", type=float, default=None, help="Specify the learning rate for critic model." + ) + parser.add_argument( + "--sync_interval", type=int, default=None, help="Specify the sync interval." + ) + args = parser.parse_args() + main(args) diff --git a/benchmark/config/countdown-template.yaml b/benchmark/config/countdown-template.yaml new file mode 100644 index 0000000000..c213e024e2 --- /dev/null +++ b/benchmark/config/countdown-template.yaml @@ -0,0 +1,187 @@ +mode: both +project: Trinity-RFT +group: countdown-bench +name: countdown-qwen2.5-1.5B +checkpoint_root_dir: placeholder +algorithm: + algorithm_type: ppo + repeat_times: 5 + advantage_fn: ppo +data_processor: {} +model: + model_path: Qwen/Qwen2.5-1.5B-Instruct + max_prompt_tokens: 256 + max_response_tokens: 1024 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + batch_size: 96 + total_epochs: 20 + explorer_input: + taskset: + name: taskset + storage_type: file + path: null + split: train + subset_name: null + format: + prompt_key: question + response_key: answer + rollout_args: + temperature: 1.0 + logprobs: 0 + eval_tasksets: [] + default_workflow_type: math_workflow + default_reward_fn_type: countdown_reward + system_prompt: null + reply_prefix: null + trainer_input: + experience_buffer: + name: experience_buffer + storage_type: queue + path: '' + use_priority_queue: true + replay_buffer_kwargs: + priority_fn: linear_decay + decay: 0.1 + sft_warmup_steps: 0 + max_retry_times: 3 + max_retry_interval: 1 +explorer: + runner_num: 32 + max_timeout: 900 + max_retry_times: 2 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + use_v1: true + enforce_eager: true + enable_prefix_caching: false + enable_chunked_prefill: false + gpu_memory_utilization: 0.9 + dtype: bfloat16 + seed: 42 + enable_thinking: false + enable_openai_api: false + auxiliary_models: [] + eval_interval: 1000 + eval_on_startup: false + bench_on_latest_checkpoint: true +trainer: + trainer_type: verl + save_interval: 100 + enable_preview: true + actor_grad_clip: 1.0 + trainer_config: + actor_rollout_ref: + hybrid_engine: true + model: + external_lib: null + override_config: {} + enable_gradient_checkpointing: true + use_remove_padding: true + actor: + strategy: fsdp + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: true + ppo_max_token_len_per_gpu: 6400 + ppo_epochs: 1 + shuffle: false + ulysses_sequence_parallel_size: 1 + checkpoint: + load_contents: + - model + - optimizer + - extra + save_contents: + - model + - optimizer + - extra + optim: + lr: 1.0e-06 + lr_warmup_steps_ratio: 0.0 + warmup_style: constant + total_training_steps: -1 + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + fsdp_size: -1 + ref: + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + fsdp_size: -1 + log_prob_micro_batch_size_per_gpu: 8 + log_prob_use_dynamic_bsz: true + log_prob_max_token_len_per_gpu: 6400 + ulysses_sequence_parallel_size: 1 + custom_reward_function: + path: null + name: compute_score + algorithm: + kl_penalty: low_var_kl + kl_ctrl: + type: fixed + kl_coef: 0.001 + trainer: + balance_batch: true + total_training_steps: 1000 + resume_mode: auto + resume_from_path: '' + critic_warmup: 0 + default_hdfs_dir: null + remove_previous_ckpt_in_save: false + del_local_ckpt_after_load: false + val_before_train: false + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + critic: + strategy: fsdp + optim: + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + warmup_style: constant + total_training_steps: -1 + model: + override_config: {} + external_lib: null + enable_gradient_checkpointing: true + use_remove_padding: true + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + fsdp_size: -1 + ppo_micro_batch_size_per_gpu: 8 + forward_micro_batch_size_per_gpu: 8 + use_dynamic_bsz: true + ppo_max_token_len_per_gpu: 12800 + forward_max_token_len_per_gpu: 12800 + ulysses_sequence_parallel_size: 1 + ppo_epochs: 1 + shuffle: false + grad_clip: 1.0 + cliprange_value: 0.5 + checkpoint: + load_contents: + - model + - optimizer + - extra + save_contents: + - model + - optimizer + - extra +monitor: + monitor_type: wandb +synchronizer: + sync_method: nccl + sync_style: fixed + sync_interval: 10 + sync_timeout: 1200 diff --git a/benchmark/config/gsm8k-template.yaml b/benchmark/config/gsm8k-template.yaml new file mode 100644 index 0000000000..5e2ee99b77 --- /dev/null +++ b/benchmark/config/gsm8k-template.yaml @@ -0,0 +1,147 @@ +mode: both +project: Trinity-RFT +group: gsm8k-bench +name: gsm8k-qwen2.5-1.5B +checkpoint_root_dir: placeholder +algorithm: + algorithm_type: grpo + repeat_times: 8 + sample_strategy: warmup + policy_loss_fn: ppo + advantage_fn: grpo + kl_penalty_fn: none + kl_loss_fn: k2 + entropy_loss_fn: default +data_processor: {} +model: + model_path: Qwen/Qwen2.5-1.5B-Instruct + max_prompt_tokens: 256 + max_response_tokens: 1024 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + batch_size: 96 + total_epochs: 20 + explorer_input: + taskset: + name: taskset + storage_type: file + path: openai/gsm8k + split: train + subset_name: main + format: + prompt_key: question + response_key: answer + rollout_args: + temperature: 1.0 + logprobs: 0 + eval_tasksets: [] + default_workflow_type: math_workflow + default_reward_fn_type: math_reward + system_prompt: null + reply_prefix: null + trainer_input: + experience_buffer: + name: experience_buffer + storage_type: queue + path: '' + use_priority_queue: true + replay_buffer_kwargs: + priority_fn: linear_decay + decay: 0.1 + sft_warmup_steps: 0 + max_retry_times: 3 + max_retry_interval: 1 +explorer: + runner_per_model: 8 + max_timeout: 900 + max_retry_times: 2 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + use_v1: true + enforce_eager: false + enable_prefix_caching: false + enable_chunked_prefill: false + gpu_memory_utilization: 0.9 + dtype: bfloat16 + seed: 42 + enable_thinking: false + enable_openai_api: false + auxiliary_models: [] + eval_interval: 1000 + bench_on_latest_checkpoint: true +trainer: + trainer_type: verl + save_interval: 100 + enable_preview: true + actor_grad_clip: 1.0 + trainer_config: + actor_rollout_ref: + model: + external_lib: null + override_config: {} + enable_gradient_checkpointing: true + use_remove_padding: true + use_fused_kernels: false + # fused_kernel_options: + # impl_backend: triton + actor: + strategy: fsdp + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: true + ppo_max_token_len_per_gpu: 10240 + ppo_epochs: 1 + shuffle: false + ulysses_sequence_parallel_size: 1 + checkpoint: + load_contents: + - model + - optimizer + - extra + save_contents: + - model + - optimizer + - extra + optim: + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + warmup_style: constant + total_training_steps: -1 + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + fsdp_size: -1 + ref: + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + fsdp_size: -1 + log_prob_micro_batch_size_per_gpu: 8 + log_prob_use_dynamic_bsz: true + log_prob_max_token_len_per_gpu: 10240 + ulysses_sequence_parallel_size: 1 + trainer: + balance_batch: true + total_training_steps: 100 + resume_mode: auto + resume_from_path: '' + default_hdfs_dir: null + remove_previous_ckpt_in_save: false + del_local_ckpt_after_load: false + val_before_train: false + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null +monitor: + monitor_type: wandb +synchronizer: + sync_method: nccl + sync_style: fixed + sync_interval: 2 + sync_timeout: 1200 diff --git a/benchmark/scripts/gen-countdown-data.py b/benchmark/scripts/gen-countdown-data.py new file mode 100644 index 0000000000..ffaf41f00e --- /dev/null +++ b/benchmark/scripts/gen-countdown-data.py @@ -0,0 +1,115 @@ +""" +Modified from https://github.com/Jiayi-Pan/TinyZero/blob/main/examples/data_preprocess/countdown.py +Preprocess dataset for countdown task - given a target number and N numbers, generate equations to reach target +""" + +import argparse +import json +import os +from random import randint, seed +from typing import List, Tuple + +from datasets import load_dataset +from tqdm import tqdm +from verl.utils.hdfs_io import copy, makedirs + + +def gen_dataset( + num_samples: int, + num_operands: int = 6, + max_target: int = 1000, + min_number: int = 1, + max_number: int = 100, + operations: List[str] = ["+", "-", "*", "/"], + seed_value: int = 42, +) -> List[Tuple]: + """Generate dataset for countdown task. + + Args: + num_samples: Number of samples to generate + num_operands: Number of numbers provided in each sample + max_target: Maximum value for target number + min_number: Minimum value for provided numbers + max_number: Maximum value for provided numbers + operations: List of allowed operations + seed_value: Random seed for reproducibility + + Returns: + List of tuples containing (target, numbers, solution) + """ + seed(seed_value) + samples = [] + + for _ in tqdm(range(num_samples)): + # Generate random target + target = randint(1, max_target) + + # Generate random numbers + numbers = [randint(min_number, max_number) for _ in range(num_operands)] + + samples.append((target, numbers)) + + return samples + + +def make_prefix(dp): + target = dp["target"] + numbers = dp["nums"] + system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.""" + task_desc = f"""User: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n""" + final_prompt = f"{system_prompt}\n{task_desc}" + return final_prompt + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default="~/data/countdown") + parser.add_argument("--hdfs_dir", default=None) + parser.add_argument("--num_samples", type=int, default=100000) + parser.add_argument("--num_operands", type=int, default=6) + parser.add_argument("--max_target", type=int, default=1000) + parser.add_argument("--min_number", type=int, default=1) + parser.add_argument("--max_number", type=int, default=100) + parser.add_argument("--train_size", type=int, default=320000) + parser.add_argument("--test_size", type=int, default=7680) + + args = parser.parse_args() + + data_source = "countdown" + TRAIN_SIZE = args.train_size + TEST_SIZE = args.test_size + + raw_dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split="train") + + assert len(raw_dataset) > TRAIN_SIZE + TEST_SIZE + train_dataset = raw_dataset.select(range(TRAIN_SIZE)) + test_dataset = raw_dataset.select(range(TRAIN_SIZE, TRAIN_SIZE + TEST_SIZE)) + + def make_map_fn(split): + def process_fn(example, idx): + question = make_prefix(example) + data = { + "question": question, + "answer": json.dumps( + { + "numbers": example["nums"], + "target": example["target"], + } + ), + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + train_dataset.to_json(os.path.join(local_dir, "train.jsonl")) + test_dataset.to_json(os.path.join(local_dir, "test.jsonl")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + copy(src=local_dir, dst=hdfs_dir) diff --git a/docs/sphinx_doc/assets/countdown-bench.png b/docs/sphinx_doc/assets/countdown-bench.png new file mode 100644 index 0000000000..d8735ce835 Binary files /dev/null and b/docs/sphinx_doc/assets/countdown-bench.png differ diff --git a/docs/sphinx_doc/assets/gsm8k-bench.png b/docs/sphinx_doc/assets/gsm8k-bench.png new file mode 100644 index 0000000000..e79c36b5dc Binary files /dev/null and b/docs/sphinx_doc/assets/gsm8k-bench.png differ diff --git a/tests/utils/eval_utils_test.py b/tests/utils/eval_utils_test.py index 8105b692ce..4cb137bf73 100644 --- a/tests/utils/eval_utils_test.py +++ b/tests/utils/eval_utils_test.py @@ -32,7 +32,7 @@ def test_extract_answer(self): ] for i, (input_str, expected_output, description) in enumerate(test_cases): - with self.subTest(f"Case {i+1}: {description}"): + with self.subTest(f"Case {i + 1}: {description}"): actual_output = extract_answer(input_str) self.assertEqual( actual_output, @@ -58,7 +58,7 @@ def test_verify_math_answer(self): ] for i, (response, ground_truth, expected_correct, description) in enumerate(test_cases): - with self.subTest(f"Case {i+1}: {description}"): + with self.subTest(f"Case {i + 1}: {description}"): accuracy, details = verify_math_answer(response, ground_truth) is_correct = accuracy == 1.0 self.assertEqual( @@ -88,7 +88,7 @@ def test_is_equiv(self): ] for i, (str1, str2, expected_output, description) in enumerate(test_cases): - with self.subTest(f"Case {i+1}: {description}"): + with self.subTest(f"Case {i + 1}: {description}"): actual_output = is_equiv(str1, str2) self.assertEqual( actual_output, diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index ba422bf5dd..c8bb61752d 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -322,9 +322,7 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: self.logger.info("Waiting for all tasks to complete") await self.scheduler.wait_all() self.logger.info(f"All tasks before step {self.explore_step_num} have completed.") - log_task = asyncio.create_task( - self._finish_steps(self.last_sync_step + 1, self.explore_step_num, self.model_version) - ) + await self._finish_steps(self.last_sync_step + 1, self.explore_step_num, self.model_version) if sync_weight: # sync weights @@ -337,9 +335,6 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: f"Explorer sync_weights at step {self.explore_step_num} finished, model version = {self.model_version}." ) - # overlay log and weight sync - await log_task - # save explore checkpoint self.cache.save_explorer( current_step=self.explore_step_num, diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index 6d3eeb679e..e6b9667c8f 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -321,4 +321,5 @@ def get_actor(cls, config: Optional[Config] = None, namespace: Optional[str] = N .remote(config, module_ref=module_ref) ) synchronizer.add_module.remote(module_ref) + return synchronizer return ray.get_actor("synchronizer", namespace=namespace)