diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 1710ce9f38..e5dab0004b 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -6,16 +6,18 @@ from typing import Dict, Optional from unittest.mock import MagicMock +import ray from parameterized import parameterized, parameterized_class from torch import Tensor from tests.common.vllm_test import CHAT_TEMPLATE from tests.tools import get_model_path, get_template_config, get_unittest_dataset_config -from trinity.common.experience import EID +from trinity.common.experience import EID, Experience from trinity.common.models import create_inference_models from trinity.common.models.model import ModelWrapper from trinity.common.rewards import RMGalleryFn from trinity.common.workflows import ( + WORKFLOWS, MathBoxedWorkflow, MathEvalWorkflow, MathRMWorkflow, @@ -489,3 +491,44 @@ def test_multi_turn_workflow(self): else: answer = workflow.run() self.assertEqual(len(answer), 2) + + def tearDown(self): + ray.shutdown(_exiting_interpreter=True) + + +class TestAgentScopeWorkflowAdapter(unittest.IsolatedAsyncioTestCase): + @unittest.skip("Waiting for agentscope>=0.1.6") + async def test_adapter(self): + try: + from agentscope.model import TrinityChatModel + except ImportError: + self.skipTest("agentscope >= 0.1.6 is not installed") + + async def as_workflow_func(task, model) -> float: + self.assertIsInstance(task, dict) + self.assertIsInstance(model, TrinityChatModel) + return task["reward"] + + model = MagicMock() + openai_client = MagicMock() + openai_client.model_path = "Qwen/Qwen3-8B" + model.get_openai_async_client.return_value = openai_client + model.extract_experience_from_history.return_value = [ + Experience(tokens=Tensor([0, 1, 2]), prompt_length=1, logprobs=Tensor([0.1, 0.2])), + Experience(tokens=Tensor([3, 4, 5]), prompt_length=2, logprobs=Tensor([0.3])), + ] + + as_adapter_cls = WORKFLOWS.get("agentscope_workflow_adapter") + as_adapter = as_adapter_cls( + task=Task( + raw_task={"reward": 0.1}, + workflow_args={"workflow_func": as_workflow_func}, + ), + model=model, + ) + result = await as_adapter.run_async() + self.assertEqual(len(result), 2) + self.assertEqual(result[0].reward, 0.1) + self.assertEqual(result[0].prompt_length, 1) + self.assertEqual(result[1].reward, 0.1) + self.assertEqual(result[1].prompt_length, 2) diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index 492f8b3a60..1277f041f9 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -3,6 +3,7 @@ from trinity.common.workflows.agentscope.react.react_workflow import ( AgentScopeReActWorkflow, ) +from trinity.common.workflows.agentscope_workflow import AgentScopeWorkflowAdapter from trinity.common.workflows.customized_math_workflows import ( AsyncMathBoxedWorkflow, MathBoxedWorkflow, @@ -92,4 +93,5 @@ "AsyncSimpleMMWorkflow", "SimpleMMWorkflow", "RubricJudgeWorkflow", + "AgentScopeWorkflowAdapter", ] diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py new file mode 100644 index 0000000000..0cbce6e01a --- /dev/null +++ b/trinity/common/workflows/agentscope_workflow.py @@ -0,0 +1,83 @@ +from typing import Awaitable, Callable, Dict, List, Optional + +import openai + +from trinity.common.experience import Experience +from trinity.common.models.model import ModelWrapper +from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow + + +@WORKFLOWS.register_module("agentscope_workflow_adapter") +class AgentScopeWorkflowAdapter(Workflow): + """Adapter to wrap a agentscope trainable workflow function into a Trinity Workflow.""" + + def __init__( + self, + *, + task: Task, + model: ModelWrapper, + auxiliary_models: Optional[List[openai.OpenAI]] = None, + ): + """Initialize the adapter with the task and model.""" + try: + from agentscope.model import TrinityChatModel + except ImportError: + raise ImportError( + "This workflow requires agentscope >= 0.1.6, please install " + "it via `pip install agentscope>=0.1.6`", + ) + + super().__init__( + task=task, + model=model, + auxiliary_models=auxiliary_models, + ) + self.workflow_func: Callable[ + [Dict, TrinityChatModel], Awaitable[float] + ] = task.workflow_args.get("workflow_func", None) + + if self.workflow_func is None: + raise ValueError( + "The 'workflow_func' is not provided.", + ) + + self.chat_model: TrinityChatModel = TrinityChatModel( + model.get_openai_async_client(), + ) + + @property + def asynchronous(self) -> bool: + """This workflow runs asynchronously.""" + return True + + @property + def repeatable(self) -> bool: + """This workflow is not repeatable.""" + return False + + @property + def resetable(self) -> bool: + """This workflow cannot be reset.""" + return False + + def construct_experiences( + self, + reward: float, + ) -> List[Experience]: + """Construct experiences from the agent's interaction history. + + Args: + reward (float): The reward value to assign to each experience. + + Returns: + List: A list of Experience objects. + """ + exps = self.model.extract_experience_from_history() + for exp in exps: + exp.reward = reward + return exps + + async def run_async(self) -> List[Experience]: + """Run the workflow asynchronously and return experiences.""" + reward = await self.workflow_func(self.task.raw_task, self.chat_model) # type: ignore [arg-type] + return self.construct_experiences(reward)