Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
from typing import Dict, Optional
from unittest.mock import MagicMock

import ray
from parameterized import parameterized, parameterized_class
from torch import Tensor

from tests.common.vllm_test import CHAT_TEMPLATE
from tests.tools import get_model_path, get_template_config, get_unittest_dataset_config
from trinity.common.experience import EID
from trinity.common.experience import EID, Experience
from trinity.common.models import create_inference_models
from trinity.common.models.model import ModelWrapper
from trinity.common.rewards import RMGalleryFn
from trinity.common.workflows import (
WORKFLOWS,
MathBoxedWorkflow,
MathEvalWorkflow,
MathRMWorkflow,
Expand Down Expand Up @@ -489,3 +491,44 @@ def test_multi_turn_workflow(self):
else:
answer = workflow.run()
self.assertEqual(len(answer), 2)

def tearDown(self):
ray.shutdown(_exiting_interpreter=True)


class TestAgentScopeWorkflowAdapter(unittest.IsolatedAsyncioTestCase):
@unittest.skip("Waiting for agentscope>=0.1.6")
async def test_adapter(self):
try:
from agentscope.model import TrinityChatModel
except ImportError:
self.skipTest("agentscope >= 0.1.6 is not installed")

async def as_workflow_func(task, model) -> float:
self.assertIsInstance(task, dict)
self.assertIsInstance(model, TrinityChatModel)
return task["reward"]

model = MagicMock()
openai_client = MagicMock()
openai_client.model_path = "Qwen/Qwen3-8B"
model.get_openai_async_client.return_value = openai_client
model.extract_experience_from_history.return_value = [
Experience(tokens=Tensor([0, 1, 2]), prompt_length=1, logprobs=Tensor([0.1, 0.2])),
Experience(tokens=Tensor([3, 4, 5]), prompt_length=2, logprobs=Tensor([0.3])),
]

as_adapter_cls = WORKFLOWS.get("agentscope_workflow_adapter")
as_adapter = as_adapter_cls(
task=Task(
raw_task={"reward": 0.1},
workflow_args={"workflow_func": as_workflow_func},
),
model=model,
)
result = await as_adapter.run_async()
self.assertEqual(len(result), 2)
self.assertEqual(result[0].reward, 0.1)
self.assertEqual(result[0].prompt_length, 1)
self.assertEqual(result[1].reward, 0.1)
self.assertEqual(result[1].prompt_length, 2)
2 changes: 2 additions & 0 deletions trinity/common/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from trinity.common.workflows.agentscope.react.react_workflow import (
AgentScopeReActWorkflow,
)
from trinity.common.workflows.agentscope_workflow import AgentScopeWorkflowAdapter
from trinity.common.workflows.customized_math_workflows import (
AsyncMathBoxedWorkflow,
MathBoxedWorkflow,
Expand Down Expand Up @@ -92,4 +93,5 @@
"AsyncSimpleMMWorkflow",
"SimpleMMWorkflow",
"RubricJudgeWorkflow",
"AgentScopeWorkflowAdapter",
]
83 changes: 83 additions & 0 deletions trinity/common/workflows/agentscope_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Awaitable, Callable, Dict, List, Optional

import openai

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow


@WORKFLOWS.register_module("agentscope_workflow_adapter")
class AgentScopeWorkflowAdapter(Workflow):
"""Adapter to wrap a agentscope trainable workflow function into a Trinity Workflow."""

def __init__(
self,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
"""Initialize the adapter with the task and model."""
try:
from agentscope.model import TrinityChatModel
except ImportError:
raise ImportError(
"This workflow requires agentscope >= 0.1.6, please install "
"it via `pip install agentscope>=0.1.6`",
)

super().__init__(
task=task,
model=model,
auxiliary_models=auxiliary_models,
)
self.workflow_func: Callable[
[Dict, TrinityChatModel], Awaitable[float]
] = task.workflow_args.get("workflow_func", None)

if self.workflow_func is None:
raise ValueError(
"The 'workflow_func' is not provided.",
)

self.chat_model: TrinityChatModel = TrinityChatModel(
model.get_openai_async_client(),
)

@property
def asynchronous(self) -> bool:
"""This workflow runs asynchronously."""
return True

@property
def repeatable(self) -> bool:
"""This workflow is not repeatable."""
return False

@property
def resetable(self) -> bool:
"""This workflow cannot be reset."""
return False

def construct_experiences(
self,
reward: float,
) -> List[Experience]:
"""Construct experiences from the agent's interaction history.

Args:
reward (float): The reward value to assign to each experience.

Returns:
List: A list of Experience objects.
"""
exps = self.model.extract_experience_from_history()
for exp in exps:
exp.reward = reward
return exps

async def run_async(self) -> List[Experience]:
"""Run the workflow asynchronously and return experiences."""
reward = await self.workflow_func(self.task.raw_task, self.chat_model) # type: ignore [arg-type]
return self.construct_experiences(reward)