Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 88 additions & 38 deletions tests/explorer/step_wise_workflow_test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
# -*- coding: utf-8 -*-
"""Test for the general step-wise workflow module"""
import asyncio
import unittest
from dataclasses import dataclass, field
from typing import Dict, Optional
from unittest.mock import MagicMock

from parameterized import parameterized
from torch import Tensor

from tests.tools import get_unittest_dataset_config
from trinity.common.experience import EID, Experience
from trinity.common.workflows.step_wise_workflow import (
AsyncRewardPropagationWorkflow,
AsyncStepWiseRewardWorkflow,
RewardPropagationWorkflow,
StepWiseRewardWorkflow,
)
Expand Down Expand Up @@ -46,6 +50,26 @@ def max_step_num(self):
return self.max_env_steps


class DummyAsyncStepWiseRewardWorkflow(AsyncStepWiseRewardWorkflow):
def __init__(self, model, task: Task, auxiliary_models=None):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.repeat_times = task.repeat_times
self.max_env_steps = task.workflow_args.get("max_env_steps", 1)
self.actual_steps = task.workflow_args.get("actual_steps", 1)

async def step_async(self, step_num: int):
await asyncio.sleep(0.1)
return step_num < self.actual_steps - 1

async def reward_async(self, exps: list[Experience], step_num: int):
await asyncio.sleep(0.1)
return 0.1 * step_num

@property
def max_step_num(self):
return self.max_env_steps


class DummyRewardPropagationWorkflow(RewardPropagationWorkflow):
def __init__(self, model, task: Task, auxiliary_models=None):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
Expand All @@ -64,6 +88,34 @@ def max_step_num(self):
return self.max_env_steps


class DummyAsyncRewardPropagationWorkflow(AsyncRewardPropagationWorkflow):
def __init__(self, model, task: Task, auxiliary_models=None):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.repeat_times = task.repeat_times
self.max_env_steps = task.workflow_args.get("max_env_steps", 1)
self.actual_steps = task.workflow_args.get("actual_steps", 1)

async def step_async(self, step_num: int):
await asyncio.sleep(0.1)
return step_num < self.actual_steps - 1

async def reward_async(self, exps: list[Experience]):
await asyncio.sleep(0.1)
return 0.1 * len(exps)

@property
def max_step_num(self):
return self.max_env_steps


_dummy_workflows = [
DummyStepWiseRewardWorkflow,
DummyAsyncStepWiseRewardWorkflow,
DummyRewardPropagationWorkflow,
DummyAsyncRewardPropagationWorkflow,
]


class WorkflowTest(unittest.TestCase):
def setUp(self) -> None:
self.model = MagicMock()
Expand All @@ -73,14 +125,18 @@ def setUp(self) -> None:
]
self.taskset_config = get_unittest_dataset_config("countdown")

def test_step_wise_reward_workflow(self) -> None:
@parameterized.expand([(DummyStepWiseRewardWorkflow,), (DummyAsyncStepWiseRewardWorkflow,)])
def test_step_wise_reward_workflow(self, workflow_cls) -> None:
task = Task(
workflow=DummyStepWiseRewardWorkflow,
workflow=workflow_cls,
repeat_times=self.taskset_config.repeat_times,
workflow_args={"max_env_steps": 10, "actual_steps": 5},
)
workflow = task.to_workflow(model=self.model)
experiences = workflow.run()
if workflow.asynchronous:
experiences = asyncio.run(workflow.run_async())
else:
experiences = workflow.run()

self.assertEqual(len(experiences), 5)
actual_steps = [exp.eid.step for exp in experiences]
Expand All @@ -90,14 +146,20 @@ def test_step_wise_reward_workflow(self) -> None:
for actual, expected in zip(actual_rewards, expected_rewards):
self.assertAlmostEqual(actual, expected) # type: ignore

def test_reward_propagation_workflow(self) -> None:
@parameterized.expand(
[(DummyRewardPropagationWorkflow,), (DummyAsyncRewardPropagationWorkflow,)]
)
def test_reward_propagation_workflow(self, workflow_cls) -> None:
task = Task(
workflow=DummyRewardPropagationWorkflow,
workflow=workflow_cls,
repeat_times=self.taskset_config.repeat_times,
workflow_args={"max_env_steps": 10, "actual_steps": 5},
)
workflow = task.to_workflow(model=self.model)
experiences = workflow.run()
if workflow.asynchronous:
experiences = asyncio.run(workflow.run_async())
else:
experiences = workflow.run()

