From 937bd2f69ebf440d0c2f1df0dfa1965d64d68ffa Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 20 Oct 2025 18:04:18 +0800 Subject: [PATCH 1/6] fix examples and tutorial --- .../source/tutorial/develop_workflow.md | 37 +++++---------- .../source_zh/tutorial/develop_workflow.md | 35 +++++--------- tests/explorer/scheduler_test.py | 27 +++-------- tests/explorer/workflow_test.py | 33 ++++--------- .../agentscope/react/react_workflow.py | 12 +---- .../common/workflows/agentscope_workflow.py | 17 +------ .../workflows/customized_math_workflows.py | 4 +- .../customized_toolcall_workflows.py | 6 +-- .../agentscope/agentscopev0_react_workflow.py | 10 +--- .../agentscope/agentscopev1_react_workflow.py | 16 ++----- .../agentscopev1_search_workflow.py | 16 ++----- .../envs/alfworld/RAFT_alfworld_workflow.py | 16 ++----- .../RAFT_reflect_alfworld_workflow.py | 4 -- .../envs/alfworld/alfworld_workflow.py | 6 +-- .../workflows/envs/email_searcher/workflow.py | 15 ++---- .../envs/sciworld/sciworld_workflow.py | 6 +-- .../envs/webshop/webshop_workflow.py | 7 ++- trinity/common/workflows/eval_workflow.py | 12 +---- trinity/common/workflows/math_rm_workflow.py | 4 +- .../common/workflows/math_ruler_workflow.py | 4 +- .../common/workflows/simple_mm_workflow.py | 4 +- .../common/workflows/step_wise_workflow.py | 16 +------ trinity/common/workflows/workflow.py | 47 +++++++++---------- trinity/explorer/workflow_runner.py | 3 ++ 24 files changed, 101 insertions(+), 256 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/develop_workflow.md b/docs/sphinx_doc/source/tutorial/develop_workflow.md index b40c87801c..f6ded5eaad 100644 --- a/docs/sphinx_doc/source/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source/tutorial/develop_workflow.md @@ -205,20 +205,20 @@ __all__ = [ ##### Avoid Re-initialization For heavy workflows, re-initializing every time can incurs extra computational costs. -In this case, you can implement the `resettable` property and `reset` method to avoid re-initialization. +In this case, you can set the `can_reset` property and implement `reset` method to avoid re-initialization. + +The `can_reset` is a class property that indicates whether the workflow supports resetting. -The `resettable` property returns a boolean indicating whether the workflow supports resetting. The `reset` method accepts a `Task` parameter and resets the workflow's internal state based on the new task. ```python @WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): + can_reset: bool = True + # some code # ... - def resettable(self): - return True - def reset(self, task: Task): self.question = task.raw_task.get("question") self.answer = task.raw_task.get("answer") @@ -227,20 +227,18 @@ class ExampleWorkflow(Workflow): ##### Support Batch Inference In many popular RL algorithms, multiple runs of the same task are required (e.g., GRPO). In such scenarios, you can directly use batch inference to obtain multiple responses for a single question to improve efficiency. -For this case, you can implement the `repeatable` property and `set_repeat_times` method. +For this case, you can implement the `can_repeat` property and `set_repeat_times` method. + +The `can_repeat` is a class property that indicates whether the workflow supports multiple executions within the `run` method. -The `repeatable` property returns a boolean indicating whether the workflow supports multiple executions within the `run` method. The `set_repeat_times` method accepts two parameters: `repeat_times` specifies the number of times to execute within the `run` method, and `run_id_base` is an integer used to identify the first run ID in multiple runs (this parameter is used in multi-turn interaction scenarios; for tasks that can be completed with a single model call, this can be ignored). ```python @WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): + can_repeat: bool = True # some code - @property - def repeatable(self) -> bool: - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base @@ -279,6 +277,8 @@ class ExampleWorkflow(Workflow): ```python @WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): + can_reset: bool = True + can_repeat: bool = True def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) @@ -319,18 +319,10 @@ class ExampleWorkflow(Workflow): ) return experiences - @property - def resettable(self): - return True - def reset(self, task: Task): self.question = task.raw_task.get("question") self.answer = task.raw_task.get("answer") - @property - def repeatable(self) -> bool: - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base @@ -364,15 +356,12 @@ trinity run --config #### Async Support -The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can implement the `asynchronous` property and set it to `True`, then implement the `run_async` method. In this case, you no longer need to implement the `run` method, while other methods and properties remain unaffected. +The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can set `is_async` to `True`, then implement the `run_async` method. In this case, you no longer need to implement the `run` method, while other methods and properties remain unaffected. ```python @WORKFLOWS.register_module("example_workflow_async") class ExampleWorkflowAsync(Workflow): - - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: # your async code here diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md index 3cb85787e4..f178491643 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md @@ -200,22 +200,20 @@ __all__ = [ ##### 避免重复初始化 对于较为复杂的工作流,每次重新初始化会带来额外计算开销。 -此时,你可以实现 `resettable` 和 `reset` 方法以避免重复初始化。 +此时,你可以设置 `can_reset` 属性并实现 `reset` 方法以避免重复初始化。 -`resettable` 方法返回一个布尔值,指示工作流是否支持轻量化重置。 +`can_reset` 是一个类属性,表示工作流是否支持轻量化重置。 `reset` 方法接受一个新的 `Task` 实例,并使用该实例更新工作流的状态。 ```python @WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): + can_reset: bool = True + # some code # ... - @property - def resettable(self): - return True - def reset(self, task: Task): self.question = task.raw_task.get("question") self.answer = task.raw_task.get("answer") @@ -224,21 +222,18 @@ class ExampleWorkflow(Workflow): ##### 批量运行推理任务 当前流行的很多 RL 算法需要多次运行同一个任务(例如 GRPO)。该场景下一些简单任务可以直接通过模型批量推理来获得一个问题的多个回复以提升效率。 -针对该情况,你可以实现 `repeatable` 属性以及 `set_repeat_times` 方法。 +针对该情况,你可以设置 `can_repeat` 属性并实现 `set_repeat_times` 方法。 -`repeatable` 属性返回一个布尔值,指示工作流是否支持在 `run` 方法内多次执行。 +`can_repeat` 是一个类属性,指示工作流是否支持在 `run` 方法内多次执行。 `set_repeat_times` 方法接受两个参数:`repeat_times` 指定了在 `run` 方法内需要执行的次数,`run_id_base` 是一个整数,用于标识多次运行中第一次的运行 ID,之后各次的 ID 基于此递增(该参数用于多轮交互场景,单次模型调用即可完成的任务可以忽略该项)。 ```python @WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): + can_repeat: bool = True # some code - @property - def repeatable(self) -> bool: - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base @@ -277,6 +272,8 @@ class ExampleWorkflow(Workflow): ```python @WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): + can_reset: bool = True + can_repeat: bool = True def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) @@ -317,18 +314,10 @@ class ExampleWorkflow(Workflow): ) return experiences - @property - def resettable(self): - return True - def reset(self, task: Task): self.question = task.raw_task.get("question") self.answer = task.raw_task.get("answer") - @property - def repeatable(self) -> bool: - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base @@ -362,15 +351,13 @@ trinity run --config #### async 支持 -本节样例主要针对同步模式,如果你的工作流需要使用异步方法(例如异步 API),你可以实现 `asynchronous` 属性并将其设置为 `True`,然后实现 `run_async` 方法,在这种情况下不再需要实现 `run` 方法,其余方法和属性不受影响。 +本节样例主要针对同步模式,如果你的工作流需要使用异步方法(例如异步 API),你可以将 `is_async` 属性设置为 `True`,然后实现 `run_async` 方法,在这种情况下不再需要实现 `run` 方法,其余方法和属性不受影响。 ```python @WORKFLOWS.register_module("example_workflow_async") class ExampleWorkflowAsync(Workflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: # your async code here diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 36c73063a9..446b191e6b 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -18,6 +18,8 @@ @WORKFLOWS.register_module("dummy_workflow") class DummyWorkflow(Workflow): + can_repeat: bool = True + def __init__(self, *, task, model, auxiliary_models): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.step_num = task.workflow_args.get("step_num", 1) @@ -30,10 +32,6 @@ def __init__(self, *, task, model, auxiliary_models): else: self.seconds = 10 - @property - def repeatable(self): - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base @@ -63,19 +61,13 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("dummy_nonrepeat_workflow") class DummyNonRepeatWorkflow(Workflow): + can_reset: bool = True + def __init__(self, *, task, model, auxiliary_models): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.reset_flag = False self.step_num = task.workflow_args.get("step_num", 1) - @property - def resettable(self): - return True - - @property - def repeatable(self): - return False - def reset(self, task: Task): self.task = task self.reset_flag = True @@ -95,18 +87,13 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("dummy_async_workflow") class DummyAsyncWorkflow(Workflow): + can_repeat: bool = True + is_async: bool = True + def __init__(self, *, task, model, auxiliary_models): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.step_num = task.workflow_args.get("step_num", 1) - @property - def asynchronous(self): - return True - - @property - def repeatable(self): - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index e5dab0004b..8623500258 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -40,20 +40,15 @@ class MockResponse: class DummyWorkflow(Workflow): + can_reset: bool = True + can_repeat: bool = True + def __init__(self, model, task: Task, auxiliary_models=None): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.obj = task.raw_task self.output_format = task.workflow_args["output_format"] self.repeat_times = task.rollout_args.n - @property - def resettable(self): - return True - - @property - def repeatable(self): - return True - def reset(self, task: Task): self.obj = task.raw_task self.output_format = task.workflow_args["output_format"] @@ -76,24 +71,16 @@ def run(self): class DummyAsyncWorkflow(Workflow): + can_reset: bool = True + can_repeat: bool = True + is_async: bool = True + def __init__(self, model, task: Task, auxiliary_models=None): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.obj = task.raw_task self.output_format = task.workflow_args["output_format"] self.repeat_times = task.rollout_args.n - @property - def resettable(self): - return True - - @property - def repeatable(self): - return True - - @property - def asynchronous(self): - return True - def reset(self, task: Task): self.obj = task.raw_task self.output_format = task.workflow_args["output_format"] @@ -133,14 +120,12 @@ def run(self): class DummyAsyncMultiTurnWorkflow(MultiTurnWorkflow): + is_async: bool = True + def __init__(self, model, task: Task, auxiliary_models=None): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.contents = task.raw_task["contents"] # type: ignore - @property - def asynchronous(self): - return True - async def run_async(self): memory = [{"role": "system", "content": "You are a helpful assistant."}] experience_list = [] diff --git a/trinity/common/workflows/agentscope/react/react_workflow.py b/trinity/common/workflows/agentscope/react/react_workflow.py index 247ba085d9..044c7935f1 100644 --- a/trinity/common/workflows/agentscope/react/react_workflow.py +++ b/trinity/common/workflows/agentscope/react/react_workflow.py @@ -16,6 +16,8 @@ @WORKFLOWS.register_module("as_react_workflow") class AgentScopeReActWorkflow(Workflow): + is_async: bool = True + def __init__( self, *, @@ -95,13 +97,3 @@ def construct_experiences(self, reward: Union[float, Dict[str, float]]) -> List[ if isinstance(reward, dict): exp.metrics.update(reward) return exps - - @property - def asynchronous(self): - """AgentScope's ReAct agent only supports asynchronous calls, so we set this to True.""" - return True - - @property - def repeatable(self): - """This workflow is not repeatable.""" - return False diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index 0cbce6e01a..359afa5b65 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -11,6 +11,8 @@ class AgentScopeWorkflowAdapter(Workflow): """Adapter to wrap a agentscope trainable workflow function into a Trinity Workflow.""" + is_async: bool = True + def __init__( self, *, @@ -45,21 +47,6 @@ def __init__( model.get_openai_async_client(), ) - @property - def asynchronous(self) -> bool: - """This workflow runs asynchronously.""" - return True - - @property - def repeatable(self) -> bool: - """This workflow is not repeatable.""" - return False - - @property - def resetable(self) -> bool: - """This workflow cannot be reset.""" - return False - def construct_experiences( self, reward: float, diff --git a/trinity/common/workflows/customized_math_workflows.py b/trinity/common/workflows/customized_math_workflows.py index f76983a324..65bed1ac0d 100644 --- a/trinity/common/workflows/customized_math_workflows.py +++ b/trinity/common/workflows/customized_math_workflows.py @@ -95,9 +95,7 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("async_math_boxed_workflow") class AsyncMathBoxedWorkflow(MathBoxedWorkflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: if not self.use_base: diff --git a/trinity/common/workflows/customized_toolcall_workflows.py b/trinity/common/workflows/customized_toolcall_workflows.py index 1e31e492dd..f723c0e347 100644 --- a/trinity/common/workflows/customized_toolcall_workflows.py +++ b/trinity/common/workflows/customized_toolcall_workflows.py @@ -215,6 +215,8 @@ class ToolCallWorkflow(SimpleWorkflow): Only support qwen model for now. You can change the prompt construction and reward calculation by yourself for other models. """ + is_async: bool = True + def reset(self, task: Task): self.format_args = task.format_args self.system_prompt = task.format_args.system_prompt @@ -227,10 +229,6 @@ def reset(self, task: Task): self.workflow_args = task.workflow_args self.reward_fn_args = task.reward_fn_args - @property - def asynchronous(self): - return True - def format_prompt(self): raw_task = self.raw_task messages = construct_prompt(raw_task) diff --git a/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py index e2c1fe8f7b..5c8dbe82f8 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py @@ -19,6 +19,8 @@ class AgentScopeV0ReactMathWorkflow(Workflow): We use the AgentScope V0 version here. The code will be deprecated soon. """ + can_reset: bool = True + def __init__( self, *, @@ -72,10 +74,6 @@ def __init__( ) self.reset(task) - @property - def resettable(self): - return True - def reset(self, task: Task): self.system_prompt = """ You are an agent specialized in solving math problems with tools. Please solve the math problem given to you. You can write and execute Python code to perform calculation or verify your answer. You should return your final answer within \\boxed{{}}. @@ -115,10 +113,6 @@ def reset(self, task: Task): # we use the boxed format to evaluate the answer self.reward_fn = MathBoxedRewardFn() - @property - def repeatable(self): - return False - def run(self): # make sure that we have the correct import try: diff --git a/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py index a20c88063c..079b8a940d 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py @@ -17,6 +17,9 @@ class AgentScopeReactMathWorkflow(Workflow): We use the AgentScope V1 version here. """ + can_reset: bool = True + is_async: bool = True + def __init__( self, *, @@ -57,10 +60,6 @@ def __init__( self.agent_model_formatter = OpenAIChatFormatter() self.reset(task) - @property - def resettable(self): - return True - def reset(self, task: Task): self.system_prompt = """ You are an agent specialized in solving math problems with tools. Please solve the math problem given to you. You can write and execute Python code to perform calculation or verify your answer. You should return your final answer within \\boxed{{}}. @@ -104,15 +103,6 @@ def reset(self, task: Task): # we use the boxed format to evaluate the answer self.reward_fn = MathBoxedRewardFn() - @property - def repeatable(self): - return False - - @property - def asynchronous(self): - """Whether the workflow runs in async mode.""" - return True - async def run_async(self): # make sure that we have the correct import try: diff --git a/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py index b32eeefe9a..57f45d9af3 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py @@ -17,6 +17,9 @@ class AgentScopeV1ReactSearchWorkflow(Workflow): This workflow serves as an example of how to use the agentscope framework within the trinity workflow. """ + can_reset: bool = True + is_async: bool = True + def __init__( self, *, @@ -58,19 +61,6 @@ def __init__( self.reset(task) - @property - def resettable(self): - return True - - @property - def asynchronous(self): - """Whether the workflow runs in async mode.""" - return True - - @property - def repeatable(self): - return False - def reset(self, task: Task): self.workflow_args = task.workflow_args self.max_turns = int(self.workflow_args.get("max_turns", 10)) diff --git a/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py index 0b2e6fc93e..e9e3f30206 100644 --- a/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py @@ -26,6 +26,10 @@ class RAFTAlfworldWorkflow(Workflow): 2. Generate SFT data from successful attempt """ + can_reset: bool = True + can_repeat: bool = True + is_async: bool = True + def __init__( self, model: ModelWrapper, @@ -55,10 +59,6 @@ def __init__( ) self.reset(task) - @property - def asynchronous(self): - return True - def reset(self, task: Task): """Reset the workflow with a new task""" self.game_file_path = task.task_desc or task.raw_task.get("game_file", "") @@ -220,14 +220,6 @@ async def run_async(self) -> List[Experience]: ) return [experience] - def resettable(self) -> bool: - """Indicate that this workflow can be reset to avoid re-initialization""" - return True - - @property - def repeatable(self) -> bool: - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base diff --git a/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py index bf8ff85b09..81aed995d9 100644 --- a/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py @@ -59,10 +59,6 @@ def __init__( f"Initializing RAFTReflectAlfworldWorkflow with RAFT learning, temperature={self.temperature}" ) - @property - def asynchronous(self): - return True - async def construct_sft_data( self, first_trajectory: List[Dict[str, str]], diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py index 36370b55f1..f722e82ce6 100644 --- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py @@ -97,6 +97,8 @@ def parse_action(response): class AlfworldWorkflow(MultiTurnWorkflow): """A workflow for alfworld task.""" + is_async: bool = True + def __init__( self, model: ModelWrapper, @@ -111,10 +113,6 @@ def __init__( self.repeat_times = task.repeat_times self.max_env_steps = task.workflow_args.get("max_env_steps", 30) - @property - def asynchronous(self): - return True - async def get_model_response(self, messages): responses = await self.model.chat_async(messages, n=1) return responses diff --git a/trinity/common/workflows/envs/email_searcher/workflow.py b/trinity/common/workflows/envs/email_searcher/workflow.py index 2fe97f13f6..05702bd31a 100644 --- a/trinity/common/workflows/envs/email_searcher/workflow.py +++ b/trinity/common/workflows/envs/email_searcher/workflow.py @@ -27,6 +27,9 @@ class EmailSearchWorkflow(Workflow): Multi-turn Email Search workflow (ReAct-style tool use). """ + can_reset: bool = True + is_async: bool = True + def __init__( self, *, @@ -45,18 +48,6 @@ def __init__( self.reset(task) - @property - def repeatable(self) -> bool: - return False - - @property - def resettable(self): - return True - - @property - def asynchronous(self): - return True - def reset(self, task: Task): self.query = QueryModel.model_validate(task.raw_task) self.task_desc = task.task_desc # question diff --git a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py index da313e72ca..4c2a7417a9 100644 --- a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py +++ b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py @@ -59,6 +59,8 @@ def parse_action(response): class SciWorldWorkflow(MultiTurnWorkflow): """A workflow for sciworld task.""" + is_async: bool = True + def __init__( self, model: ModelWrapper, @@ -73,10 +75,6 @@ def __init__( self.repeat_times = task.repeat_times self.max_env_steps = 30 # should be less than 100 - @property - def asynchronous(self): - return True - async def get_model_response(self, messages): responses = await self.model.chat_async(messages, n=1) return responses diff --git a/trinity/common/workflows/envs/webshop/webshop_workflow.py b/trinity/common/workflows/envs/webshop/webshop_workflow.py index 5b18ee48ae..4cef8fc456 100644 --- a/trinity/common/workflows/envs/webshop/webshop_workflow.py +++ b/trinity/common/workflows/envs/webshop/webshop_workflow.py @@ -181,6 +181,9 @@ def validate_action(action, available_actions): class WebShopWorkflow(MultiTurnWorkflow): """A workflow for webshop task.""" + can_reset: bool = True + is_async: bool = True + def __init__( self, model: ModelWrapper, @@ -209,10 +212,6 @@ def __init__( "WebAgentTextEnv-v0", observation_mode="text_rich", num_products=None, human_goals=True ) - @property - def resettable(self): - return True - def reset(self, task: Task): self.task_desc = task.task_desc or "0" self.repeat_times = task.repeat_times diff --git a/trinity/common/workflows/eval_workflow.py b/trinity/common/workflows/eval_workflow.py index 03348d6bc2..55ea78e74e 100644 --- a/trinity/common/workflows/eval_workflow.py +++ b/trinity/common/workflows/eval_workflow.py @@ -42,14 +42,6 @@ def __init__( # TODO: customize the config in the yaml self.eval_gen_args = asdict(GenerationConfig(temperature=0.6, top_p=0.8, logprobs=0, n=1)) - @property - def resettable(self): - return False - - @property - def repeatable(self): - return False - def format_messages(self): """Format message for the evaluation of qwen_boxed type.""" if not self.raw_task or "question" not in self.raw_task: @@ -89,9 +81,7 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("async_math_eval_workflow") class AsyncMathEvalWorkflow(MathEvalWorkflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: messages = self.format_messages() diff --git a/trinity/common/workflows/math_rm_workflow.py b/trinity/common/workflows/math_rm_workflow.py index f498acf2ed..aca5586717 100644 --- a/trinity/common/workflows/math_rm_workflow.py +++ b/trinity/common/workflows/math_rm_workflow.py @@ -55,9 +55,7 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("async_math_rm_workflow") class AsyncMathRMWorkflow(MathRMWorkflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: messages = self.format_messages() diff --git a/trinity/common/workflows/math_ruler_workflow.py b/trinity/common/workflows/math_ruler_workflow.py index 9eb1462652..fec1bc72a4 100644 --- a/trinity/common/workflows/math_ruler_workflow.py +++ b/trinity/common/workflows/math_ruler_workflow.py @@ -160,9 +160,7 @@ def get_ruler_scores( @WORKFLOWS.register_module("async_math_ruler_workflow") class AsyncMathRULERWorkflow(MathRULERWorkflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: """Modified from SimpleWorkflow.run""" diff --git a/trinity/common/workflows/simple_mm_workflow.py b/trinity/common/workflows/simple_mm_workflow.py index c24fffe19c..2bb3f857e2 100644 --- a/trinity/common/workflows/simple_mm_workflow.py +++ b/trinity/common/workflows/simple_mm_workflow.py @@ -82,9 +82,7 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("async_simple_mm_workflow") class AsyncSimpleMMWorkflow(SimpleMMWorkflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: # TODO: test generate_mm diff --git a/trinity/common/workflows/step_wise_workflow.py b/trinity/common/workflows/step_wise_workflow.py index 608eea3bf2..1042866063 100644 --- a/trinity/common/workflows/step_wise_workflow.py +++ b/trinity/common/workflows/step_wise_workflow.py @@ -67,17 +67,11 @@ def max_step_num(self): """Return the maximum number of steps in the task.""" raise NotImplementedError - @property - def repeatable(self): - return False - class AsyncStepWiseRewardWorkflow(StepWiseRewardWorkflow): """Async version of `StepWiseRewardWorkflow`.""" - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> list[Experience]: """Run the workflow and return a list of experiences with step-wise rewards asynchronously.""" @@ -184,17 +178,11 @@ def max_step_num(self): """Return the maximum number of steps in the task.""" raise NotImplementedError - @property - def repeatable(self): - return False - class AsyncRewardPropagationWorkflow(RewardPropagationWorkflow): """Async version of `RewardPropagationWorkflow`.""" - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> list[Experience]: """Run the workflow and return a list of experiences with step-wise rewards asynchronously.""" diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index c5694f595a..26a555d7cc 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -80,6 +80,10 @@ class Workflow: A workflow is a runnable object which generates a list of experiences. """ + can_reset: bool = False # whether the workflow can be reset with a new task. If true, `reset()` must be implemented. + can_repeat: bool = False # whether the workflow can be repeated multiple times. If true, `set_repeat_times()` must be implemented. + is_async: bool = False # whether the workflow runs in async mode. If true, `run_async()` must be implemented, else `run()` must be implemented. + def __init__( self, *, @@ -95,21 +99,21 @@ def __init__( @property def resettable(self): - return False + """Deprecated, use cls.can_reset instead.""" + return self.__class__.can_reset @property def repeatable(self): - """A workflow is repeatable if it can be run multiple times within the run() or run_async() method.""" - return False + """Deprecated, use cls.can_repeat instead. + A workflow is repeatable if it can be run multiple times within the run() or run_async() method. + """ + return self.__class__.can_repeat @property def asynchronous(self): - """Whether the workflow runs in async mode.""" - return False - - @property - def rollout_args(self): - return asdict(self.task.rollout_args) + """Deprecated, use cls.is_async instead. + Whether the workflow runs in async mode.""" + return self.__class__.is_async def reset(self, task: Task): """Reset the workflow.""" @@ -140,6 +144,8 @@ class MultiTurnWorkflow(Workflow): The base workflow class for concatenated multi-turn tasks. """ + can_repeat: bool = True + def __init__( self, *, @@ -153,10 +159,6 @@ def __init__( auxiliary_models=auxiliary_models, ) - @property - def repeatable(self): - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base @@ -190,6 +192,9 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc class SimpleWorkflow(Workflow): """A workflow for simple single-round task.""" + can_reset: bool = True + can_repeat: bool = True + def __init__( self, *, @@ -204,10 +209,6 @@ def __init__( auxiliary_models=auxiliary_models, ) - @property - def resettable(self): - return True - def reset(self, task: Task): self.format_args = task.format_args self.system_prompt = task.format_args.system_prompt @@ -224,15 +225,15 @@ def reset(self, task: Task): else: raise ValueError("`reward_fn` must be a subclass of `RewardFn`") - @property - def repeatable(self): - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.task.rollout_args.n = repeat_times self.run_id_base = run_id_base + @property + def rollout_args(self): + return asdict(self.task.rollout_args) + def format_messages(self): """Format messages for the instruct model.""" messages = [] @@ -270,9 +271,7 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("async_simple_workflow") class AsyncSimpleWorkflow(Workflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: # TODO: Optimize the generate function diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 7a473137f0..5acc634390 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -51,6 +51,7 @@ def __init__( for model in (auxiliary_models or []) ] self.auxiliary_model_clients = [] + self.auxiliary_model_async_clients = [] self.workflow_instance: Workflow = None self.runner_id = runner_id @@ -62,7 +63,9 @@ async def prepare(self) -> None: ) for model in self.auxiliary_models: api_client = model.get_openai_client() + async_api_client = model.get_openai_async_client() self.auxiliary_model_clients.append(api_client) + self.auxiliary_model_async_clients.append(async_api_client) def is_alive(self): return True From df89df403da9481806d2725b107df49c9033b6e9 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 20 Oct 2025 20:13:25 +0800 Subject: [PATCH 2/6] add tests --- tests/explorer/workflow_test.py | 105 +++++++++++++++++++++++++--- trinity/explorer/workflow_runner.py | 7 +- 2 files changed, 101 insertions(+), 11 deletions(-) diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 8623500258..3105292389 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -2,16 +2,19 @@ """Test for the workflow module""" import asyncio import unittest +from unittest import mock from dataclasses import dataclass, field from typing import Dict, Optional from unittest.mock import MagicMock +import openai import ray from parameterized import parameterized, parameterized_class from torch import Tensor from tests.common.vllm_test import CHAT_TEMPLATE from tests.tools import get_model_path, get_template_config, get_unittest_dataset_config +from trinity.explorer.workflow_runner import WorkflowRunner from trinity.common.experience import EID, Experience from trinity.common.models import create_inference_models from trinity.common.models.model import ModelWrapper @@ -48,6 +51,9 @@ def __init__(self, model, task: Task, auxiliary_models=None): self.obj = task.raw_task self.output_format = task.workflow_args["output_format"] self.repeat_times = task.rollout_args.n + if auxiliary_models is not None: + for model in auxiliary_models: + assert isinstance(model, openai.OpenAI) def reset(self, task: Task): self.obj = task.raw_task @@ -58,14 +64,21 @@ def set_repeat_times(self, repeat_times, run_id_base): self.run_id_base = run_id_base def run(self): + exps = [] if self.output_format == "json": import json - - return [json.dumps(self.obj)] * self.repeat_times + for i in range(self.repeat_times): + exp = Experience(tokens=Tensor([0, 1, 2, 3]), prompt_length=1) + exp.response_text = json.dumps(self.obj) + exps.append(exp) + return exps elif self.output_format == "yaml": import yaml - - return [yaml.safe_dump(self.obj)] * self.repeat_times + for i in range(self.repeat_times): + exp = Experience(tokens=Tensor([0, 1, 2, 3]), prompt_length=1) + exp.response_text = yaml.safe_dump(self.obj) + exps.append(exp) + return exps else: raise ValueError("Invalid output format") @@ -80,6 +93,9 @@ def __init__(self, model, task: Task, auxiliary_models=None): self.obj = task.raw_task self.output_format = task.workflow_args["output_format"] self.repeat_times = task.rollout_args.n + if auxiliary_models is not None: + for model in auxiliary_models: + assert isinstance(model, openai.AsyncOpenAI) def reset(self, task: Task): self.obj = task.raw_task @@ -91,14 +107,21 @@ def set_repeat_times(self, repeat_times, run_id_base): async def run_async(self): await asyncio.sleep(0.1) + exps = [] if self.output_format == "json": import json - - return [json.dumps(self.obj)] * self.repeat_times + for i in range(self.repeat_times): + exp = Experience(tokens=Tensor([0, 1, 2, 3]), prompt_length=1) + exp.response_text = json.dumps(self.obj) + exps.append(exp) + return exps elif self.output_format == "yaml": import yaml - - return [yaml.safe_dump(self.obj)] * self.repeat_times + for i in range(self.repeat_times): + exp = Experience(tokens=Tensor([0, 1, 2, 3]), prompt_length=1) + exp.response_text = yaml.safe_dump(self.obj) + exps.append(exp) + return exps else: raise ValueError("Invalid output format") @@ -413,13 +436,13 @@ def test_workflow_resettable(self, workflow_cls) -> None: answer = asyncio.run(workflow.run_async()) else: answer = workflow.run() - self.assertEqual(answer[0], '{"a": 1}') + self.assertEqual(answer[0].response_text, '{"a": 1}') workflow.reset(yaml_task) if workflow.asynchronous: answer = asyncio.run(workflow.run_async()) else: answer = workflow.run() - self.assertEqual(answer[0], "a: 1\n") + self.assertEqual(answer[0].response_text, "a: 1\n") @parameterized.expand([(DummyWorkflow,), (DummyAsyncWorkflow,)]) def test_workflow_repeatable(self, workflow_cls) -> None: @@ -517,3 +540,65 @@ async def as_workflow_func(task, model) -> float: self.assertEqual(result[0].prompt_length, 1) self.assertEqual(result[1].reward, 0.1) self.assertEqual(result[1].prompt_length, 2) + + +class DummyModelWrapper: + + def __init__(self, model, engine_type="vllm", **kwargs): + pass + + async def prepare(self): + return None + + def get_openai_client(self): + return openai.OpenAI(api_key="EMPTY") + + def get_openai_async_client(self): + return openai.AsyncOpenAI(api_key="EMPTY") + + @property + async def model_version_async(self): + return 0 + +class TestWorkflowRunner(unittest.IsolatedAsyncioTestCase): + + async def test_workflow_runner(self): + + config = get_template_config() + + with mock.patch( + "trinity.explorer.workflow_runner.ModelWrapper", + DummyModelWrapper, + ): + runner = WorkflowRunner( + config, + model=MagicMock(), + auxiliary_models=[MagicMock(), MagicMock()], + runner_id=0, + ) + await runner.prepare() + + task = Task( + workflow=DummyWorkflow, + repeat_times=3, + raw_task={"a": 1}, + workflow_args={"output_format": "json"}, + ) + + status, exps = await runner.run_task(task, repeat_times=3, run_id_base=0) + + self.assertTrue(status.ok) + self.assertIsInstance(exps, list) + self.assertEqual(len(exps), 3) + + task = Task( + workflow=DummyAsyncWorkflow, + repeat_times=2, + raw_task={"a": 1}, + workflow_args={"output_format": "yaml"}, + ) + + status, exps = await runner.run_task(task, repeat_times=2, run_id_base=0) + self.assertTrue(status.ok) + self.assertIsInstance(exps, list) + self.assertEqual(len(exps), 2) diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 5acc634390..187d0d5adf 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -79,7 +79,12 @@ def _create_workflow_instance(self, task: Task) -> None: or not self.workflow_instance.resettable ): self.workflow_instance = task.to_workflow( - self.model_wrapper, self.auxiliary_model_clients + self.model_wrapper, + ( + self.auxiliary_model_async_clients + if task.workflow.is_async + else self.auxiliary_model_clients + ), ) else: self.workflow_instance.reset(task) From f9e4892c2e8956d6adbe57ce2403c9d19fc3fae1 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 20 Oct 2025 20:13:43 +0800 Subject: [PATCH 3/6] fix pre-commit --- tests/explorer/workflow_test.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 3105292389..81e1afa916 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -2,9 +2,9 @@ """Test for the workflow module""" import asyncio import unittest -from unittest import mock from dataclasses import dataclass, field from typing import Dict, Optional +from unittest import mock from unittest.mock import MagicMock import openai @@ -14,7 +14,6 @@ from tests.common.vllm_test import CHAT_TEMPLATE from tests.tools import get_model_path, get_template_config, get_unittest_dataset_config -from trinity.explorer.workflow_runner import WorkflowRunner from trinity.common.experience import EID, Experience from trinity.common.models import create_inference_models from trinity.common.models.model import ModelWrapper @@ -28,6 +27,7 @@ Workflow, ) from trinity.common.workflows.workflow import MultiTurnWorkflow, Task +from trinity.explorer.workflow_runner import WorkflowRunner @dataclass @@ -67,6 +67,7 @@ def run(self): exps = [] if self.output_format == "json": import json + for i in range(self.repeat_times): exp = Experience(tokens=Tensor([0, 1, 2, 3]), prompt_length=1) exp.response_text = json.dumps(self.obj) @@ -74,6 +75,7 @@ def run(self): return exps elif self.output_format == "yaml": import yaml + for i in range(self.repeat_times): exp = Experience(tokens=Tensor([0, 1, 2, 3]), prompt_length=1) exp.response_text = yaml.safe_dump(self.obj) @@ -110,6 +112,7 @@ async def run_async(self): exps = [] if self.output_format == "json": import json + for i in range(self.repeat_times): exp = Experience(tokens=Tensor([0, 1, 2, 3]), prompt_length=1) exp.response_text = json.dumps(self.obj) @@ -117,6 +120,7 @@ async def run_async(self): return exps elif self.output_format == "yaml": import yaml + for i in range(self.repeat_times): exp = Experience(tokens=Tensor([0, 1, 2, 3]), prompt_length=1) exp.response_text = yaml.safe_dump(self.obj) @@ -543,7 +547,6 @@ async def as_workflow_func(task, model) -> float: class DummyModelWrapper: - def __init__(self, model, engine_type="vllm", **kwargs): pass @@ -560,10 +563,9 @@ def get_openai_async_client(self): async def model_version_async(self): return 0 -class TestWorkflowRunner(unittest.IsolatedAsyncioTestCase): +class TestWorkflowRunner(unittest.IsolatedAsyncioTestCase): async def test_workflow_runner(self): - config = get_template_config() with mock.patch( From f421336ab20ee970717825c99c005942aed3582c Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 20 Oct 2025 20:27:31 +0800 Subject: [PATCH 4/6] fix rollout args --- .../envs/agentscope/agentscopev0_react_workflow.py | 7 ++----- .../envs/agentscope/agentscopev1_react_workflow.py | 6 ++---- .../envs/agentscope/agentscopev1_search_workflow.py | 6 ++---- trinity/common/workflows/envs/email_searcher/workflow.py | 4 ++-- 4 files changed, 8 insertions(+), 15 deletions(-) diff --git a/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py index 5c8dbe82f8..a44acedf1f 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py @@ -46,9 +46,6 @@ def __init__( self.openai_client = model.get_openai_client() self.model_name = self.openai_client.model_path - temperature = self.rollout_args.get("temperature", 1.0) - max_tokens = self.rollout_args.get("max_tokens", 4096) - agentscope.init( model_configs=[ { @@ -57,8 +54,8 @@ def __init__( "model_name": self.model_name, "api_key": "EMPTY", "generate_args": { - "temperature": temperature, - "max_tokens": max_tokens, + "temperature": self.task.rollout_args.temperature, + "max_tokens": self.task.rollout_args.max_tokens or 4096, }, "use_openai_formatter": True, } diff --git a/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py index 079b8a940d..dbc73150bc 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py @@ -45,15 +45,13 @@ def __init__( self.openai_async_client = model.get_openai_async_client() self.model_name = self.openai_async_client.model_path - temperature = self.rollout_args.get("temperature", 1.0) - max_tokens = self.rollout_args.get("max_tokens", 4096) self.agent_model = OpenAIChatModel( api_key="EMPTY", model_name=self.model_name, stream=False, generate_kwargs={ - "temperature": temperature, - "max_tokens": max_tokens, + "temperature": self.task.rollout_args.temperature, + "max_tokens": self.task.rollout_args.max_tokens or 4096, }, ) self.agent_model.client = self.openai_async_client diff --git a/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py index 57f45d9af3..2585113303 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py @@ -45,15 +45,13 @@ def __init__( self.openai_async_client = model.get_openai_async_client() self.model_name = self.openai_async_client.model_path - temperature = self.rollout_args.get("temperature", 1.0) - max_tokens = self.rollout_args.get("max_tokens", 4096) self.agent_model = OpenAIChatModel( api_key="EMPTY", model_name=self.model_name, stream=False, generate_kwargs={ - "temperature": temperature, - "max_tokens": max_tokens, + "temperature": self.task.rollout_args.temperature, + "max_tokens": self.task.rollout_args.max_tokens or 4096, }, ) self.agent_model.client = self.openai_async_client diff --git a/trinity/common/workflows/envs/email_searcher/workflow.py b/trinity/common/workflows/envs/email_searcher/workflow.py index 05702bd31a..737f9fc279 100644 --- a/trinity/common/workflows/envs/email_searcher/workflow.py +++ b/trinity/common/workflows/envs/email_searcher/workflow.py @@ -78,8 +78,8 @@ def reset(self, task: Task): openai_client=self.openai_client, system_prompt=self.system_prompt, generate_kwargs={ - "temperature": self.rollout_args.get("temperature", 1.0), - "max_tokens": self.rollout_args.get("max_tokens", 4096), + "temperature": self.task.rollout_args.temperature, + "max_tokens": self.task.rollout_args.max_tokens or 4096, }, response_structure=AnswerModel, ) From 20577e2d0d67361c7ed90ec32e02e011891daf77 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 20 Oct 2025 20:29:32 +0800 Subject: [PATCH 5/6] fix comments --- trinity/common/workflows/agentscope/react/react_workflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trinity/common/workflows/agentscope/react/react_workflow.py b/trinity/common/workflows/agentscope/react/react_workflow.py index 044c7935f1..a6dbca28e3 100644 --- a/trinity/common/workflows/agentscope/react/react_workflow.py +++ b/trinity/common/workflows/agentscope/react/react_workflow.py @@ -57,8 +57,8 @@ def __init__( openai_client=self.model_client, system_prompt=template.system_prompt, generate_kwargs={ - "temperature": self.rollout_args.get("temperature", 1.0), - "max_tokens": self.rollout_args.get("max_tokens", 4096), + "temperature": self.task.rollout_args.temperature, + "max_tokens": self.task.rollout_args.max_tokens or 4096, }, response_structure=template.response_structure, ) From 8bb5ec89f7a1d7a823449a2a50b76d307596b0ec Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 21 Oct 2025 11:14:11 +0800 Subject: [PATCH 6/6] fix doc --- docs/sphinx_doc/source/tutorial/develop_workflow.md | 5 +++-- docs/sphinx_doc/source_zh/tutorial/develop_workflow.md | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/develop_workflow.md b/docs/sphinx_doc/source/tutorial/develop_workflow.md index f6ded5eaad..0d5c632421 100644 --- a/docs/sphinx_doc/source/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source/tutorial/develop_workflow.md @@ -356,11 +356,12 @@ trinity run --config #### Async Support -The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can set `is_async` to `True`, then implement the `run_async` method. In this case, you no longer need to implement the `run` method, while other methods and properties remain unaffected. +The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can set `is_async` to `True`, then implement the `run_async` method. In this case, you no longer need to implement the `run` method, and the initialization parameter `auxiliary_models` will also change to `List[openai.AsyncOpenAI]`, while other methods and properties remain changed. ```python @WORKFLOWS.register_module("example_workflow_async") class ExampleWorkflowAsync(Workflow): + is_async: bool = True async def run_async(self) -> List[Experience]: @@ -447,7 +448,7 @@ explorer: Note that each auxiliary model will independently occupy `tensor_parallel_size * engine_num` GPUs. Please configure according to your hardware resources. After enabling auxiliary models, the number of GPUs available to the Trainer is the total GPU count minus those occupied by all auxiliary models and the inference model being trained (`rollout_model`). -The auxiliary models specified in the configuration file will automatically activate the OpenAI API and pass the corresponding `openai.OpenAI` instances to the `auxiliary_models` parameter of the `Workflow` initialization method. For example: +The auxiliary models specified in the configuration file will automatically activate the OpenAI API and pass the corresponding `openai.OpenAI` or `openai.AsyncOpenAI` instances (depending on the `is_async` setting) to the `auxiliary_models` parameter of the `Workflow` initialization method. For example: ```python class MyWorkflow(Workflow): diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md index f178491643..366834ea53 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md @@ -351,7 +351,7 @@ trinity run --config #### async 支持 -本节样例主要针对同步模式,如果你的工作流需要使用异步方法(例如异步 API),你可以将 `is_async` 属性设置为 `True`,然后实现 `run_async` 方法,在这种情况下不再需要实现 `run` 方法,其余方法和属性不受影响。 +本节样例主要针对同步模式,如果你的工作流需要使用异步方法(例如异步 API),你可以将 `is_async` 属性设置为 `True`,然后实现 `run_async` 方法,在这种情况下不再需要实现 `run` 方法,并且初始化参数 `auxiliary_models` 也会自动变为 `List[openai.AsyncOpenAI]` 类型,其余方法和属性保持不变。 ```python @WORKFLOWS.register_module("example_workflow_async") @@ -445,7 +445,7 @@ explorer: 请注意,每个辅助模型会独立占用 `tensor_parallel_size * engine_num` 个 GPU,请根据硬件资源合理配置。在启用辅助模型后,Trainer 可用的 GPU 数量为总 GPU 数量减去所有辅助模型及被训练的推理模型(`rollout_model`)所占用的 GPU 数量。 -配置文件中指定的辅助模型会自动激活 OpenAI API,并将对应的 `openai.OpenAI` 实例传递给 `Workflow` 初始化方法的 `auxiliary_models` 参数。例如: +配置文件中指定的辅助模型会自动激活 OpenAI API,并将对应的 `openai.OpenAI` 或 `openai.AsyncOpenAI` 实例 (取决于 `is_async`) 传递给 `Workflow` 初始化方法的 `auxiliary_models` 参数。例如: ```python class MyWorkflow(Workflow):