diff --git a/docs/sphinx_doc/source/tutorial/example_multi_turn.md b/docs/sphinx_doc/source/tutorial/example_multi_turn.md index 6cc28690e6..aa20529e4f 100644 --- a/docs/sphinx_doc/source/tutorial/example_multi_turn.md +++ b/docs/sphinx_doc/source/tutorial/example_multi_turn.md @@ -139,8 +139,8 @@ and include it in the init file `trinity/common/workflows/__init__.py` ```diff # -*- coding: utf-8 -*- """Workflow module""" - from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow -+from .envs.alfworld.alfworld_workflow import AlfworldWorkflow + from trinity.common.workflows.workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow ++from trinity.common.workflows.envs.alfworld.alfworld_workflow import AlfworldWorkflow __all__ = [ "WORKFLOWS", diff --git a/docs/sphinx_doc/source/tutorial/example_step_wise.md b/docs/sphinx_doc/source/tutorial/example_step_wise.md index 73d6042de2..3dd9706802 100644 --- a/docs/sphinx_doc/source/tutorial/example_step_wise.md +++ b/docs/sphinx_doc/source/tutorial/example_step_wise.md @@ -62,8 +62,8 @@ and include it in the init file `trinity/common/workflows/__init__.py` ```diff # -*- coding: utf-8 -*- """Workflow module""" - from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow -+from .envs.alfworld.alfworld_workflow import StepWiseAlfworldWorkflow + from trinity.common.workflows.workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow ++from trinity.common.workflows.envs.alfworld.alfworld_workflow import StepWiseAlfworldWorkflow __all__ = [ "WORKFLOWS", diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index c87ec8779f..6e98df5b46 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -212,7 +212,7 @@ For workflows that are prepared to be contributed to Trinity-RFT project, you ne ```python # existing import lines -from .example_workflow import ExampleWorkflow +from trinity.common.workflows.example_workflow import ExampleWorkflow __all__ = [ # existing __all__ lines diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 82ac0c898d..be0a5ed2fe 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -167,6 +167,11 @@ def model_version(self) -> int: """Get the version of the model.""" return ray.get(self.model.get_model_version.remote()) + @property + async def model_version_async(self) -> int: + """Get the version of the model.""" + return await self.model.get_model_version.remote() + def _get_api_server_address(self) -> str: """Get the address of the API server.""" if self.api_address: diff --git a/trinity/common/rewards/__init__.py b/trinity/common/rewards/__init__.py index c723788908..ad36b8103a 100644 --- a/trinity/common/rewards/__init__.py +++ b/trinity/common/rewards/__init__.py @@ -2,13 +2,13 @@ """Reward functions for RFT""" # isort: off -from .reward_fn import REWARD_FUNCTIONS, RewardFn, RMGalleryFn +from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn, RMGalleryFn -from .accuracy_reward import AccuracyReward -from .countdown_reward import CountDownRewardFn -from .dapo_reward import MathDAPORewardFn -from .format_reward import FormatReward -from .math_reward import MathBoxedRewardFn, MathRewardFn +from trinity.common.rewards.accuracy_reward import AccuracyReward +from trinity.common.rewards.countdown_reward import CountDownRewardFn +from trinity.common.rewards.dapo_reward import MathDAPORewardFn +from trinity.common.rewards.format_reward import FormatReward +from trinity.common.rewards.math_reward import MathBoxedRewardFn, MathRewardFn # isort: on diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index fa80b850b0..54342fea24 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -1,29 +1,65 @@ # -*- coding: utf-8 -*- """Workflow module""" -from .customized_math_workflows import MathBoxedWorkflow -from .customized_toolcall_workflows import ToolCallWorkflow -from .envs.agentscope.agentscopev0_react_workflow import ( # will be deprecated soon +from trinity.common.workflows.customized_math_workflows import ( + AsyncMathBoxedWorkflow, + MathBoxedWorkflow, +) +from trinity.common.workflows.customized_toolcall_workflows import ToolCallWorkflow +from trinity.common.workflows.envs.agentscope.agentscopev0_react_workflow import ( # will be deprecated soon AgentScopeV0ReactMathWorkflow, ) -from .envs.agentscope.agentscopev1_react_workflow import AgentScopeReactMathWorkflow -from .envs.alfworld.alfworld_workflow import AlfworldWorkflow, StepWiseAlfworldWorkflow -from .envs.alfworld.RAFT_alfworld_workflow import RAFTAlfworldWorkflow -from .envs.alfworld.RAFT_reflect_alfworld_workflow import RAFTReflectAlfworldWorkflow -from .envs.email_searcher.workflow import EmailSearchWorkflow -from .envs.sciworld.sciworld_workflow import SciWorldWorkflow -from .envs.webshop.webshop_workflow import WebShopWorkflow -from .eval_workflow import MathEvalWorkflow -from .math_rm_workflow import MathRMWorkflow -from .math_ruler_workflow import MathRULERWorkflow -from .math_trainable_ruler_workflow import MathTrainableRULERWorkflow -from .simple_mm_workflow import SimpleMMWorkflow -from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow +from trinity.common.workflows.envs.agentscope.agentscopev1_react_workflow import ( + AgentScopeReactMathWorkflow, +) +from trinity.common.workflows.envs.alfworld.alfworld_workflow import ( + AlfworldWorkflow, + StepWiseAlfworldWorkflow, +) +from trinity.common.workflows.envs.alfworld.RAFT_alfworld_workflow import ( + RAFTAlfworldWorkflow, +) +from trinity.common.workflows.envs.alfworld.RAFT_reflect_alfworld_workflow import ( + RAFTReflectAlfworldWorkflow, +) +from trinity.common.workflows.envs.email_searcher.workflow import EmailSearchWorkflow +from trinity.common.workflows.envs.sciworld.sciworld_workflow import SciWorldWorkflow +from trinity.common.workflows.envs.webshop.webshop_workflow import WebShopWorkflow +from trinity.common.workflows.eval_workflow import ( + AsyncMathEvalWorkflow, + MathEvalWorkflow, +) +from trinity.common.workflows.math_rm_workflow import ( + AsyncMathRMWorkflow, + MathRMWorkflow, +) +from trinity.common.workflows.math_ruler_workflow import ( + AsyncMathRULERWorkflow, + MathRULERWorkflow, +) +from trinity.common.workflows.math_trainable_ruler_workflow import ( + MathTrainableRULERWorkflow, +) +from trinity.common.workflows.simple_mm_workflow import ( + AsyncSimpleMMWorkflow, + SimpleMMWorkflow, +) +from trinity.common.workflows.workflow import ( + WORKFLOWS, + AsyncMathWorkflow, + AsyncSimpleWorkflow, + MathWorkflow, + SimpleWorkflow, + Task, + Workflow, +) __all__ = [ "Task", "Workflow", "WORKFLOWS", + "AsyncSimpleWorkflow", "SimpleWorkflow", + "AsyncMathWorkflow", "MathWorkflow", "WebShopWorkflow", "AlfworldWorkflow", @@ -31,14 +67,19 @@ "RAFTAlfworldWorkflow", "RAFTReflectAlfworldWorkflow", "SciWorldWorkflow", + "AsyncMathBoxedWorkflow", "MathBoxedWorkflow", + "AsyncMathRMWorkflow", "MathRMWorkflow", "ToolCallWorkflow", + "AsyncMathEvalWorkflow", "MathEvalWorkflow", "AgentScopeV0ReactMathWorkflow", # will be deprecated soon "AgentScopeReactMathWorkflow", "EmailSearchWorkflow", + "AsyncMathRULERWorkflow", "MathRULERWorkflow", "MathTrainableRULERWorkflow", + "AsyncSimpleMMWorkflow", "SimpleMMWorkflow", ] diff --git a/trinity/common/workflows/customized_math_workflows.py b/trinity/common/workflows/customized_math_workflows.py index 4fdfbf2a21..f76983a324 100644 --- a/trinity/common/workflows/customized_math_workflows.py +++ b/trinity/common/workflows/customized_math_workflows.py @@ -91,3 +91,48 @@ def run(self) -> List[Experience]: f"self.task_desc: {self.task_desc}, prompt_text: {prompt_text}, response: {response.response_text}, reward: {reward}" ) return responses + + +@WORKFLOWS.register_module("async_math_boxed_workflow") +class AsyncMathBoxedWorkflow(MathBoxedWorkflow): + @property + def asynchronous(self): + return True + + async def run_async(self) -> List[Experience]: + if not self.use_base: + messages = self.format_messages() + else: + prompt_text = self.format_prompt() + + self.logger.debug("start chat") + if not self.use_base: + responses = await self.model.chat_async(messages, **self.rollout_args) + else: + responses = await self.model.generate_async([prompt_text], **self.rollout_args) + + for i, response in enumerate(responses): + reward_dict = self.reward_fn( # type: ignore [misc] + response=response.response_text, # type: ignore [arg-type] + truth=self.truth, + with_think=self.with_think, + format_score_coef=self.format_score_coef, + response_token=response.tokens[response.prompt_length :], + ) + + if response.metrics is None: + response.metrics = {} + response.metrics.update(reward_dict) + reward = sum(reward_dict.values()) + response.reward = reward + response.eid.run = i + self.run_id_base + + if not self.use_base: + self.logger.debug( + f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" + ) + else: + self.logger.debug( + f"self.task_desc: {self.task_desc}, prompt_text: {prompt_text}, response: {response.response_text}, reward: {reward}" + ) + return responses diff --git a/trinity/common/workflows/customized_toolcall_workflows.py b/trinity/common/workflows/customized_toolcall_workflows.py index ed5dc98771..1e31e492dd 100644 --- a/trinity/common/workflows/customized_toolcall_workflows.py +++ b/trinity/common/workflows/customized_toolcall_workflows.py @@ -227,16 +227,20 @@ def reset(self, task: Task): self.workflow_args = task.workflow_args self.reward_fn_args = task.reward_fn_args + @property + def asynchronous(self): + return True + def format_prompt(self): raw_task = self.raw_task messages = construct_prompt(raw_task) return messages - def run(self) -> List[Experience]: + async def run_async(self) -> List[Experience]: messages = self.format_prompt() self.logger.debug("start chat") - responses = self.model.chat(messages, **self.rollout_args) + responses = await self.model.chat_async(messages, **self.rollout_args) for i, response in enumerate(responses): reward = 0.0 diff --git a/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py index 5b3cb8d786..34235d238f 100644 --- a/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py @@ -55,6 +55,10 @@ def __init__( ) self.reset(task) + @property + def asynchronous(self): + return True + def reset(self, task: Task): """Reset the workflow with a new task""" self.game_file_path = task.task_desc or task.raw_task.get("game_file", "") @@ -64,7 +68,7 @@ def create_environment(self, game_file): """Create alfworld environment""" return create_alfworld_environment(game_file) - def run_single_rollout( + async def run_single_rollout( self, env ) -> tuple[List[Dict[str, str]], float, bool, int, List[Dict[str, str]]]: """Run a single rollout with RAFT-guided actions""" @@ -82,7 +86,7 @@ def run_single_rollout( trajectory.append({"role": "user", "content": format_observation(observation)}) # Get model response with RAFT guidance - responses = self.model.chat( + responses = await self.model.chat_async( trajectory, n=1, temperature=self.temperature, @@ -134,12 +138,12 @@ def run_single_rollout( # If timeout, return the last reward from environment instead of fixed value return trajectory, last_reward, False, self.max_env_steps, parsed_steps - def _execute_first_attempt(self) -> tuple: + async def _execute_first_attempt(self) -> tuple: """Execute the first attempt and return results""" env = self.create_environment(self.game_file_path) try: - trajectory, reward, done, steps, parsed_steps = self.run_single_rollout(env) + trajectory, reward, done, steps, parsed_steps = await self.run_single_rollout(env) except Exception as e: print(f"Single rollout failed: {e}") env.close() @@ -151,11 +155,11 @@ def _execute_first_attempt(self) -> tuple: return trajectory, reward, done, steps, parsed_steps, success, traj_format_valid - def eval_alfworld(self) -> List[Experience]: + async def eval_alfworld(self) -> List[Experience]: """Evaluate a single alfworld trajectory""" env = self.create_environment(self.game_file_path) try: - trajectory, reward, done, steps, parsed_steps = self.run_single_rollout(env) + trajectory, reward, done, steps, parsed_steps = await self.run_single_rollout(env) except Exception as e: print(f"Single rollout failed during eval: {e}") env.close() @@ -175,11 +179,11 @@ def eval_alfworld(self) -> List[Experience]: return [experience] - def run(self) -> List[Experience]: + async def run_async(self) -> List[Experience]: """Run the RAFT alfworld workflow and return experiences""" if self.is_eval: - return self.eval_alfworld() + return await self.eval_alfworld() # Execute first attempt try: @@ -191,7 +195,7 @@ def run(self) -> List[Experience]: parsed_steps, success, traj_format_valid, - ) = self._execute_first_attempt() + ) = await self._execute_first_attempt() except Exception as e: return [generate_default_empty_experience(f"Training rollout failed: {str(e)}")] diff --git a/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py index 8c7ff101a8..bf8ff85b09 100644 --- a/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import os from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper @@ -59,7 +59,11 @@ def __init__( f"Initializing RAFTReflectAlfworldWorkflow with RAFT learning, temperature={self.temperature}" ) - def construct_sft_data( + @property + def asynchronous(self): + return True + + async def construct_sft_data( self, first_trajectory: List[Dict[str, str]], success: bool, @@ -75,7 +79,7 @@ def construct_sft_data( new_success, new_steps, new_parsed_steps, - ) = self.re_explore_with_context(first_trajectory, reward, success, original_steps) + ) = await self.re_explore_with_context(first_trajectory, reward, success, original_steps) # Consider improvement if reward is higher OR same reward with fewer steps reward_improved = new_reward > reward @@ -92,7 +96,7 @@ def construct_sft_data( new_parsed_steps, ) - def re_explore_with_context( + async def re_explore_with_context( self, first_trajectory: List[Dict[str, str]], original_reward: float, @@ -130,7 +134,7 @@ def re_explore_with_context( # Add to clean SFT trajectory sft_trajectory.append({"role": "user", "content": format_observation(observation)}) - responses = self.model.chat( + responses = await self.model.chat_async( context_messages, n=1, temperature=self.temperature, @@ -183,18 +187,18 @@ def _handle_invalid_format_success( ) return [experience] - def _execute_second_attempt( + async def _execute_second_attempt( self, trajectory: list, success: bool, reward: float, steps: int - ) -> tuple: + ) -> Union[tuple[List[Dict[str, str]], Dict[str, Any], List[Dict[str, str]],], Exception]: """Execute second attempt and return SFT data""" try: - sft_messages, re_explore_info, new_parsed_steps = self.construct_sft_data( + sft_messages, re_explore_info, new_parsed_steps = await self.construct_sft_data( trajectory, success, reward, steps ) - return sft_messages, re_explore_info, new_parsed_steps, None + return sft_messages, re_explore_info, new_parsed_steps except Exception as e: print(f"SFT data construction failed: {e}") - return None, None, None, e + return e def _build_metrics( self, reward: float, steps: int, new_parsed_steps: list, re_explore_info: dict @@ -220,11 +224,11 @@ def _generate_experience_from_sft(self, sft_messages: list, metrics: dict) -> Ex """Generate experience from SFT messages""" return process_messages_to_experience(self.model, sft_messages, info=metrics) - def run(self) -> List[Experience]: + async def run_async(self) -> List[Experience]: """Run the RAFT alfworld workflow and return experiences""" if self.is_eval: - return self.eval_alfworld() + return await self.eval_alfworld() # Generate unique task ID using timestamp task_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f") @@ -239,7 +243,7 @@ def run(self) -> List[Experience]: parsed_steps, success, traj_format_valid, - ) = self._execute_first_attempt() + ) = await self._execute_first_attempt() except Exception as e: return [generate_default_empty_experience(f"Training rollout failed: {str(e)}")] @@ -256,14 +260,15 @@ def run(self) -> List[Experience]: print(f"Task result: done={done}, reward={reward:.3f}, steps={steps}, success={success}") # Execute second attempt - sft_messages, re_explore_info, new_parsed_steps, error = self._execute_second_attempt( - trajectory, success, reward, steps - ) - if error: + ret = await self._execute_second_attempt(trajectory, success, reward, steps) + if isinstance(ret, Exception): + error = ret default_experience = generate_default_empty_experience( f"SFT data construction failed: {str(error)}", ) return [default_experience] + else: + sft_messages, re_explore_info, new_parsed_steps = ret # Validate second attempt and build metrics second_success = re_explore_info["new_reward"] >= 1 diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py index 97ebd8238d..36370b55f1 100644 --- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py @@ -111,14 +111,18 @@ def __init__( self.repeat_times = task.repeat_times self.max_env_steps = task.workflow_args.get("max_env_steps", 30) - def get_model_response(self, messages): - responses = self.model.chat(messages, n=1) + @property + def asynchronous(self): + return True + + async def get_model_response(self, messages): + responses = await self.model.chat_async(messages, n=1) return responses - def get_model_response_text(self, messages): - return self.get_model_response(messages)[0].response_text + async def get_model_response_text(self, messages): + return (await self.get_model_response(messages))[0].response_text - def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: + async def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: # TODO: Make this parallel print("Generating env inference samples...") experience_list = [] @@ -130,7 +134,7 @@ def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: for r in range(self.max_env_steps): format_obs = format_observation(observation) memory = memory + [{"role": "user", "content": format_obs}] - response_text = self.get_model_response_text(memory) + response_text = await self.get_model_response_text(memory) memory.append({"role": "assistant", "content": response_text}) action = parse_action(response_text) observation, reward, done, info = env.step(action) @@ -145,7 +149,7 @@ def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: env.close() return experience_list - def run(self) -> List[Experience]: + async def run_async(self) -> List[Experience]: # assume the task_description is the game_file_path generated. # see Trinity-RFT/examples/grpo_alfworld/get_alfworld_data.py game_file_path = self.task_desc @@ -177,7 +181,7 @@ def create_environment(game_file): error_message = f"Error importing AlfworldTWEnv {str(e)}. Please make sure you have installed the alfworld package successfully, following the instructions in https://github.com/alfworld/alfworld" raise ImportError(error_message) env = create_environment(game_file_path) - return self.generate_env_inference_samples(env, rollout_n) + return await self.generate_env_inference_samples(env, rollout_n) @WORKFLOWS.register_module("step_wise_alfworld_workflow") diff --git a/trinity/common/workflows/envs/email_searcher/workflow.py b/trinity/common/workflows/envs/email_searcher/workflow.py index 30e1eb76d4..9262da7c62 100644 --- a/trinity/common/workflows/envs/email_searcher/workflow.py +++ b/trinity/common/workflows/envs/email_searcher/workflow.py @@ -5,11 +5,15 @@ import openai from trinity.common.models.model import ModelWrapper +from trinity.common.workflows.envs.email_searcher.react_agent import EmailSearchAgent +from trinity.common.workflows.envs.email_searcher.utils import ( + AnswerModel, + FinalRubric, + QueryModel, + judge_correctness, +) from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow -from .react_agent import EmailSearchAgent -from .utils import AnswerModel, FinalRubric, QueryModel, judge_correctness - SYSTEM_PROMPT = """You are an email search agent. You are given a user query and a list of tools you can use to search the user's email. Use the tools to search the user's emails and find the answer to the user's query. You may take up to {max_turns} turns to find the answer, so if your first seach doesn't find the answer, you can try with different keywords. Always describe what you see and plan your next steps clearly. When taking actions, explain what you're doing and why. When the answer to the task is found, call `generate_response` to finish the process. Only call `generate_response` when answer is found. You should not respond any next steps in `generate_response`. Complete all steps and then call `generate_response`. diff --git a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py index c851ae465a..da313e72ca 100644 --- a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py +++ b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py @@ -73,14 +73,18 @@ def __init__( self.repeat_times = task.repeat_times self.max_env_steps = 30 # should be less than 100 - def get_model_response(self, messages): - responses = self.model.chat(messages, n=1) + @property + def asynchronous(self): + return True + + async def get_model_response(self, messages): + responses = await self.model.chat_async(messages, n=1) return responses - def get_model_response_text(self, messages): - return self.get_model_response(messages)[0].response_text + async def get_model_response_text(self, messages): + return (await self.get_model_response(messages))[0].response_text - def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: + async def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: # TODO: Make this parallel print("Generating env inference samples...") golden_rounds = len(env.get_gold_action_sequence()) @@ -97,7 +101,7 @@ def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: for r in range(self.max_env_steps): format_obs = format_observation(observation) memory = memory + [{"role": "user", "content": format_obs}] - response_text = self.get_model_response_text(memory) + response_text = await self.get_model_response_text(memory) memory.append({"role": "assistant", "content": response_text}) action = parse_action(response_text) observation, reward, done, info = env.step(action) @@ -116,7 +120,7 @@ def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: env.close() return experience_list - def run(self) -> List[Experience]: + async def run_async(self) -> List[Experience]: # assume the task_description is the json object containing task index and the var_num # see Trinity-RFT/script/data_prepare/get_scriworld_data.py task_desc = self.task_desc @@ -141,4 +145,4 @@ def create_environment(task_config): error_message = f"Error importing SciWorldTWEnv {str(e)}. Please make sure you have installed the sciworld package successfully, following the instructions in https://github.com/allenai/ScienceWorld" raise ImportError(error_message) env = create_environment(task_config) - return self.generate_env_inference_samples(env, rollout_n) + return await self.generate_env_inference_samples(env, rollout_n) diff --git a/trinity/common/workflows/envs/webshop/webshop_workflow.py b/trinity/common/workflows/envs/webshop/webshop_workflow.py index eea564fcad..5b18ee48ae 100644 --- a/trinity/common/workflows/envs/webshop/webshop_workflow.py +++ b/trinity/common/workflows/envs/webshop/webshop_workflow.py @@ -217,14 +217,16 @@ def reset(self, task: Task): self.task_desc = task.task_desc or "0" self.repeat_times = task.repeat_times - def get_model_response(self, messages): - responses = self.model.chat(messages, n=1) + async def get_model_response(self, messages): + responses = await self.model.chat_async(messages, n=1) return responses - def get_model_response_text(self, messages): - return self.get_model_response(messages)[0].response_text + async def get_model_response_text(self, messages): + return (await self.get_model_response(messages))[0].response_text - def generate_env_inference_samples(self, env, session_id, rollout_num) -> List[Experience]: + async def generate_env_inference_samples( + self, env, session_id, rollout_num + ) -> List[Experience]: # TODO: Make this parallel print("Generating env inference samples...") experience_list = [] @@ -238,7 +240,7 @@ def generate_env_inference_samples(self, env, session_id, rollout_num) -> List[E available_actions = env.get_available_actions() format_obs = format_observation(observation) memory = memory + [{"role": "user", "content": format_obs}] - response_text = self.get_model_response_text(memory) + response_text = await self.get_model_response_text(memory) memory.append({"role": "assistant", "content": response_text}) action = parse_action(response_text) action_valid, error_msg = validate_action(action, available_actions) @@ -266,8 +268,8 @@ def generate_env_inference_samples(self, env, session_id, rollout_num) -> List[E experience_list.append(experience) return experience_list - def run(self) -> List[Experience]: + async def run_async(self) -> List[Experience]: # assume the task_description is the session_id generated. session_id = int(self.task_desc) rollout_n = self.repeat_times - return self.generate_env_inference_samples(self.env, session_id, rollout_n) + return await self.generate_env_inference_samples(self.env, session_id, rollout_n) diff --git a/trinity/common/workflows/eval_workflow.py b/trinity/common/workflows/eval_workflow.py index 2d97a5325b..03348d6bc2 100644 --- a/trinity/common/workflows/eval_workflow.py +++ b/trinity/common/workflows/eval_workflow.py @@ -85,3 +85,30 @@ def run(self) -> List[Experience]: response.metrics.update(acc_metrics) return responses + + +@WORKFLOWS.register_module("async_math_eval_workflow") +class AsyncMathEvalWorkflow(MathEvalWorkflow): + @property + def asynchronous(self): + return True + + async def run_async(self) -> List[Experience]: + messages = self.format_messages() + + responses: List[Experience] = await self.model.chat_async(messages, **self.eval_gen_args) + + for response in responses: + if response.response_text is None or self.task.truth is None: + continue + + accuracy, _ = verify_math_answer( + response_text=response.response_text, ground_truth=self.task.truth + ) + + acc_metrics = {"accuracy": accuracy} + if response.metrics is None: + response.metrics = {} + response.metrics.update(acc_metrics) + + return responses diff --git a/trinity/common/workflows/math_rm_workflow.py b/trinity/common/workflows/math_rm_workflow.py index b46f5b7713..f498acf2ed 100644 --- a/trinity/common/workflows/math_rm_workflow.py +++ b/trinity/common/workflows/math_rm_workflow.py @@ -51,3 +51,34 @@ def run(self) -> List[Experience]: f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" ) return responses + + +@WORKFLOWS.register_module("async_math_rm_workflow") +class AsyncMathRMWorkflow(MathRMWorkflow): + @property + def asynchronous(self): + return True + + async def run_async(self) -> List[Experience]: + messages = self.format_messages() + + self.logger.debug("start chat") + responses = await self.model.chat_async(messages, **self.rollout_args) + for i, response in enumerate(responses): + reward_dict = self.reward_fn( # type: ignore + response, + messages, + ground_truth=self.truth, + ) + + if response.metrics is None: + response.metrics = {} + response.metrics.update(reward_dict) + reward = sum(reward_dict.values()) + response.reward = reward + response.eid.run = i + self.run_id_base + + self.logger.debug( + f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" + ) + return responses diff --git a/trinity/common/workflows/math_ruler_workflow.py b/trinity/common/workflows/math_ruler_workflow.py index 21bdebdb1c..9eb1462652 100644 --- a/trinity/common/workflows/math_ruler_workflow.py +++ b/trinity/common/workflows/math_ruler_workflow.py @@ -156,3 +156,49 @@ def get_ruler_scores( "Unable to parse the list in judger response, set scores to all zero." ) return False, [0.0 for _ in range(num_responses)] + + +@WORKFLOWS.register_module("async_math_ruler_workflow") +class AsyncMathRULERWorkflow(MathRULERWorkflow): + @property + def asynchronous(self): + return True + + async def run_async(self) -> List[Experience]: + """Modified from SimpleWorkflow.run""" + + messages = self.format_messages() + + self.logger.debug("start chat") + responses = await self.model.chat_async(messages, **self.rollout_args) + + for i, response in enumerate(responses): + gold_reward_dict = self.reward_fn( # type: ignore [misc] + response=response.response_text, # type: ignore [arg-type] + truth=self.truth, + ) + + if response.metrics is None: + response.metrics = {} + + response.metrics.update(gold_reward_dict) + gold_reward = sum(gold_reward_dict.values()) + response.metrics.update({"gold_reward": gold_reward}) + response.eid.run = i + self.run_id_base + + self.logger.debug( + f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, gold_reward: {gold_reward}" + ) + + # === RULER scores as rewards === + assert ( + self.auxiliary_models is not None + ), "Current implementation of RULER requires that auxiliary_models is not None." + judge_success, ruler_scores = self.get_ruler_scores( + responses=responses, judger=self.auxiliary_models[0] + ) + for i, response in enumerate(responses): + response.reward = ruler_scores[i] + response.metrics.update({"judge_success": float(judge_success)}) + + return responses diff --git a/trinity/common/workflows/simple_mm_workflow.py b/trinity/common/workflows/simple_mm_workflow.py index 95ea213ebb..bb8fc50c09 100644 --- a/trinity/common/workflows/simple_mm_workflow.py +++ b/trinity/common/workflows/simple_mm_workflow.py @@ -74,3 +74,37 @@ def run(self) -> List[Experience]: self.logger.debug(f"Generated {len(responses)} responses") return responses + + +@WORKFLOWS.register_module("async_simple_mm_workflow") +class AsyncSimpleMMWorkflow(SimpleMMWorkflow): + @property + def asynchronous(self): + return True + + async def run_async(self) -> List[Experience]: + messages = self.format_messages() + + # TODO: test generate_mm + self.logger.debug("start chat") + if self.raw_mm_data: + responses = await self.model.chat_mm_async( + messages, self.raw_mm_data, **self.rollout_args + ) + else: + responses = await self.model.chat_async(messages, **self.rollout_args) + for i, response in enumerate(responses): + reward_dict = self.reward_fn( # type: ignore [misc] + response=response.response_text, # type: ignore [arg-type] + truth=self.truth, + ) + + if response.metrics is None: + response.metrics = {} + response.metrics.update(reward_dict) + reward = sum(reward_dict.values()) + response.reward = reward + response.eid.run = i + self.run_id_base + + self.logger.debug(f"Generated {len(responses)} responses") + return responses diff --git a/trinity/common/workflows/step_wise_workflow.py b/trinity/common/workflows/step_wise_workflow.py index 20dd294a21..23ae76fbf0 100644 --- a/trinity/common/workflows/step_wise_workflow.py +++ b/trinity/common/workflows/step_wise_workflow.py @@ -76,6 +76,56 @@ def repeatable(self): return False +class AsyncStepWiseRewardWorkflow(StepWiseRewardWorkflow): + """Async version of `StepWiseRewardWorkflow`.""" + + @property + def asynchronous(self): + return True + + async def run_async(self) -> list[Experience]: + """Run the workflow and return a list of experiences with step-wise rewards asynchronously.""" + experiences = [] + for step in range(self.max_step_num): + # Run a single step of the agent application + continue_run = await self.step_async(step_num=step) + # Collect experiences data of the current step + exps = self.model.extract_experience_from_history() + # Calculate the reward for the current step + reward = await self.reward_async(exps, step_num=step) + for exp in exps: + exp.reward = reward + # set the step number in each experience + exp.eid.step = step + # Store the step experiences + experiences.extend(exps) + if not continue_run: + break + + return experiences + + @abstractmethod + async def step_async(self, step_num: int) -> bool: + """Run a single step of your agent application asynchronously. + + Args: + step_num (int): The current step number. + + Returns: + bool: Whether to continue running the agent application. + + Tips: + You can use the openai client (`self.client`) to migrate your existing + applications at low cost. + """ + pass + + @abstractmethod + async def reward_async(self, exps: list[Experience], step_num: int) -> float: + """Calculate the reward for the given experiences at the specified step asynchronously.""" + pass + + class RewardPropagationWorkflow(Workflow): """A workflow that propagates rewards across multiple turns.""" @@ -145,3 +195,55 @@ def max_step_num(self): @property def repeatable(self): return False + + +class AsyncRewardPropagationWorkflow(RewardPropagationWorkflow): + """Async version of `RewardPropagationWorkflow`.""" + + @property + def asynchronous(self): + return True + + async def run_async(self) -> list[Experience]: + """Run the workflow and return a list of experiences with step-wise rewards asynchronously.""" + experiences = [] + for step in range(self.max_step_num): + # Run a single step of the agent application + continue_run = await self.step_async(step_num=step) + # Collect experiences data of the current step + exps = self.model.extract_experience_from_history() + # set the step number in each experience + for exp in exps: + exp.eid.step = step + # Store the step experiences + experiences.extend(exps) + if not continue_run: + break + reward = self.reward(experiences) + for exp in experiences: + exp.reward = reward + if exp.metrics is None: + exp.metrics = {} + exp.metrics["actual_env_steps"] = step + 1 # +1 because step starts from 0 + return experiences + + @abstractmethod + async def step_async(self, step_num: int) -> bool: + """Run a single step of your agent application asynchronously. + + Args: + step_num (int): The current step number. + + Returns: + bool: Whether to continue running the agent application. + + Tips: + You can use the openai client (`self.client`) to migrate your existing + applications at low cost. + """ + pass + + @abstractmethod + async def reward_async(self, exps: list[Experience]) -> float: + """Calculate the reward for the given experiences of the entire run asynchronously.""" + pass diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 7bb4d184d5..b5067d5bd8 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -265,6 +265,37 @@ def run(self) -> List[Experience]: return responses +@WORKFLOWS.register_module("async_simple_workflow") +class AsyncSimpleWorkflow(Workflow): + @property + def asynchronous(self): + return True + + async def run_async(self) -> List[Experience]: + # TODO: Optimize the generate function + messages = self.format_messages() + + self.logger.debug("start chat") + responses = await self.model.chat_async(messages, **self.rollout_args) + for i, response in enumerate(responses): + reward_dict = self.reward_fn( # type: ignore [misc] + response=response.response_text, # type: ignore [arg-type] + truth=self.truth, + ) + + if response.metrics is None: + response.metrics = {} + response.metrics.update(reward_dict) + reward = sum(reward_dict.values()) + response.reward = reward + response.eid.run = i + self.run_id_base + + self.logger.debug( + f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" + ) + return responses + + @WORKFLOWS.register_module("math_workflow") class MathWorkflow(SimpleWorkflow): """A workflow for math tasks as introduced in DeepSeek-R1.""" @@ -293,3 +324,8 @@ def reset(self, task: Task): """ # call the SimpleWorkflow.reset super().reset(task) + + +@WORKFLOWS.register_module("async_math_workflow") +class AsyncMathWorkflow(AsyncSimpleWorkflow, MathWorkflow): + pass diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index b6bb628b6a..4a6ed67e30 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -103,6 +103,7 @@ async def run_task( exps = await self._run_task(task, repeat_times, run_id_base) assert exps is not None and len(exps) > 0, "An empty experience is generated" metrics: dict[str, List[float]] = defaultdict(list) + model_version = await self.model_wrapper.model_version_async # set eid for each experience for i, exp in enumerate(exps): exp.eid.batch = task.batch_id @@ -111,7 +112,7 @@ async def run_task( exp.eid.task = task.task_id if not hasattr(exp, "info") or exp.info is None: exp.info = {} - exp.info["model_version"] = self.model_wrapper.model_version + exp.info["model_version"] = model_version exp.info["use_count"] = 0 if not hasattr(exp, "metrics") or exp.metrics is None: diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 3242e39143..ddeb5cf72d 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -299,7 +299,10 @@ def _build_model_optimizer( # noqa: C901 _apply_liger_kernel_to_instance, ) - _apply_liger_kernel_to_instance(model=actor_module) + fused_linear_cross_entropy = not use_fused_kernels + _apply_liger_kernel_to_instance( + model=actor_module, fused_linear_cross_entropy=fused_linear_cross_entropy + ) fused_kernel_options = self.config.model.get("fused_kernel_options", None) fused_kernels_backend = ( diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 2eb09dca4d..4f6b33c3de 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -14,7 +14,6 @@ from verl.trainer.ppo.metric_utils import ( compute_throughout_metrics, compute_timing_metrics, - reduce_metrics, ) from verl.trainer.ppo.ray_trainer import ( RayClassWithInitArgs, @@ -27,6 +26,7 @@ from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.debug import marked_timer from verl.utils.fs import copy_local_path_from_hdfs +from verl.utils.metric import reduce_metrics from trinity.algorithm import ADVANTAGE_FN, KL_FN, SAMPLE_STRATEGY from trinity.algorithm.algorithm import ALGORITHM_TYPE