self.assertEqual(len(experiences), 5)
actual_steps = [exp.eid.step for exp in experiences]
Expand All @@ -107,38 +169,26 @@ def test_reward_propagation_workflow(self) -> None:
self.assertAlmostEqual(exp.reward, expected_reward) # type: ignore

def test_workflows_stop_at_max_env_steps(self) -> None:
task = Task(
workflow=DummyStepWiseRewardWorkflow,
repeat_times=self.taskset_config.repeat_times,
workflow_args={"max_env_steps": 3, "actual_steps": 100}, # actual > max
)
workflow = task.to_workflow(model=self.model)
experiences = workflow.run()
self.assertEqual(len(experiences), 3)

task = Task(
workflow=DummyRewardPropagationWorkflow,
repeat_times=self.taskset_config.repeat_times,
workflow_args={"max_env_steps": 3, "actual_steps": 100}, # actual > max
)
workflow = task.to_workflow(model=self.model)
experiences = workflow.run()
self.assertEqual(len(experiences), 3)
for workflow in _dummy_workflows:
task = Task(
workflow=workflow,
repeat_times=self.taskset_config.repeat_times,
workflow_args={"max_env_steps": 3, "actual_steps": 100}, # actual > max
)
workflow = task.to_workflow(model=self.model)
if workflow.asynchronous:
experiences = asyncio.run(workflow.run_async()) # type: ignore
else:
experiences = workflow.run()
self.assertEqual(len(experiences), 3)

def test_workflows_raise_error(self) -> None:
self.model.enable_history = False
task = Task(
workflow=DummyStepWiseRewardWorkflow,
repeat_times=self.taskset_config.repeat_times,
workflow_args={"max_env_steps": 10, "actual_steps": 5},
)
with self.assertRaises(AssertionError):
task.to_workflow(model=self.model)

task = Task(
workflow=DummyRewardPropagationWorkflow,
repeat_times=self.taskset_config.repeat_times,
workflow_args={"max_env_steps": 10, "actual_steps": 5},
)
with self.assertRaises(AssertionError):
task.to_workflow(model=self.model)
for workflow in _dummy_workflows:
task = Task(
workflow=workflow,
repeat_times=self.taskset_config.repeat_times,
workflow_args={"max_env_steps": 10, "actual_steps": 5},
)
with self.assertRaises(AssertionError):
task.to_workflow(model=self.model)
153 changes: 143 additions & 10 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# -*- coding: utf-8 -*-
"""Test for the workflow module"""
import asyncio
import unittest
from dataclasses import dataclass, field
from typing import Dict, Optional
from unittest.mock import MagicMock

from parameterized import parameterized, parameterized_class
from torch import Tensor

from tests.tools import get_unittest_dataset_config
from tests.common.vllm_test import CHAT_TEMPLATE
from tests.tools import get_model_path, get_template_config, get_unittest_dataset_config
from trinity.common.experience import EID
from trinity.common.models import create_inference_models
from trinity.common.models.model import ModelWrapper
from trinity.common.rewards import RMGalleryFn
from trinity.common.workflows import (
MathBoxedWorkflow,
Expand All @@ -17,7 +22,7 @@
MathWorkflow,
Workflow,
)
from trinity.common.workflows.workflow import Task
from trinity.common.workflows.workflow import MultiTurnWorkflow, Task


@dataclass
Expand Down Expand Up @@ -68,6 +73,84 @@ def run(self):
raise ValueError("Invalid output format")


class DummyAsyncWorkflow(Workflow):
def __init__(self, model, task: Task, auxiliary_models=None):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.obj = task.raw_task
self.output_format = task.workflow_args["output_format"]
self.repeat_times = task.rollout_args.n

@property
def resettable(self):
return True

@property
def repeatable(self):
return True

@property
def asynchronous(self):
return True

def reset(self, task: Task):
self.obj = task.raw_task
self.output_format = task.workflow_args["output_format"]

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base

async def run_async(self):
await asyncio.sleep(0.1)
if self.output_format == "json":
import json

return [json.dumps(self.obj)] * self.repeat_times
elif self.output_format == "yaml":
import yaml

return [yaml.safe_dump(self.obj)] * self.repeat_times
else:
raise ValueError("Invalid output format")


