Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source/tutorial/example_multi_turn.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ and include it in the init file `trinity/common/workflows/__init__.py`
```diff
# -*- coding: utf-8 -*-
"""Workflow module"""
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow
+from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
from trinity.common.workflows.workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow
+from trinity.common.workflows.envs.alfworld.alfworld_workflow import AlfworldWorkflow

__all__ = [
"WORKFLOWS",
Expand Down
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source/tutorial/example_step_wise.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ and include it in the init file `trinity/common/workflows/__init__.py`
```diff
# -*- coding: utf-8 -*-
"""Workflow module"""
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow
+from .envs.alfworld.alfworld_workflow import StepWiseAlfworldWorkflow
from trinity.common.workflows.workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow
+from trinity.common.workflows.envs.alfworld.alfworld_workflow import StepWiseAlfworldWorkflow

__all__ = [
"WORKFLOWS",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ For workflows that are prepared to be contributed to Trinity-RFT project, you ne

```python
# existing import lines
from .example_workflow import ExampleWorkflow
from trinity.common.workflows.example_workflow import ExampleWorkflow

__all__ = [
# existing __all__ lines
Expand Down
5 changes: 5 additions & 0 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ def model_version(self) -> int:
"""Get the version of the model."""
return ray.get(self.model.get_model_version.remote())

@property
async def model_version_async(self) -> int:
"""Get the version of the model."""
return await self.model.get_model_version.remote()

def _get_api_server_address(self) -> str:
"""Get the address of the API server."""
if self.api_address:
Expand Down
12 changes: 6 additions & 6 deletions trinity/common/rewards/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
"""Reward functions for RFT"""

# isort: off
from .reward_fn import REWARD_FUNCTIONS, RewardFn, RMGalleryFn
from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn, RMGalleryFn

from .accuracy_reward import AccuracyReward
from .countdown_reward import CountDownRewardFn
from .dapo_reward import MathDAPORewardFn
from .format_reward import FormatReward
from .math_reward import MathBoxedRewardFn, MathRewardFn
from trinity.common.rewards.accuracy_reward import AccuracyReward
from trinity.common.rewards.countdown_reward import CountDownRewardFn
from trinity.common.rewards.dapo_reward import MathDAPORewardFn
from trinity.common.rewards.format_reward import FormatReward
from trinity.common.rewards.math_reward import MathBoxedRewardFn, MathRewardFn

# isort: on

Expand Down
73 changes: 57 additions & 16 deletions trinity/common/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,85 @@
# -*- coding: utf-8 -*-
"""Workflow module"""
from .customized_math_workflows import MathBoxedWorkflow
from .customized_toolcall_workflows import ToolCallWorkflow
from .envs.agentscope.agentscopev0_react_workflow import ( # will be deprecated soon
from trinity.common.workflows.customized_math_workflows import (
AsyncMathBoxedWorkflow,
MathBoxedWorkflow,
)
from trinity.common.workflows.customized_toolcall_workflows import ToolCallWorkflow
from trinity.common.workflows.envs.agentscope.agentscopev0_react_workflow import ( # will be deprecated soon
AgentScopeV0ReactMathWorkflow,
)
from .envs.agentscope.agentscopev1_react_workflow import AgentScopeReactMathWorkflow
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow, StepWiseAlfworldWorkflow
from .envs.alfworld.RAFT_alfworld_workflow import RAFTAlfworldWorkflow
from .envs.alfworld.RAFT_reflect_alfworld_workflow import RAFTReflectAlfworldWorkflow
from .envs.email_searcher.workflow import EmailSearchWorkflow
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
from .envs.webshop.webshop_workflow import WebShopWorkflow
from .eval_workflow import MathEvalWorkflow
from .math_rm_workflow import MathRMWorkflow
from .math_ruler_workflow import MathRULERWorkflow
from .math_trainable_ruler_workflow import MathTrainableRULERWorkflow
from .simple_mm_workflow import SimpleMMWorkflow
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow
from trinity.common.workflows.envs.agentscope.agentscopev1_react_workflow import (
AgentScopeReactMathWorkflow,
)
from trinity.common.workflows.envs.alfworld.alfworld_workflow import (
AlfworldWorkflow,
StepWiseAlfworldWorkflow,
)
from trinity.common.workflows.envs.alfworld.RAFT_alfworld_workflow import (
RAFTAlfworldWorkflow,
)
from trinity.common.workflows.envs.alfworld.RAFT_reflect_alfworld_workflow import (
RAFTReflectAlfworldWorkflow,
)
from trinity.common.workflows.envs.email_searcher.workflow import EmailSearchWorkflow
from trinity.common.workflows.envs.sciworld.sciworld_workflow import SciWorldWorkflow
from trinity.common.workflows.envs.webshop.webshop_workflow import WebShopWorkflow
from trinity.common.workflows.eval_workflow import (
AsyncMathEvalWorkflow,
MathEvalWorkflow,
)
from trinity.common.workflows.math_rm_workflow import (
AsyncMathRMWorkflow,
MathRMWorkflow,
)
from trinity.common.workflows.math_ruler_workflow import (
AsyncMathRULERWorkflow,
MathRULERWorkflow,
)
from trinity.common.workflows.math_trainable_ruler_workflow import (
MathTrainableRULERWorkflow,
)
from trinity.common.workflows.simple_mm_workflow import (
AsyncSimpleMMWorkflow,
SimpleMMWorkflow,
)
from trinity.common.workflows.workflow import (
WORKFLOWS,
AsyncMathWorkflow,
AsyncSimpleWorkflow,
MathWorkflow,
SimpleWorkflow,
Task,
Workflow,
)

__all__ = [
"Task",
"Workflow",
"WORKFLOWS",
"AsyncSimpleWorkflow",
"SimpleWorkflow",
"AsyncMathWorkflow",
"MathWorkflow",
"WebShopWorkflow",
"AlfworldWorkflow",
"StepWiseAlfworldWorkflow",
"RAFTAlfworldWorkflow",
"RAFTReflectAlfworldWorkflow",
"SciWorldWorkflow",
"AsyncMathBoxedWorkflow",
"MathBoxedWorkflow",
"AsyncMathRMWorkflow",
"MathRMWorkflow",
"ToolCallWorkflow",
"AsyncMathEvalWorkflow",
"MathEvalWorkflow",
"AgentScopeV0ReactMathWorkflow", # will be deprecated soon
"AgentScopeReactMathWorkflow",
"EmailSearchWorkflow",
"AsyncMathRULERWorkflow",
"MathRULERWorkflow",
"MathTrainableRULERWorkflow",
"AsyncSimpleMMWorkflow",
"SimpleMMWorkflow",
]
45 changes: 45 additions & 0 deletions trinity/common/workflows/customized_math_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,48 @@ def run(self) -> List[Experience]:
f"self.task_desc: {self.task_desc}, prompt_text: {prompt_text}, response: {response.response_text}, reward: {reward}"
)
return responses


@WORKFLOWS.register_module("async_math_boxed_workflow")
class AsyncMathBoxedWorkflow(MathBoxedWorkflow):
@property
def asynchronous(self):
return True

async def run_async(self) -> List[Experience]:
if not self.use_base:
messages = self.format_messages()
else:
prompt_text = self.format_prompt()

self.logger.debug("start chat")
if not self.use_base:
responses = await self.model.chat_async(messages, **self.rollout_args)
else:
responses = await self.model.generate_async([prompt_text], **self.rollout_args)

for i, response in enumerate(responses):
reward_dict = self.reward_fn( # type: ignore [misc]
response=response.response_text, # type: ignore [arg-type]
truth=self.truth,
with_think=self.with_think,
format_score_coef=self.format_score_coef,
response_token=response.tokens[response.prompt_length :],
)

if response.metrics is None:
response.metrics = {}
response.metrics.update(reward_dict)
reward = sum(reward_dict.values())
response.reward = reward
response.eid.run = i + self.run_id_base

if not self.use_base:
self.logger.debug(
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
)
else:
self.logger.debug(
f"self.task_desc: {self.task_desc}, prompt_text: {prompt_text}, response: {response.response_text}, reward: {reward}"
)
return responses
8 changes: 6 additions & 2 deletions trinity/common/workflows/customized_toolcall_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,20 @@ def reset(self, task: Task):
self.workflow_args = task.workflow_args
self.reward_fn_args = task.reward_fn_args

@property
def asynchronous(self):
return True

def format_prompt(self):
raw_task = self.raw_task
messages = construct_prompt(raw_task)
return messages

def run(self) -> List[Experience]:
async def run_async(self) -> List[Experience]:
messages = self.format_prompt()

self.logger.debug("start chat")
responses = self.model.chat(messages, **self.rollout_args)
responses = await self.model.chat_async(messages, **self.rollout_args)

for i, response in enumerate(responses):
reward = 0.0
Expand Down
22 changes: 13 additions & 9 deletions trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def __init__(
)
self.reset(task)

@property
def asynchronous(self):
return True

def reset(self, task: Task):
"""Reset the workflow with a new task"""
self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
Expand All @@ -64,7 +68,7 @@ def create_environment(self, game_file):
"""Create alfworld environment"""
return create_alfworld_environment(game_file)

def run_single_rollout(
async def run_single_rollout(
self, env
) -> tuple[List[Dict[str, str]], float, bool, int, List[Dict[str, str]]]:
"""Run a single rollout with RAFT-guided actions"""
Expand All @@ -82,7 +86,7 @@ def run_single_rollout(
trajectory.append({"role": "user", "content": format_observation(observation)})

# Get model response with RAFT guidance
responses = self.model.chat(
responses = await self.model.chat_async(
trajectory,
n=1,
temperature=self.temperature,
Expand Down Expand Up @@ -134,12 +138,12 @@ def run_single_rollout(
# If timeout, return the last reward from environment instead of fixed value
return trajectory, last_reward, False, self.max_env_steps, parsed_steps

def _execute_first_attempt(self) -> tuple:
async def _execute_first_attempt(self) -> tuple:
"""Execute the first attempt and return results"""
env = self.create_environment(self.game_file_path)

try:
trajectory, reward, done, steps, parsed_steps = self.run_single_rollout(env)
trajectory, reward, done, steps, parsed_steps = await self.run_single_rollout(env)
except Exception as e:
print(f"Single rollout failed: {e}")
env.close()
Expand All @@ -151,11 +155,11 @@ def _execute_first_attempt(self) -> tuple:

return trajectory, reward, done, steps, parsed_steps, success, traj_format_valid

def eval_alfworld(self) -> List[Experience]:
async def eval_alfworld(self) -> List[Experience]:
"""Evaluate a single alfworld trajectory"""
env = self.create_environment(self.game_file_path)
try:
trajectory, reward, done, steps, parsed_steps = self.run_single_rollout(env)
trajectory, reward, done, steps, parsed_steps = await self.run_single_rollout(env)
except Exception as e:
print(f"Single rollout failed during eval: {e}")
env.close()
Expand All @@ -175,11 +179,11 @@ def eval_alfworld(self) -> List[Experience]:

return [experience]

def run(self) -> List[Experience]:
async def run_async(self) -> List[Experience]:
"""Run the RAFT alfworld workflow and return experiences"""

if self.is_eval:
return self.eval_alfworld()
return await self.eval_alfworld()

# Execute first attempt
try:
Expand All @@ -191,7 +195,7 @@ def run(self) -> List[Experience]:
parsed_steps,
success,
traj_format_valid,
) = self._execute_first_attempt()
) = await self._execute_first_attempt()
except Exception as e:
return [generate_default_empty_experience(f"Training rollout failed: {str(e)}")]

Expand Down
Loading