diff --git a/docs/sphinx_doc/source/tutorial/develop_workflow.md b/docs/sphinx_doc/source/tutorial/develop_workflow.md index b40c87801c..0d5c632421 100644 --- a/docs/sphinx_doc/source/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source/tutorial/develop_workflow.md @@ -205,20 +205,20 @@ __all__ = [ ##### Avoid Re-initialization For heavy workflows, re-initializing every time can incurs extra computational costs. -In this case, you can implement the `resettable` property and `reset` method to avoid re-initialization. +In this case, you can set the `can_reset` property and implement `reset` method to avoid re-initialization. + +The `can_reset` is a class property that indicates whether the workflow supports resetting. -The `resettable` property returns a boolean indicating whether the workflow supports resetting. The `reset` method accepts a `Task` parameter and resets the workflow's internal state based on the new task. ```python @WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): + can_reset: bool = True + # some code # ... - def resettable(self): - return True - def reset(self, task: Task): self.question = task.raw_task.get("question") self.answer = task.raw_task.get("answer") @@ -227,20 +227,18 @@ class ExampleWorkflow(Workflow): ##### Support Batch Inference In many popular RL algorithms, multiple runs of the same task are required (e.g., GRPO). In such scenarios, you can directly use batch inference to obtain multiple responses for a single question to improve efficiency. -For this case, you can implement the `repeatable` property and `set_repeat_times` method. +For this case, you can implement the `can_repeat` property and `set_repeat_times` method. + +The `can_repeat` is a class property that indicates whether the workflow supports multiple executions within the `run` method. -The `repeatable` property returns a boolean indicating whether the workflow supports multiple executions within the `run` method. The `set_repeat_times` method accepts two parameters: `repeat_times` specifies the number of times to execute within the `run` method, and `run_id_base` is an integer used to identify the first run ID in multiple runs (this parameter is used in multi-turn interaction scenarios; for tasks that can be completed with a single model call, this can be ignored). ```python @WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): + can_repeat: bool = True # some code - @property - def repeatable(self) -> bool: - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base @@ -279,6 +277,8 @@ class ExampleWorkflow(Workflow): ```python @WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): + can_reset: bool = True + can_repeat: bool = True def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) @@ -319,18 +319,10 @@ class ExampleWorkflow(Workflow): ) return experiences - @property - def resettable(self): - return True - def reset(self, task: Task): self.question = task.raw_task.get("question") self.answer = task.raw_task.get("answer") - @property - def repeatable(self) -> bool: - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base @@ -364,15 +356,13 @@ trinity run --config #### Async Support -The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can implement the `asynchronous` property and set it to `True`, then implement the `run_async` method. In this case, you no longer need to implement the `run` method, while other methods and properties remain unaffected. +The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can set `is_async` to `True`, then implement the `run_async` method. In this case, you no longer need to implement the `run` method, and the initialization parameter `auxiliary_models` will also change to `List[openai.AsyncOpenAI]`, while other methods and properties remain changed. ```python @WORKFLOWS.register_module("example_workflow_async") class ExampleWorkflowAsync(Workflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: # your async code here @@ -458,7 +448,7 @@ explorer: Note that each auxiliary model will independently occupy `tensor_parallel_size * engine_num` GPUs. Please configure according to your hardware resources. After enabling auxiliary models, the number of GPUs available to the Trainer is the total GPU count minus those occupied by all auxiliary models and the inference model being trained (`rollout_model`). -The auxiliary models specified in the configuration file will automatically activate the OpenAI API and pass the corresponding `openai.OpenAI` instances to the `auxiliary_models` parameter of the `Workflow` initialization method. For example: +The auxiliary models specified in the configuration file will automatically activate the OpenAI API and pass the corresponding `openai.OpenAI` or `openai.AsyncOpenAI` instances (depending on the `is_async` setting) to the `auxiliary_models` parameter of the `Workflow` initialization method. For example: ```python class MyWorkflow(Workflow): diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md index 3cb85787e4..366834ea53 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md @@ -200,22 +200,20 @@ __all__ = [ ##### 避免重复初始化 对于较为复杂的工作流,每次重新初始化会带来额外计算开销。 -此时,你可以实现 `resettable` 和 `reset` 方法以避免重复初始化。 +此时,你可以设置 `can_reset` 属性并实现 `reset` 方法以避免重复初始化。 -`resettable` 方法返回一个布尔值,指示工作流是否支持轻量化重置。 +`can_reset` 是一个类属性,表示工作流是否支持轻量化重置。 `reset` 方法接受一个新的 `Task` 实例,并使用该实例更新工作流的状态。 ```python @WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): + can_reset: bool = True + # some code # ... - @property - def resettable(self): - return True - def reset(self, task: Task): self.question = task.raw_task.get("question") self.answer = task.raw_task.get("answer") @@ -224,21 +222,18 @@ class ExampleWorkflow(Workflow): ##### 批量运行推理任务 当前流行的很多 RL 算法需要多次运行同一个任务(例如 GRPO)。该场景下一些简单任务可以直接通过模型批量推理来获得一个问题的多个回复以提升效率。 -针对该情况,你可以实现 `repeatable` 属性以及 `set_repeat_times` 方法。 +针对该情况,你可以设置 `can_repeat` 属性并实现 `set_repeat_times` 方法。 -`repeatable` 属性返回一个布尔值,指示工作流是否支持在 `run` 方法内多次执行。 +`can_repeat` 是一个类属性,指示工作流是否支持在 `run` 方法内多次执行。 `set_repeat_times` 方法接受两个参数:`repeat_times` 指定了在 `run` 方法内需要执行的次数,`run_id_base` 是一个整数,用于标识多次运行中第一次的运行 ID,之后各次的 ID 基于此递增(该参数用于多轮交互场景,单次模型调用即可完成的任务可以忽略该项)。 ```python @WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): + can_repeat: bool = True # some code - @property - def repeatable(self) -> bool: - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base @@ -277,6 +272,8 @@ class ExampleWorkflow(Workflow): ```python @WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): + can_reset: bool = True + can_repeat: bool = True def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) @@ -317,18 +314,10 @@ class ExampleWorkflow(Workflow): ) return experiences - @property - def resettable(self): - return True - def reset(self, task: Task): self.question = task.raw_task.get("question") self.answer = task.raw_task.get("answer") - @property - def repeatable(self) -> bool: - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base @@ -362,15 +351,13 @@ trinity run --config #### async 支持 -本节样例主要针对同步模式,如果你的工作流需要使用异步方法(例如异步 API),你可以实现 `asynchronous` 属性并将其设置为 `True`,然后实现 `run_async` 方法,在这种情况下不再需要实现 `run` 方法,其余方法和属性不受影响。 +本节样例主要针对同步模式,如果你的工作流需要使用异步方法(例如异步 API),你可以将 `is_async` 属性设置为 `True`,然后实现 `run_async` 方法,在这种情况下不再需要实现 `run` 方法,并且初始化参数 `auxiliary_models` 也会自动变为 `List[openai.AsyncOpenAI]` 类型,其余方法和属性保持不变。 ```python @WORKFLOWS.register_module("example_workflow_async") class ExampleWorkflowAsync(Workflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: # your async code here @@ -458,7 +445,7 @@ explorer: 请注意,每个辅助模型会独立占用 `tensor_parallel_size * engine_num` 个 GPU,请根据硬件资源合理配置。在启用辅助模型后,Trainer 可用的 GPU 数量为总 GPU 数量减去所有辅助模型及被训练的推理模型(`rollout_model`)所占用的 GPU 数量。 -配置文件中指定的辅助模型会自动激活 OpenAI API,并将对应的 `openai.OpenAI` 实例传递给 `Workflow` 初始化方法的 `auxiliary_models` 参数。例如: +配置文件中指定的辅助模型会自动激活 OpenAI API,并将对应的 `openai.OpenAI` 或 `openai.AsyncOpenAI` 实例 (取决于 `is_async`) 传递给 `Workflow` 初始化方法的 `auxiliary_models` 参数。例如: ```python class MyWorkflow(Workflow): diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 36c73063a9..446b191e6b 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -18,6 +18,8 @@ @WORKFLOWS.register_module("dummy_workflow") class DummyWorkflow(Workflow): + can_repeat: bool = True + def __init__(self, *, task, model, auxiliary_models): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.step_num = task.workflow_args.get("step_num", 1) @@ -30,10 +32,6 @@ def __init__(self, *, task, model, auxiliary_models): else: self.seconds = 10 - @property - def repeatable(self): - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base @@ -63,19 +61,13 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("dummy_nonrepeat_workflow") class DummyNonRepeatWorkflow(Workflow): + can_reset: bool = True + def __init__(self, *, task, model, auxiliary_models): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.reset_flag = False self.step_num = task.workflow_args.get("step_num", 1) - @property - def resettable(self): - return True - - @property - def repeatable(self): - return False - def reset(self, task: Task): self.task = task self.reset_flag = True @@ -95,18 +87,13 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("dummy_async_workflow") class DummyAsyncWorkflow(Workflow): + can_repeat: bool = True + is_async: bool = True + def __init__(self, *, task, model, auxiliary_models): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.step_num = task.workflow_args.get("step_num", 1) - @property - def asynchronous(self): - return True - - @property - def repeatable(self): - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index e5dab0004b..81e1afa916 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -4,8 +4,10 @@ import unittest from dataclasses import dataclass, field from typing import Dict, Optional +from unittest import mock from unittest.mock import MagicMock +import openai import ray from parameterized import parameterized, parameterized_class from torch import Tensor @@ -25,6 +27,7 @@ Workflow, ) from trinity.common.workflows.workflow import MultiTurnWorkflow, Task +from trinity.explorer.workflow_runner import WorkflowRunner @dataclass @@ -40,19 +43,17 @@ class MockResponse: class DummyWorkflow(Workflow): + can_reset: bool = True + can_repeat: bool = True + def __init__(self, model, task: Task, auxiliary_models=None): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.obj = task.raw_task self.output_format = task.workflow_args["output_format"] self.repeat_times = task.rollout_args.n - - @property - def resettable(self): - return True - - @property - def repeatable(self): - return True + if auxiliary_models is not None: + for model in auxiliary_models: + assert isinstance(model, openai.OpenAI) def reset(self, task: Task): self.obj = task.raw_task @@ -63,36 +64,40 @@ def set_repeat_times(self, repeat_times, run_id_base): self.run_id_base = run_id_base def run(self): + exps = [] if self.output_format == "json": import json - return [json.dumps(self.obj)] * self.repeat_times + for i in range(self.repeat_times): + exp = Experience(tokens=Tensor([0, 1, 2, 3]), prompt_length=1) + exp.response_text = json.dumps(self.obj) + exps.append(exp) + return exps elif self.output_format == "yaml": import yaml - return [yaml.safe_dump(self.obj)] * self.repeat_times + for i in range(self.repeat_times): + exp = Experience(tokens=Tensor([0, 1, 2, 3]), prompt_length=1) + exp.response_text = yaml.safe_dump(self.obj) + exps.append(exp) + return exps else: raise ValueError("Invalid output format") class DummyAsyncWorkflow(Workflow): + can_reset: bool = True + can_repeat: bool = True + is_async: bool = True + def __init__(self, model, task: Task, auxiliary_models=None): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.obj = task.raw_task self.output_format = task.workflow_args["output_format"] self.repeat_times = task.rollout_args.n - - @property - def resettable(self): - return True - - @property - def repeatable(self): - return True - - @property - def asynchronous(self): - return True + if auxiliary_models is not None: + for model in auxiliary_models: + assert isinstance(model, openai.AsyncOpenAI) def reset(self, task: Task): self.obj = task.raw_task @@ -104,14 +109,23 @@ def set_repeat_times(self, repeat_times, run_id_base): async def run_async(self): await asyncio.sleep(0.1) + exps = [] if self.output_format == "json": import json - return [json.dumps(self.obj)] * self.repeat_times + for i in range(self.repeat_times): + exp = Experience(tokens=Tensor([0, 1, 2, 3]), prompt_length=1) + exp.response_text = json.dumps(self.obj) + exps.append(exp) + return exps elif self.output_format == "yaml": import yaml - return [yaml.safe_dump(self.obj)] * self.repeat_times + for i in range(self.repeat_times): + exp = Experience(tokens=Tensor([0, 1, 2, 3]), prompt_length=1) + exp.response_text = yaml.safe_dump(self.obj) + exps.append(exp) + return exps else: raise ValueError("Invalid output format") @@ -133,14 +147,12 @@ def run(self): class DummyAsyncMultiTurnWorkflow(MultiTurnWorkflow): + is_async: bool = True + def __init__(self, model, task: Task, auxiliary_models=None): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.contents = task.raw_task["contents"] # type: ignore - @property - def asynchronous(self): - return True - async def run_async(self): memory = [{"role": "system", "content": "You are a helpful assistant."}] experience_list = [] @@ -428,13 +440,13 @@ def test_workflow_resettable(self, workflow_cls) -> None: answer = asyncio.run(workflow.run_async()) else: answer = workflow.run() - self.assertEqual(answer[0], '{"a": 1}') + self.assertEqual(answer[0].response_text, '{"a": 1}') workflow.reset(yaml_task) if workflow.asynchronous: answer = asyncio.run(workflow.run_async()) else: answer = workflow.run() - self.assertEqual(answer[0], "a: 1\n") + self.assertEqual(answer[0].response_text, "a: 1\n") @parameterized.expand([(DummyWorkflow,), (DummyAsyncWorkflow,)]) def test_workflow_repeatable(self, workflow_cls) -> None: @@ -532,3 +544,63 @@ async def as_workflow_func(task, model) -> float: self.assertEqual(result[0].prompt_length, 1) self.assertEqual(result[1].reward, 0.1) self.assertEqual(result[1].prompt_length, 2) + + +class DummyModelWrapper: + def __init__(self, model, engine_type="vllm", **kwargs): + pass + + async def prepare(self): + return None + + def get_openai_client(self): + return openai.OpenAI(api_key="EMPTY") + + def get_openai_async_client(self): + return openai.AsyncOpenAI(api_key="EMPTY") + + @property + async def model_version_async(self): + return 0 + + +class TestWorkflowRunner(unittest.IsolatedAsyncioTestCase): + async def test_workflow_runner(self): + config = get_template_config() + + with mock.patch( + "trinity.explorer.workflow_runner.ModelWrapper", + DummyModelWrapper, + ): + runner = WorkflowRunner( + config, + model=MagicMock(), + auxiliary_models=[MagicMock(), MagicMock()], + runner_id=0, + ) + await runner.prepare() + + task = Task( + workflow=DummyWorkflow, + repeat_times=3, + raw_task={"a": 1}, + workflow_args={"output_format": "json"}, + ) + + status, exps = await runner.run_task(task, repeat_times=3, run_id_base=0) + + self.assertTrue(status.ok) + self.assertIsInstance(exps, list) + self.assertEqual(len(exps), 3) + + task = Task( + workflow=DummyAsyncWorkflow, + repeat_times=2, + raw_task={"a": 1}, + workflow_args={"output_format": "yaml"}, + ) + + status, exps = await runner.run_task(task, repeat_times=2, run_id_base=0) + self.assertTrue(status.ok) + self.assertIsInstance(exps, list) + self.assertEqual(len(exps), 2) diff --git a/trinity/common/workflows/agentscope/react/react_workflow.py b/trinity/common/workflows/agentscope/react/react_workflow.py index 247ba085d9..a6dbca28e3 100644 --- a/trinity/common/workflows/agentscope/react/react_workflow.py +++ b/trinity/common/workflows/agentscope/react/react_workflow.py @@ -16,6 +16,8 @@ @WORKFLOWS.register_module("as_react_workflow") class AgentScopeReActWorkflow(Workflow): + is_async: bool = True + def __init__( self, *, @@ -55,8 +57,8 @@ def __init__( openai_client=self.model_client, system_prompt=template.system_prompt, generate_kwargs={ - "temperature": self.rollout_args.get("temperature", 1.0), - "max_tokens": self.rollout_args.get("max_tokens", 4096), + "temperature": self.task.rollout_args.temperature, + "max_tokens": self.task.rollout_args.max_tokens or 4096, }, response_structure=template.response_structure, ) @@ -95,13 +97,3 @@ def construct_experiences(self, reward: Union[float, Dict[str, float]]) -> List[ if isinstance(reward, dict): exp.metrics.update(reward) return exps - - @property - def asynchronous(self): - """AgentScope's ReAct agent only supports asynchronous calls, so we set this to True.""" - return True - - @property - def repeatable(self): - """This workflow is not repeatable.""" - return False diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index 0cbce6e01a..359afa5b65 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -11,6 +11,8 @@ class AgentScopeWorkflowAdapter(Workflow): """Adapter to wrap a agentscope trainable workflow function into a Trinity Workflow.""" + is_async: bool = True + def __init__( self, *, @@ -45,21 +47,6 @@ def __init__( model.get_openai_async_client(), ) - @property - def asynchronous(self) -> bool: - """This workflow runs asynchronously.""" - return True - - @property - def repeatable(self) -> bool: - """This workflow is not repeatable.""" - return False - - @property - def resetable(self) -> bool: - """This workflow cannot be reset.""" - return False - def construct_experiences( self, reward: float, diff --git a/trinity/common/workflows/customized_math_workflows.py b/trinity/common/workflows/customized_math_workflows.py index f76983a324..65bed1ac0d 100644 --- a/trinity/common/workflows/customized_math_workflows.py +++ b/trinity/common/workflows/customized_math_workflows.py @@ -95,9 +95,7 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("async_math_boxed_workflow") class AsyncMathBoxedWorkflow(MathBoxedWorkflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: if not self.use_base: diff --git a/trinity/common/workflows/customized_toolcall_workflows.py b/trinity/common/workflows/customized_toolcall_workflows.py index 1e31e492dd..f723c0e347 100644 --- a/trinity/common/workflows/customized_toolcall_workflows.py +++ b/trinity/common/workflows/customized_toolcall_workflows.py @@ -215,6 +215,8 @@ class ToolCallWorkflow(SimpleWorkflow): Only support qwen model for now. You can change the prompt construction and reward calculation by yourself for other models. """ + is_async: bool = True + def reset(self, task: Task): self.format_args = task.format_args self.system_prompt = task.format_args.system_prompt @@ -227,10 +229,6 @@ def reset(self, task: Task): self.workflow_args = task.workflow_args self.reward_fn_args = task.reward_fn_args - @property - def asynchronous(self): - return True - def format_prompt(self): raw_task = self.raw_task messages = construct_prompt(raw_task) diff --git a/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py index e2c1fe8f7b..a44acedf1f 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py @@ -19,6 +19,8 @@ class AgentScopeV0ReactMathWorkflow(Workflow): We use the AgentScope V0 version here. The code will be deprecated soon. """ + can_reset: bool = True + def __init__( self, *, @@ -44,9 +46,6 @@ def __init__( self.openai_client = model.get_openai_client() self.model_name = self.openai_client.model_path - temperature = self.rollout_args.get("temperature", 1.0) - max_tokens = self.rollout_args.get("max_tokens", 4096) - agentscope.init( model_configs=[ { @@ -55,8 +54,8 @@ def __init__( "model_name": self.model_name, "api_key": "EMPTY", "generate_args": { - "temperature": temperature, - "max_tokens": max_tokens, + "temperature": self.task.rollout_args.temperature, + "max_tokens": self.task.rollout_args.max_tokens or 4096, }, "use_openai_formatter": True, } @@ -72,10 +71,6 @@ def __init__( ) self.reset(task) - @property - def resettable(self): - return True - def reset(self, task: Task): self.system_prompt = """ You are an agent specialized in solving math problems with tools. Please solve the math problem given to you. You can write and execute Python code to perform calculation or verify your answer. You should return your final answer within \\boxed{{}}. @@ -115,10 +110,6 @@ def reset(self, task: Task): # we use the boxed format to evaluate the answer self.reward_fn = MathBoxedRewardFn() - @property - def repeatable(self): - return False - def run(self): # make sure that we have the correct import try: diff --git a/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py index a20c88063c..dbc73150bc 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py @@ -17,6 +17,9 @@ class AgentScopeReactMathWorkflow(Workflow): We use the AgentScope V1 version here. """ + can_reset: bool = True + is_async: bool = True + def __init__( self, *, @@ -42,25 +45,19 @@ def __init__( self.openai_async_client = model.get_openai_async_client() self.model_name = self.openai_async_client.model_path - temperature = self.rollout_args.get("temperature", 1.0) - max_tokens = self.rollout_args.get("max_tokens", 4096) self.agent_model = OpenAIChatModel( api_key="EMPTY", model_name=self.model_name, stream=False, generate_kwargs={ - "temperature": temperature, - "max_tokens": max_tokens, + "temperature": self.task.rollout_args.temperature, + "max_tokens": self.task.rollout_args.max_tokens or 4096, }, ) self.agent_model.client = self.openai_async_client self.agent_model_formatter = OpenAIChatFormatter() self.reset(task) - @property - def resettable(self): - return True - def reset(self, task: Task): self.system_prompt = """ You are an agent specialized in solving math problems with tools. Please solve the math problem given to you. You can write and execute Python code to perform calculation or verify your answer. You should return your final answer within \\boxed{{}}. @@ -104,15 +101,6 @@ def reset(self, task: Task): # we use the boxed format to evaluate the answer self.reward_fn = MathBoxedRewardFn() - @property - def repeatable(self): - return False - - @property - def asynchronous(self): - """Whether the workflow runs in async mode.""" - return True - async def run_async(self): # make sure that we have the correct import try: diff --git a/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py index b32eeefe9a..2585113303 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py @@ -17,6 +17,9 @@ class AgentScopeV1ReactSearchWorkflow(Workflow): This workflow serves as an example of how to use the agentscope framework within the trinity workflow. """ + can_reset: bool = True + is_async: bool = True + def __init__( self, *, @@ -42,15 +45,13 @@ def __init__( self.openai_async_client = model.get_openai_async_client() self.model_name = self.openai_async_client.model_path - temperature = self.rollout_args.get("temperature", 1.0) - max_tokens = self.rollout_args.get("max_tokens", 4096) self.agent_model = OpenAIChatModel( api_key="EMPTY", model_name=self.model_name, stream=False, generate_kwargs={ - "temperature": temperature, - "max_tokens": max_tokens, + "temperature": self.task.rollout_args.temperature, + "max_tokens": self.task.rollout_args.max_tokens or 4096, }, ) self.agent_model.client = self.openai_async_client @@ -58,19 +59,6 @@ def __init__( self.reset(task) - @property - def resettable(self): - return True - - @property - def asynchronous(self): - """Whether the workflow runs in async mode.""" - return True - - @property - def repeatable(self): - return False - def reset(self, task: Task): self.workflow_args = task.workflow_args self.max_turns = int(self.workflow_args.get("max_turns", 10)) diff --git a/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py index 0b2e6fc93e..e9e3f30206 100644 --- a/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py @@ -26,6 +26,10 @@ class RAFTAlfworldWorkflow(Workflow): 2. Generate SFT data from successful attempt """ + can_reset: bool = True + can_repeat: bool = True + is_async: bool = True + def __init__( self, model: ModelWrapper, @@ -55,10 +59,6 @@ def __init__( ) self.reset(task) - @property - def asynchronous(self): - return True - def reset(self, task: Task): """Reset the workflow with a new task""" self.game_file_path = task.task_desc or task.raw_task.get("game_file", "") @@ -220,14 +220,6 @@ async def run_async(self) -> List[Experience]: ) return [experience] - def resettable(self) -> bool: - """Indicate that this workflow can be reset to avoid re-initialization""" - return True - - @property - def repeatable(self) -> bool: - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base diff --git a/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py index bf8ff85b09..81aed995d9 100644 --- a/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py @@ -59,10 +59,6 @@ def __init__( f"Initializing RAFTReflectAlfworldWorkflow with RAFT learning, temperature={self.temperature}" ) - @property - def asynchronous(self): - return True - async def construct_sft_data( self, first_trajectory: List[Dict[str, str]], diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py index 36370b55f1..f722e82ce6 100644 --- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py @@ -97,6 +97,8 @@ def parse_action(response): class AlfworldWorkflow(MultiTurnWorkflow): """A workflow for alfworld task.""" + is_async: bool = True + def __init__( self, model: ModelWrapper, @@ -111,10 +113,6 @@ def __init__( self.repeat_times = task.repeat_times self.max_env_steps = task.workflow_args.get("max_env_steps", 30) - @property - def asynchronous(self): - return True - async def get_model_response(self, messages): responses = await self.model.chat_async(messages, n=1) return responses diff --git a/trinity/common/workflows/envs/email_searcher/workflow.py b/trinity/common/workflows/envs/email_searcher/workflow.py index 2fe97f13f6..737f9fc279 100644 --- a/trinity/common/workflows/envs/email_searcher/workflow.py +++ b/trinity/common/workflows/envs/email_searcher/workflow.py @@ -27,6 +27,9 @@ class EmailSearchWorkflow(Workflow): Multi-turn Email Search workflow (ReAct-style tool use). """ + can_reset: bool = True + is_async: bool = True + def __init__( self, *, @@ -45,18 +48,6 @@ def __init__( self.reset(task) - @property - def repeatable(self) -> bool: - return False - - @property - def resettable(self): - return True - - @property - def asynchronous(self): - return True - def reset(self, task: Task): self.query = QueryModel.model_validate(task.raw_task) self.task_desc = task.task_desc # question @@ -87,8 +78,8 @@ def reset(self, task: Task): openai_client=self.openai_client, system_prompt=self.system_prompt, generate_kwargs={ - "temperature": self.rollout_args.get("temperature", 1.0), - "max_tokens": self.rollout_args.get("max_tokens", 4096), + "temperature": self.task.rollout_args.temperature, + "max_tokens": self.task.rollout_args.max_tokens or 4096, }, response_structure=AnswerModel, ) diff --git a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py index da313e72ca..4c2a7417a9 100644 --- a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py +++ b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py @@ -59,6 +59,8 @@ def parse_action(response): class SciWorldWorkflow(MultiTurnWorkflow): """A workflow for sciworld task.""" + is_async: bool = True + def __init__( self, model: ModelWrapper, @@ -73,10 +75,6 @@ def __init__( self.repeat_times = task.repeat_times self.max_env_steps = 30 # should be less than 100 - @property - def asynchronous(self): - return True - async def get_model_response(self, messages): responses = await self.model.chat_async(messages, n=1) return responses diff --git a/trinity/common/workflows/envs/webshop/webshop_workflow.py b/trinity/common/workflows/envs/webshop/webshop_workflow.py index 5b18ee48ae..4cef8fc456 100644 --- a/trinity/common/workflows/envs/webshop/webshop_workflow.py +++ b/trinity/common/workflows/envs/webshop/webshop_workflow.py @@ -181,6 +181,9 @@ def validate_action(action, available_actions): class WebShopWorkflow(MultiTurnWorkflow): """A workflow for webshop task.""" + can_reset: bool = True + is_async: bool = True + def __init__( self, model: ModelWrapper, @@ -209,10 +212,6 @@ def __init__( "WebAgentTextEnv-v0", observation_mode="text_rich", num_products=None, human_goals=True ) - @property - def resettable(self): - return True - def reset(self, task: Task): self.task_desc = task.task_desc or "0" self.repeat_times = task.repeat_times diff --git a/trinity/common/workflows/eval_workflow.py b/trinity/common/workflows/eval_workflow.py index 03348d6bc2..55ea78e74e 100644 --- a/trinity/common/workflows/eval_workflow.py +++ b/trinity/common/workflows/eval_workflow.py @@ -42,14 +42,6 @@ def __init__( # TODO: customize the config in the yaml self.eval_gen_args = asdict(GenerationConfig(temperature=0.6, top_p=0.8, logprobs=0, n=1)) - @property - def resettable(self): - return False - - @property - def repeatable(self): - return False - def format_messages(self): """Format message for the evaluation of qwen_boxed type.""" if not self.raw_task or "question" not in self.raw_task: @@ -89,9 +81,7 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("async_math_eval_workflow") class AsyncMathEvalWorkflow(MathEvalWorkflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: messages = self.format_messages() diff --git a/trinity/common/workflows/math_rm_workflow.py b/trinity/common/workflows/math_rm_workflow.py index f498acf2ed..aca5586717 100644 --- a/trinity/common/workflows/math_rm_workflow.py +++ b/trinity/common/workflows/math_rm_workflow.py @@ -55,9 +55,7 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("async_math_rm_workflow") class AsyncMathRMWorkflow(MathRMWorkflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: messages = self.format_messages() diff --git a/trinity/common/workflows/math_ruler_workflow.py b/trinity/common/workflows/math_ruler_workflow.py index 9eb1462652..fec1bc72a4 100644 --- a/trinity/common/workflows/math_ruler_workflow.py +++ b/trinity/common/workflows/math_ruler_workflow.py @@ -160,9 +160,7 @@ def get_ruler_scores( @WORKFLOWS.register_module("async_math_ruler_workflow") class AsyncMathRULERWorkflow(MathRULERWorkflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: """Modified from SimpleWorkflow.run""" diff --git a/trinity/common/workflows/simple_mm_workflow.py b/trinity/common/workflows/simple_mm_workflow.py index c24fffe19c..2bb3f857e2 100644 --- a/trinity/common/workflows/simple_mm_workflow.py +++ b/trinity/common/workflows/simple_mm_workflow.py @@ -82,9 +82,7 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("async_simple_mm_workflow") class AsyncSimpleMMWorkflow(SimpleMMWorkflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: # TODO: test generate_mm diff --git a/trinity/common/workflows/step_wise_workflow.py b/trinity/common/workflows/step_wise_workflow.py index 608eea3bf2..1042866063 100644 --- a/trinity/common/workflows/step_wise_workflow.py +++ b/trinity/common/workflows/step_wise_workflow.py @@ -67,17 +67,11 @@ def max_step_num(self): """Return the maximum number of steps in the task.""" raise NotImplementedError - @property - def repeatable(self): - return False - class AsyncStepWiseRewardWorkflow(StepWiseRewardWorkflow): """Async version of `StepWiseRewardWorkflow`.""" - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> list[Experience]: """Run the workflow and return a list of experiences with step-wise rewards asynchronously.""" @@ -184,17 +178,11 @@ def max_step_num(self): """Return the maximum number of steps in the task.""" raise NotImplementedError - @property - def repeatable(self): - return False - class AsyncRewardPropagationWorkflow(RewardPropagationWorkflow): """Async version of `RewardPropagationWorkflow`.""" - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> list[Experience]: """Run the workflow and return a list of experiences with step-wise rewards asynchronously.""" diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index c5694f595a..26a555d7cc 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -80,6 +80,10 @@ class Workflow: A workflow is a runnable object which generates a list of experiences. """ + can_reset: bool = False # whether the workflow can be reset with a new task. If true, `reset()` must be implemented. + can_repeat: bool = False # whether the workflow can be repeated multiple times. If true, `set_repeat_times()` must be implemented. + is_async: bool = False # whether the workflow runs in async mode. If true, `run_async()` must be implemented, else `run()` must be implemented. + def __init__( self, *, @@ -95,21 +99,21 @@ def __init__( @property def resettable(self): - return False + """Deprecated, use cls.can_reset instead.""" + return self.__class__.can_reset @property def repeatable(self): - """A workflow is repeatable if it can be run multiple times within the run() or run_async() method.""" - return False + """Deprecated, use cls.can_repeat instead. + A workflow is repeatable if it can be run multiple times within the run() or run_async() method. + """ + return self.__class__.can_repeat @property def asynchronous(self): - """Whether the workflow runs in async mode.""" - return False - - @property - def rollout_args(self): - return asdict(self.task.rollout_args) + """Deprecated, use cls.is_async instead. + Whether the workflow runs in async mode.""" + return self.__class__.is_async def reset(self, task: Task): """Reset the workflow.""" @@ -140,6 +144,8 @@ class MultiTurnWorkflow(Workflow): The base workflow class for concatenated multi-turn tasks. """ + can_repeat: bool = True + def __init__( self, *, @@ -153,10 +159,6 @@ def __init__( auxiliary_models=auxiliary_models, ) - @property - def repeatable(self): - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base @@ -190,6 +192,9 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc class SimpleWorkflow(Workflow): """A workflow for simple single-round task.""" + can_reset: bool = True + can_repeat: bool = True + def __init__( self, *, @@ -204,10 +209,6 @@ def __init__( auxiliary_models=auxiliary_models, ) - @property - def resettable(self): - return True - def reset(self, task: Task): self.format_args = task.format_args self.system_prompt = task.format_args.system_prompt @@ -224,15 +225,15 @@ def reset(self, task: Task): else: raise ValueError("`reward_fn` must be a subclass of `RewardFn`") - @property - def repeatable(self): - return True - def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.task.rollout_args.n = repeat_times self.run_id_base = run_id_base + @property + def rollout_args(self): + return asdict(self.task.rollout_args) + def format_messages(self): """Format messages for the instruct model.""" messages = [] @@ -270,9 +271,7 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("async_simple_workflow") class AsyncSimpleWorkflow(Workflow): - @property - def asynchronous(self): - return True + is_async: bool = True async def run_async(self) -> List[Experience]: # TODO: Optimize the generate function diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 7a473137f0..187d0d5adf 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -51,6 +51,7 @@ def __init__( for model in (auxiliary_models or []) ] self.auxiliary_model_clients = [] + self.auxiliary_model_async_clients = [] self.workflow_instance: Workflow = None self.runner_id = runner_id @@ -62,7 +63,9 @@ async def prepare(self) -> None: ) for model in self.auxiliary_models: api_client = model.get_openai_client() + async_api_client = model.get_openai_async_client() self.auxiliary_model_clients.append(api_client) + self.auxiliary_model_async_clients.append(async_api_client) def is_alive(self): return True @@ -76,7 +79,12 @@ def _create_workflow_instance(self, task: Task) -> None: or not self.workflow_instance.resettable ): self.workflow_instance = task.to_workflow( - self.model_wrapper, self.auxiliary_model_clients + self.model_wrapper, + ( + self.auxiliary_model_async_clients + if task.workflow.is_async + else self.auxiliary_model_clients + ), ) else: self.workflow_instance.reset(task)