class DummyMultiTurnWorkflow(MultiTurnWorkflow):
def __init__(self, model, task: Task, auxiliary_models=None):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.contents = task.raw_task["contents"] # type: ignore

def run(self):
memory = [{"role": "system", "content": "You are a helpful assistant."}]
experience_list = []
for content in self.contents:
memory.append({"role": "user", "content": content})
memory.append({"role": "assistant", "content": content.upper()})
experience = self.process_messages_to_experience(memory, 0, {})
experience_list.append(experience)
return experience_list


class DummyAsyncMultiTurnWorkflow(MultiTurnWorkflow):
def __init__(self, model, task: Task, auxiliary_models=None):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.contents = task.raw_task["contents"] # type: ignore

@property
def asynchronous(self):
return True

async def run_async(self):
memory = [{"role": "system", "content": "You are a helpful assistant."}]
experience_list = []
for content in self.contents:
await asyncio.sleep(0.1)
memory.append({"role": "user", "content": content})
memory.append({"role": "assistant", "content": content.upper()})
experience = self.process_messages_to_experience(memory, 0, {})
experience_list.append(experience)
return experience_list


class WorkflowTest(unittest.TestCase):
def test_math_workflow(self) -> None:
model = MagicMock()
Expand Down Expand Up @@ -323,37 +406,87 @@ def test_math_eval_workflow(self) -> None:
assert exp.metrics is not None, f"Metrics for response {i} should not be None"
self.assertEqual(exp.metrics["accuracy"], expected_acc)

def test_workflow_resettable(self) -> None:
@parameterized.expand([(DummyWorkflow,), (DummyAsyncWorkflow,)])
def test_workflow_resettable(self, workflow_cls) -> None:
model = MagicMock()
json_task = Task(
workflow=DummyWorkflow,
workflow=workflow_cls,
repeat_times=1,
raw_task={"a": 1},
workflow_args={"output_format": "json"},
)
yaml_task = Task(
workflow=DummyWorkflow,
workflow=workflow_cls,
repeat_times=1,
raw_task={"a": 1},
workflow_args={"output_format": "yaml"},
)
workflow = json_task.to_workflow(model)
answer = workflow.run()
if workflow.asynchronous:
answer = asyncio.run(workflow.run_async())
else:
answer = workflow.run()
self.assertEqual(answer[0], '{"a": 1}')
workflow.reset(yaml_task)
answer = workflow.run()
if workflow.asynchronous:
answer = asyncio.run(workflow.run_async())
else:
answer = workflow.run()
self.assertEqual(answer[0], "a: 1\n")

def test_workflow_repeatable(self) -> None:
@parameterized.expand([(DummyWorkflow,), (DummyAsyncWorkflow,)])
def test_workflow_repeatable(self, workflow_cls) -> None:
model = MagicMock()
task = Task(
workflow=DummyWorkflow,
workflow=workflow_cls,
repeat_times=3,
raw_task={"a": 1},
workflow_args={"output_format": "json"},
)
workflow = task.to_workflow(model)
workflow.set_repeat_times(2, run_id_base=0)
self.assertEqual(workflow.repeat_times, 2)
answer = workflow.run()
if workflow.asynchronous:
answer = asyncio.run(workflow.run_async())
else:
answer = workflow.run()
self.assertEqual(len(answer), 2)


@parameterized_class(
("workflow_cls",),
[
(DummyMultiTurnWorkflow,),
(DummyAsyncMultiTurnWorkflow,),
],
)
class MultiTurnWorkflowTest(unittest.TestCase):
def setUp(self):
# configure the model
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.model.max_model_len = None # self.max_model_len
self.config.explorer.rollout_model.engine_num = 1 # self.engine_num
self.config.explorer.rollout_model.tensor_parallel_size = 1 # self.tensor_parallel_size
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.algorithm.repeat_times = 2 # self.repeat_times
self.config.explorer.rollout_model.enable_history = True # self.enable_history
self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)

def test_multi_turn_workflow(self):
task = Task(
workflow=self.workflow_cls,
repeat_times=3,
raw_task={"contents": ["hello world!", "how are you?"]},
workflow_args={"output_format": "json"},
)
workflow = task.to_workflow(self.model_wrapper)
workflow.set_repeat_times(2, run_id_base=0)
if workflow.asynchronous:
answer = asyncio.run(workflow.run_async())
else:
answer = workflow.run()
self.assertEqual(len(answer), 2)
Loading