From 5ef4961ce99ad52c6f538e8d43a0b35670282750 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 27 Jan 2026 17:02:27 +0800 Subject: [PATCH 1/3] add thread level workflow runner --- .../source/tutorial/trinity_configs.md | 6 + .../source_zh/tutorial/trinity_configs.md | 7 + tests/explorer/scheduler_test.py | 73 ++++++++- tests/explorer/workflow_test.py | 106 ++++++++++++- trinity/common/config.py | 15 +- trinity/common/models/model.py | 23 ++- trinity/explorer/workflow_runner.py | 147 +++++++++++++++--- 7 files changed, 342 insertions(+), 35 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index e1f78a7c3b..048a52bb1b 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -390,6 +390,8 @@ Controls the rollout models and workflow execution. explorer: name: explorer runner_per_model: 8 + concurrent_mode: sequential + max_repeat_times_per_runner: null max_timeout: 900 max_retry_times: 2 env_vars: {} @@ -414,6 +416,10 @@ explorer: - `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique. - `runner_per_model`: Number of parallel workflow runners per each rollout model. +- `concurrent_mode`: Concurrency mode for executing a set of tasks (e.g., multiple repeats of a task in GRPO algorithm). Supported options: + - `sequential`: Executes tasks sequentially (default). One task is executed at a time, waiting for its completion before executing the next. This requires the least from the workflow implementation but has the worst throughput. + - `asynchronous`: Executes tasks asynchronously. All tasks are submitted at once, and results are collected as they complete. Requires the workflow to correctly implement asynchronous call interfaces and have no shared state between workflows to avoid race conditions. Throughput is better than sequential execution, but the performance depends on the workflow implementation. + - `multi-threading`: Executes tasks using multi-threading. Multiple tasks are executed simultaneously using threads. Requires the workflow implementation to be thread-safe to avoid race conditions. Throughput is usually better than sequential execution but may be lower than asynchronous execution, depending on the workflow implementation and system resources. - `max_timeout`: Maximum time (in seconds) for a workflow to complete. - `max_retry_times`: Maximum number of retries for a workflow. - `env_vars`: Environment variables to be set for every workflow runners. diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 49301d638a..321a89ba9c 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -387,6 +387,8 @@ buffer: explorer: name: explorer runner_per_model: 8 + concurrent_mode: sequential + max_repeat_times_per_runner: null max_timeout: 900 max_retry_times: 2 env_vars: {} @@ -411,6 +413,11 @@ explorer: - `name`: explorer 的名称。该名称将用作 Ray actor 的名称,因此必须唯一。 - `runner_per_model`: 每个推理引擎实例所服务的 WorkflowRunner 数量。 +- `concurrent_mode`: 执行一组任务(例如 GRPO 算法中对一个任务的多次重复执行)的并发模式。支持如下选项: + - `sequential`: 顺序执行(默认)。每次执行一个任务,等待其完成后再执行下一个任务。对 workflow 的实现要求最低,但吞吐量也最差。 + - `asynchronous`: 异步执行。使用异步模式一次性提交所有任务,并在任务完成后收集结果。要求 workflow 正确地实现了异步调用接口,并且 workflow 之间没有共享状态,以避免竞态条件。吞吐量优于顺序执行,但具体性能受限于 workflow 的实现。 + - `multi-threading`: 多线程执行。使用多线程同时执行多个任务。需要确保 workflow 的实现是线程安全的,以避免竞态条件。吞吐量通常优于顺序执行,但可能低于异步执行,具体取决于工作流的实现和系统资源。 +- `max_repeat_times_per_runner`: 将本来需要重复执行 `algorithm.repeat_times` 次的任务切分为多个子任务,每个子任务的 `repeat_times` 不超过该值,仅适用于 GRPO 类算法。如果未设置,则不限制重复次数。推荐在 `concurrent_mode` 为 `sequential` 时使用此参数,以避免单个 WorkflowRunner 长时间占用资源。 - `max_timeout`: 等待 Workflow 完成的最大时间(秒)。 - `max_retry_times`: Workflow 失败或超时情况下的最大重试次数。 - `env_vars`: 为每个 WorkflowRunner 设置的环境变量。 diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 2cd5e9a08d..fab68123b7 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -2,7 +2,7 @@ import time import unittest from collections import defaultdict -from typing import List, Optional +from typing import Dict, List, Optional, Sequence import ray import torch @@ -169,6 +169,27 @@ async def run_async(self) -> List[Experience]: return exps +@WORKFLOWS.register_module("dummy_concurrent_workflow") +class DummyConcurrentWorkflow(Workflow): + can_repeat: bool = False + is_async: bool = True + + def __init__(self, *, task, model, auxiliary_models): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + + async def run_async(self) -> List[Experience]: + await asyncio.sleep(1) + + return [ + Experience( + eid=EID(run=self.run_id_base, step=0), + tokens=torch.zeros(5), + prompt_length=2, + prompt_text="success", + ) + ] + + @ray.remote class DummyModel(InferenceModel): def __init__(self): @@ -200,6 +221,26 @@ def init_process_group( def get_api_server_url(self) -> Optional[str]: return None + async def chat(self, messages: List[Dict], lora_request=None, **kwargs) -> Sequence[Experience]: + prompt_length = sum(len(msg["content"]) for msg in messages) + return [ + Experience( + tokens=torch.zeros(prompt_length + 10), + prompt_length=prompt_length, + logprobs=torch.zeros(10), + ) + ] + + async def generate(self, prompt: str, lora_request=None, **kwargs) -> Sequence[Experience]: + prompt_length = len(prompt) + return [ + Experience( + tokens=torch.zeros(prompt_length + 5), + prompt_length=prompt_length, + logprobs=torch.zeros(5), + ) + ] + @ray.remote class DummyAuxiliaryModel(InferenceModel): @@ -889,3 +930,33 @@ async def monitor_routine(): monitor_routine(), scheduler.get_results(batch_id=0), ) + + +class TestRunnerConcurrent(unittest.IsolatedAsyncioTestCase): + async def test_runner_concurrent_execution(self): + ray.init(ignore_reinit_error=True) + config = get_template_config() + config.explorer.runner_per_model = 2 + config.explorer.max_repeat_times_per_runner = None + config.check_and_update() + scheduler = Scheduler(config, [DummyModel.remote(), DummyModel.remote()]) + # 4 runner in side the scheduler + await scheduler.start() + + num_tasks = 8 + tasks = [ + Task( + workflow=DummyWorkflowWithState, # type: ignore[type-abstract] + workflow_args={"step_num": 2}, + repeat_times=4, + raw_task={}, + ) + for _ in range(num_tasks) + ] + scheduler.schedule(tasks, batch_id=0) + + statuses, exps = await scheduler.get_results(batch_id=0) + self.assertEqual(len(statuses), num_tasks) + self.assertEqual(len(exps), num_tasks * 4 * 2) + + await scheduler.stop() diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 982d75a5ec..30aa9549e1 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """Test for the workflow module""" import asyncio +import time import unittest from collections import defaultdict from dataclasses import dataclass, field @@ -14,7 +15,12 @@ from torch import Tensor from tests.common.vllm_test import CHAT_TEMPLATE -from tests.tools import get_model_path, get_template_config, get_unittest_dataset_config +from tests.tools import ( + RayUnittestBaseAsync, + get_model_path, + get_template_config, + get_unittest_dataset_config, +) from trinity.common.config import InferenceModelConfig from trinity.common.experience import EID, Experience from trinity.common.models import create_inference_models @@ -822,3 +828,101 @@ async def test_workflow_with_openai(self): def tearDown(self): ray.shutdown(_exiting_interpreter=True) + + +class ConcurrentTestWorkflow(Workflow): + is_async: bool = True + + def __init__(self, model: ModelWrapper, task: Task, auxiliary_models=None): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + self.client = self.model.get_openai_async_client() + + async def run_async(self): + assert self.task.raw_task is not None + _ = await self.model.chat_async([{"role": "user", "content": self.task.raw_task["text"]}]) + await asyncio.sleep(1.0) + _ = await self.client.chat.completions.create( + model=self.client.model_path, + messages=[{"role": "user", "content": self.task.raw_task["text"]}], + ) + history_exps = self.model.extract_experience_from_history() + assert len(history_exps) == 2 + assert history_exps[0].prompt_length == history_exps[1].prompt_length + prompt_length = history_exps[0].prompt_length + assert ( + history_exps[0].tokens[:prompt_length].shape + == history_exps[1].tokens[:prompt_length].shape + ) + return history_exps + + +class TestConcurrentWorkflowRunner(RayUnittestBaseAsync): + async def test_concurrent_workflow_runner(self): + config = get_template_config() + config.mode = "explore" + config.model.model_path = get_model_path() + config.explorer.rollout_model.engine_num = 1 + config.explorer.rollout_model.enable_history = True + config.explorer.rollout_model.enable_openai_api = True + config.check_and_update() + engines, auxiliary_engines = create_inference_models(config) + await asyncio.gather(*[engine.prepare.remote() for engine in engines]) + + config.explorer.concurrent_mode = "sequential" + sequential_runner = WorkflowRunner( + config, + model=engines[0], + auxiliary_models=[], + runner_id=0, + ) + config.explorer.concurrent_mode = "asynchronous" + async_runner = WorkflowRunner( + config, + model=engines[0], + auxiliary_models=[], + runner_id=1, + ) + thread_runner = WorkflowRunner( + config, + model=engines[0], + auxiliary_models=[], + runner_id=2, + ) + await asyncio.gather( + sequential_runner.prepare(), + async_runner.prepare(), + thread_runner.prepare(), + ) + + task = Task( + workflow=ConcurrentTestWorkflow, + repeat_times=4, + raw_task={"text": "Hello, world!"}, + ) + # warmup + async_status, async_exps = await async_runner.run_task(task, repeat_times=2, run_id_base=0) + + st = time.time() + async_status, async_exps = await async_runner.run_task(task, repeat_times=4, run_id_base=0) + async_runtime = time.time() - st + st = time.time() + thread_status, thread_exps = await thread_runner.run_task( + task, repeat_times=4, run_id_base=0 + ) + thread_runtime = time.time() - st + st = time.time() + sequential_status, sequential_exps = await sequential_runner.run_task( + task, repeat_times=4, run_id_base=0 + ) + sequential_runtime = time.time() - st + + self.assertTrue(async_status.ok) + self.assertTrue(thread_status.ok) + self.assertTrue(sequential_status.ok) + + self.assertEqual(len(async_exps), 8) + self.assertEqual(len(thread_exps), 8) + self.assertEqual(len(sequential_exps), 8) + + self.assertLessEqual(async_runtime * 2, sequential_runtime) + self.assertLessEqual(thread_runtime * 2, sequential_runtime) diff --git a/trinity/common/config.py b/trinity/common/config.py index 3d60fd5371..acfee66d5a 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -676,9 +676,18 @@ class ExplorerConfig: max_timeout: int = 1800 # wait each task for 30 minutes at most max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout env_vars: dict = field(default_factory=dict) # environment variables for workflow runner - max_repeat_times_per_runner: Optional[ - int - ] = None # the number of time to repeat each task in a single workflow runner (for GRPO-like algorithms) + + # Workflow Runner Configs for tasks requiring group execution + # how to run a group of tasks in a single workflow runner + # "sequential": run tasks one by one, no requirements on workflow design, but have lower throughput + # "asynchronous": run tasks asynchronously, requires the workflow to be designed with async/await + # syntax, and no sharing of state between tasks + # "multi-threading": run tasks using multi-threading, requires the workflow to be thread-safe, + # and no sharing of state between tasks + concurrent_mode: str = "sequential" + # the number of time to repeat each task in a single workflow runner + # we recommend setting this only when using "sequential" concurrent_mode + max_repeat_times_per_runner: Optional[int] = None runner_num: Optional[int] = None # ! Deprecated diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index ed982fe9b9..21b235d2a1 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """Base Model Class""" import asyncio +import copy import socket from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -229,7 +230,6 @@ def __init__( engine_type: str = "vllm", enable_lora: bool = False, enable_history: bool = False, - enable_thinking: Optional[bool] = None, ): """Initialize the ModelWrapper. @@ -254,7 +254,6 @@ def __init__( self.logger = get_logger(__name__) self.enable_lora = enable_lora self.enable_history = enable_history - self.enable_thinking = enable_thinking self.history = [] self.status = RunningStatus.RUNNING self.workflow_state: Dict = {} @@ -523,10 +522,12 @@ def chat_completions(*args, **kwargs): def record_chat_completions(*args, **kwargs): logprobs = kwargs.pop("logprobs", True) extra_body = kwargs.pop("extra_body", {}) - if self.enable_thinking is not None: + if self.config.enable_thinking is not None: if "chat_template_kwargs" not in extra_body: extra_body["chat_template_kwargs"] = {} - extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking + extra_body["chat_template_kwargs"][ + "enable_thinking" + ] = self.config.enable_thinking extra_body["return_token_ids"] = True response = ori_create(*args, extra_body=extra_body, logprobs=logprobs, **kwargs) self.history.extend(convert_api_output_to_experience(response)) @@ -580,10 +581,12 @@ async def chat_completions(*args, **kwargs): async def record_chat_completions(*args, **kwargs): logprobs = kwargs.pop("logprobs", True) extra_body = kwargs.pop("extra_body", {}) - if self.enable_thinking is not None: + if self.config.enable_thinking is not None: if "chat_template_kwargs" not in extra_body: extra_body["chat_template_kwargs"] = {} - extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking + extra_body["chat_template_kwargs"][ + "enable_thinking" + ] = self.config.enable_thinking extra_body["return_token_ids"] = True response = await ori_create( *args, extra_body=extra_body, logprobs=logprobs, **kwargs @@ -637,6 +640,14 @@ async def get_workflow_state(self) -> Dict: async with self.state_lock: return self.workflow_state.copy() + def clone_with_isolated_history(self) -> "ModelWrapper": + """Clone the current ModelWrapper with isolated history.""" + new_wrapper = copy.copy(self) + new_wrapper.openai_async_client = None + new_wrapper.openai_client = None + new_wrapper.history = [] + return new_wrapper + def convert_api_output_to_experience( output, diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index aeadba1ffb..fc1bc0cd02 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -83,6 +83,18 @@ def __init__( "begin_time": 0, "terminate_time": 0, } + self.concurrent_mode = config.explorer.concurrent_mode + if self.concurrent_mode == "sequential": + self.concurrent_run_fn = self._sequential_run + elif self.concurrent_mode == "asynchronous": + self.concurrent_run_fn = self._asynchronous_run + elif self.concurrent_mode == "multi-threading": + self.concurrent_run_fn = self._multi_threading_run + else: + self.logger.warning( + f"Unknown concurrent_mode {self.concurrent_mode}, defaulting to sequential." + ) + self.concurrent_run_fn = self._sequential_run async def prepare(self) -> None: """Prepare the runner.""" @@ -94,7 +106,7 @@ async def prepare(self) -> None: def is_alive(self): return True - def _create_workflow_instance(self, task: Task) -> None: + def _create_workflow_instance(self, task: Task) -> Workflow: if task.workflow is None: raise ValueError("Workflow is not set in the task.") if ( @@ -109,6 +121,7 @@ def _create_workflow_instance(self, task: Task) -> None: ) else: self.workflow_instance.reset(task) + return self.workflow_instance async def _run_workflow(self, workflow_instance: Workflow) -> List[Experience]: if workflow_instance.asynchronous: @@ -121,16 +134,16 @@ async def _run_task( self, task: Task, repeat_times: int, run_id_base: int ) -> Tuple[List[Experience], List[Dict]]: """Init workflow from the task and run it.""" - self._create_workflow_instance(task) - if self.workflow_instance.repeatable: - self.workflow_instance.set_repeat_times(repeat_times, run_id_base) + if task.workflow.can_repeat: + workflow_instance = self._create_workflow_instance(task) + workflow_instance.set_repeat_times(repeat_times, run_id_base) st = time.time() await self.model_wrapper.clean_workflow_state() self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{run_id_base}" self.runner_state["terminate_time"] = None self.runner_state["begin_time"] = st - exps = await self._run_workflow(self.workflow_instance) + exps = await self._run_workflow(workflow_instance) et = time.time() self.runner_state["terminate_time"] = et # repeatable workflow cannot calculate run level metrics, we use experience level metrics directly @@ -138,25 +151,111 @@ async def _run_task( for metric in run_metrics: metric["time/run_execution"] = et - st else: - exps = [] - run_metrics = [] - for i in range(repeat_times): - st = time.time() - await self.model_wrapper.clean_workflow_state() - self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" - self.runner_state["terminate_time"] = None - self.runner_state["begin_time"] = st - new_exps = await self._run_workflow(self.workflow_instance) - et = time.time() - self.runner_state["terminate_time"] = et - run_metric = calculate_run_level_metrics(new_exps) - run_metric["time/run_execution"] = et - st - run_metrics.append(run_metric) - for exp in new_exps: - exp.eid.run = run_id_base + i - exps.extend(new_exps) - if i < repeat_times - 1: - self._create_workflow_instance(task) + exps, run_metrics = await self.concurrent_run_fn(task, repeat_times, run_id_base) + return exps, run_metrics + + async def _sequential_run( + self, + task: Task, + repeat_times: int, + run_id_base: int, + ) -> Tuple[List[Experience], List[Dict]]: + exps = [] + run_metrics = [] + for i in range(repeat_times): + st = time.time() + workflow = self._create_workflow_instance(task) + await self.model_wrapper.clean_workflow_state() + self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" + self.runner_state["terminate_time"] = None + self.runner_state["begin_time"] = st + new_exps = await self._run_workflow(workflow) + et = time.time() + self.runner_state["terminate_time"] = et + run_metric = calculate_run_level_metrics(new_exps) + run_metric["time/run_execution"] = et - st + run_metrics.append(run_metric) + for exp in new_exps: + exp.eid.run = run_id_base + i + exps.extend(new_exps) + return exps, run_metrics + + async def _asynchronous_run( + self, + task: Task, + repeat_times: int, + run_id_base: int, + ) -> Tuple[List[Experience], List[Dict]]: + async def run_single(i: int) -> Tuple[List[Experience], Dict]: + st = time.time() + workflow = task.to_workflow( + self.model_wrapper.clone_with_isolated_history() + if self.config.explorer.rollout_model.enable_history + else self.model_wrapper, + self.auxiliary_model_wrappers, + ) + await self.model_wrapper.clean_workflow_state() + self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" + self.runner_state["terminate_time"] = None + self.runner_state["begin_time"] = st + new_exps = await self._run_workflow(workflow) + et = time.time() + self.runner_state["terminate_time"] = et + run_metric = calculate_run_level_metrics(new_exps) + run_metric["time/run_execution"] = et - st + for exp in new_exps: + exp.eid.run = run_id_base + i + return new_exps, run_metric + + tasks = [run_single(i) for i in range(repeat_times)] + results = await asyncio.gather(*tasks) + exps = [] + run_metrics = [] + for new_exps, run_metric in results: + exps.extend(new_exps) + run_metrics.append(run_metric) + return exps, run_metrics + + async def _multi_threading_run( + self, + task: Task, + repeat_times: int, + run_id_base: int, + ) -> Tuple[List[Experience], List[Dict]]: + loop = asyncio.get_event_loop() + + def run_single(i: int) -> Tuple[List[Experience], Dict]: + st = time.time() + asyncio.run(self.model_wrapper.clean_workflow_state()) + workflow = task.to_workflow( + self.model_wrapper.clone_with_isolated_history() + if self.config.explorer.rollout_model.enable_history + else self.model_wrapper, + self.auxiliary_model_wrappers, + ) + self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" + self.runner_state["terminate_time"] = None + self.runner_state["begin_time"] = st + new_exps = asyncio.run(self._run_workflow(workflow)) + et = time.time() + self.runner_state["terminate_time"] = et + run_metric = calculate_run_level_metrics(new_exps) + run_metric["time/run_execution"] = et - st + for exp in new_exps: + exp.eid.run = run_id_base + i + return new_exps, run_metric + + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=repeat_times) as executor: + futures = [loop.run_in_executor(executor, run_single, i) for i in range(repeat_times)] + results = await asyncio.gather(*futures) + + exps = [] + run_metrics = [] + for new_exps, run_metric in results: + exps.extend(new_exps) + run_metrics.append(run_metric) return exps, run_metrics async def get_runner_state(self) -> Dict: From 560365b04e863453e3163c3200392fdb899bc23e Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 28 Jan 2026 12:19:47 +0800 Subject: [PATCH 2/3] fix comments --- .../source/tutorial/trinity_configs.md | 1 + tests/explorer/scheduler_test.py | 30 ------------------- trinity/explorer/workflow_runner.py | 20 ++++++------- 3 files changed, 11 insertions(+), 40 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 048a52bb1b..3053412977 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -420,6 +420,7 @@ explorer: - `sequential`: Executes tasks sequentially (default). One task is executed at a time, waiting for its completion before executing the next. This requires the least from the workflow implementation but has the worst throughput. - `asynchronous`: Executes tasks asynchronously. All tasks are submitted at once, and results are collected as they complete. Requires the workflow to correctly implement asynchronous call interfaces and have no shared state between workflows to avoid race conditions. Throughput is better than sequential execution, but the performance depends on the workflow implementation. - `multi-threading`: Executes tasks using multi-threading. Multiple tasks are executed simultaneously using threads. Requires the workflow implementation to be thread-safe to avoid race conditions. Throughput is usually better than sequential execution but may be lower than asynchronous execution, depending on the workflow implementation and system resources. +- `max_repeat_times_per_runner`: Splits tasks that originally need to be repeated `algorithm.repeat_times` times into multiple subtasks, where each subtask's `repeat_times` does not exceed this value. This parameter is only applicable to GRPO-like algorithms. If not set, there is no limit on the number of repeats. It is recommended to use this parameter when `concurrent_mode` is `sequential` to reduce the end to end latency and improve throughput. - `max_timeout`: Maximum time (in seconds) for a workflow to complete. - `max_retry_times`: Maximum number of retries for a workflow. - `env_vars`: Environment variables to be set for every workflow runners. diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index fab68123b7..6698063466 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -930,33 +930,3 @@ async def monitor_routine(): monitor_routine(), scheduler.get_results(batch_id=0), ) - - -class TestRunnerConcurrent(unittest.IsolatedAsyncioTestCase): - async def test_runner_concurrent_execution(self): - ray.init(ignore_reinit_error=True) - config = get_template_config() - config.explorer.runner_per_model = 2 - config.explorer.max_repeat_times_per_runner = None - config.check_and_update() - scheduler = Scheduler(config, [DummyModel.remote(), DummyModel.remote()]) - # 4 runner in side the scheduler - await scheduler.start() - - num_tasks = 8 - tasks = [ - Task( - workflow=DummyWorkflowWithState, # type: ignore[type-abstract] - workflow_args={"step_num": 2}, - repeat_times=4, - raw_task={}, - ) - for _ in range(num_tasks) - ] - scheduler.schedule(tasks, batch_id=0) - - statuses, exps = await scheduler.get_results(batch_id=0) - self.assertEqual(len(statuses), num_tasks) - self.assertEqual(len(exps), num_tasks * 4 * 2) - - await scheduler.stop() diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index fc1bc0cd02..640727aede 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -222,11 +222,9 @@ async def _multi_threading_run( repeat_times: int, run_id_base: int, ) -> Tuple[List[Experience], List[Dict]]: - loop = asyncio.get_event_loop() - - def run_single(i: int) -> Tuple[List[Experience], Dict]: + async def run_single(i: int) -> Tuple[List[Experience], Dict]: st = time.time() - asyncio.run(self.model_wrapper.clean_workflow_state()) + await self.model_wrapper.clean_workflow_state() workflow = task.to_workflow( self.model_wrapper.clone_with_isolated_history() if self.config.explorer.rollout_model.enable_history @@ -236,7 +234,7 @@ def run_single(i: int) -> Tuple[List[Experience], Dict]: self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" self.runner_state["terminate_time"] = None self.runner_state["begin_time"] = st - new_exps = asyncio.run(self._run_workflow(workflow)) + new_exps = await self._run_workflow(workflow) et = time.time() self.runner_state["terminate_time"] = et run_metric = calculate_run_level_metrics(new_exps) @@ -245,11 +243,13 @@ def run_single(i: int) -> Tuple[List[Experience], Dict]: exp.eid.run = run_id_base + i return new_exps, run_metric - from concurrent.futures import ThreadPoolExecutor - - with ThreadPoolExecutor(max_workers=repeat_times) as executor: - futures = [loop.run_in_executor(executor, run_single, i) for i in range(repeat_times)] - results = await asyncio.gather(*futures) + # Use asyncio.to_thread to run async tasks in threads + results = await asyncio.gather( + *( + asyncio.to_thread(lambda idx=i: asyncio.run(run_single(idx))) # type: ignore[misc] + for i in range(repeat_times) + ) + ) exps = [] run_metrics = [] From 0bb6e29ad63f28fa7c26460ded4a936a3a1f93c4 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 28 Jan 2026 13:22:51 +0800 Subject: [PATCH 3/3] check runner_per_model --- trinity/common/config_validator.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/trinity/common/config_validator.py b/trinity/common/config_validator.py index 68bb0e39d1..34a082c4e6 100644 --- a/trinity/common/config_validator.py +++ b/trinity/common/config_validator.py @@ -579,6 +579,21 @@ def validate(self, config: Config) -> None: self._validate_lora(config) + # check concurrent mode + if config.explorer.concurrent_mode not in ["sequential", "asynchronous", "multi-threading"]: + raise ValueError(f"Invalid explorer.concurrent_mode: {config.explorer.concurrent_mode}") + if config.explorer.concurrent_mode in ["asynchronous", "multi-threading"]: + batch_size = config.buffer.batch_size + max_runner_per_model = math.ceil(batch_size / config.explorer.rollout_model.engine_num) + if config.explorer.runner_per_model > max_runner_per_model: + self.logger.warning( + f"explorer.runner_per_model ({config.explorer.runner_per_model}) is too large " + f"for concurrent_mode '{config.explorer.concurrent_mode}' with batch_size " + f"({batch_size}) and rollout_model.engine_num ({config.explorer.rollout_model.engine_num}). " + f"It is set to {max_runner_per_model}." + ) + config.explorer.runner_per_model = max_runner_per_model + def _validate_lora(self, config: Config) -> None: """Process and validate LoRA configuration settings.