diff --git a/docs/sphinx_doc/assets/agentscope_websearch_reward.png b/docs/sphinx_doc/assets/agentscope_websearch_reward.png new file mode 100644 index 0000000000..11ec30dee5 Binary files /dev/null and b/docs/sphinx_doc/assets/agentscope_websearch_reward.png differ diff --git a/docs/sphinx_doc/assets/agentscope_websearch_turns.png b/docs/sphinx_doc/assets/agentscope_websearch_turns.png new file mode 100644 index 0000000000..0d3bff1ea3 Binary files /dev/null and b/docs/sphinx_doc/assets/agentscope_websearch_turns.png differ diff --git a/examples/agentscope_websearch/README.md b/examples/agentscope_websearch/README.md new file mode 100644 index 0000000000..4e5c9128db --- /dev/null +++ b/examples/agentscope_websearch/README.md @@ -0,0 +1,68 @@ +# Example of Training a Multi-Turn Web Search Agent + +This example demonstrates how to train a multi-turn web search agent based on the ReAct (Reasoning and Acting) paradigm. +It utilizes the **AgentScope** framework, integrated within the Trinity workflow, to equip an agent with external search tools to find information on the web and answer questions. + +We use a subset of the `WebWalkerQA` dataset here. The original dataset can be found at [Hugging Face Datasets](https://huggingface.co/datasets/callanwu/WebWalkerQA). + +The config file is located in [`agentscopev1_websearch_agent.yaml`](agentscopev1_websearch_agent.yaml). + +## Key Features + +* **Training Multi-Turn ReAct Agent**: The workflow trains a `ReActAgent` from AgentScope that can reason and act in multiple steps. +* **External Tool Integration**: The agent connects to web search tools via AgentScope's Message-based Communication Protocol (MCP). It supports: + * **Tavily Search** (`tavily`) + * **SearXNG** (`searxng`) +* **LLM-based Evaluation**: The agent's final answer is evaluated by an auxiliary "judge" LLM against a ground-truth answer to generate a reward signal for training. It is set by the auxiliary model. +* **Asynchronous Execution**: The workflow is designed to run asynchronously, improving performance. + +## Prerequisites + +Before running this workflow, please complete the following setup steps. + +1. **Install Dependencies** + + Install the core AgentScope framework. + ```bash + pip install agentscope + ``` + > **Note**: The required MCP clients (`tavily-mcp` and `mcp-searxng`) will be automatically installed via `npx` on the first run, so no manual installation is needed for them. + +2. **Configure Environment Variables** + + Set the environment variables for the search tool you plan to use. + + * For **Tavily Search**, set your API key: + ```bash + export TAVILY_API_KEY="your_tavily_api_key" + ``` + * For **SearXNG**, set the URL of your self-hosted instance: + ```bash + export SEARXNG_URL="http://your-searxng-instance.com" + ``` + +3. **Generate the Dataset** + + Run the following script to generate the dataset for evaluation. + ```bash + python examples/agentscope_websearch/get_webwalkerQA_data.py + ``` + * **(Optional) Filter the Dataset**: For a more focused evaluation, you can filter the dataset by difficulty. For example, you might want to remove samples that cannot be answered even by more capable models, allowing you to benchmark performance on a more consistent set of problems. + +## Configuration + +All workflow parameters can be configured in the [`agentscopev1_websearch_agent.yaml`](agentscopev1_websearch_agent.yaml) file. Key options under `workflow_args` include: + +* `search_client_type`: The search tool to use. Must be either `"tavily"` or `"searxng"`. If you want to use other search tools, you should make changes in the `trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py`. +* `max_turns`: The maximum number of reasoning/acting steps the agent can take. + + +## Result +Below we provide the training curve of running using `tavily` search tools. +It takes around 8 hours on 8 H20 GPUs. + +Reward curve: +![](../../docs/sphinx_doc/assets/agentscope_websearch_reward.png) + +Memory length of ReAct Agent: +![](../../docs/sphinx_doc/assets/agentscope_websearch_turns.png) diff --git a/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml b/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml new file mode 100644 index 0000000000..5c4fb5cc07 --- /dev/null +++ b/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml @@ -0,0 +1,101 @@ +project: "Trinity_Multi_Step" +name: WebQA_Search_Example +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: grpo + repeat_times: 8 + advantage_fn: grpo +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} + max_response_tokens: 4096 + max_model_len: 20480 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 8 + batch_size: 16 + train_batch_size: 512 # 16*8*4 + explorer_input: + taskset: + name: webqa_train + storage_type: file + path: 'examples/agentscope_websearch/webwalker_rl_dataset' + split: train + format: + prompt_key: 'problem' + response_key: 'answer' + workflow_args: + max_turns: 10 + search_client_type: tavily + rollout_args: + temperature: 1.0 + max_tokens: 4096 + enable_progress_bar: false + eval_tasksets: + - name: webqa_test + storage_type: file + path: 'examples/agentscope_websearch/webwalker_rl_dataset' + split: test + format: + prompt_key: 'problem' + response_key: 'answer' + enable_progress_bar: false + workflow_args: + max_turns: 10 + rollout_args: + temperature: 0.6 + max_tokens: 4096 + default_workflow_type: 'agentscope_v1_react_search_workflow' + trainer_input: + experience_buffer: + name: experience_buffer + storage_type: queue + use_priority_queue: true +explorer: + eval_interval: 10 + max_repeat_times_per_runner: 1 + max_timeout: 3600 + rollout_model: + enable_thinking: true + enable_history: true + enable_openai_api: true + enable_auto_tool_choice: true + tool_call_parser: hermes + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 + gpu_memory_utilization: 0.7 + enable_chunked_prefill: true + auxiliary_models: + - model_path: ${oc.env:TRINITY_AUX_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507} + engine_num: 1 + tensor_parallel_size: 2 + enable_thinking: false + max_prompt_tokens: 20480 + max_response_tokens: 10240 + max_model_len: 32000 +synchronizer: + sync_style: dynamic_by_explorer + sync_method: 'nccl' + sync_interval: 5 + sync_timeout: 3600 +trainer: + save_interval: 20 + trainer_config: + actor_rollout_ref: + model: + use_remove_padding: true + actor: + use_dynamic_bsz: true + ppo_max_token_len_per_gpu: 16384 + ulysses_sequence_parallel_size: 2 + optim: + lr: 1e-6 + ref: + log_prob_use_dynamic_bsz: ${trainer.trainer_config.actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${trainer.trainer_config.actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${trainer.trainer_config.actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size diff --git a/examples/agentscope_websearch/get_webwalkerQA_data.py b/examples/agentscope_websearch/get_webwalkerQA_data.py new file mode 100644 index 0000000000..b94e41a812 --- /dev/null +++ b/examples/agentscope_websearch/get_webwalkerQA_data.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +""" +This script creates Hugging Face dataset-formatted files (train.jsonl, test.jsonl) +from the 'callanwu/WebWalkerQA' dataset for Reinforcement Learning (RL) training. + +The script performs the following actions: +1. Loads the 'callanwu/WebWalkerQA' dataset from the Hugging Face Hub. +2. Splits the data into training and test set candidates based on 'difficulty_level' and 'lang'. + - 'easy' and 'en' entries are candidates for the test set. + - All other entries are candidates for the training set. +3. Randomly samples a specified number of items from the candidate pools. +4. Saves the processed data into 'train.jsonl' and 'test.jsonl' files. +5. Generates a 'dataset_dict.json' file describing the dataset's structure. +""" +import json +import os +import random + +from datasets import load_dataset + + +def create_dataset_files(output_dir, train_sample_size=160, test_sample_size=16): + """ + Loads, processes, and saves the WebWalkerQA dataset. + + Args: + output_dir (str): The directory where the output files will be saved. + train_sample_size (int): The maximum number of samples for the training set. + test_sample_size (int): The maximum number of samples for the test set. + """ + print("Starting dataset file creation...") + + # 1. Create the output directory if it doesn't exist. + os.makedirs(output_dir, exist_ok=True) + print(f"Output directory '{output_dir}' is ready.") + + # 2. Load the dataset from the Hugging Face Hub. + print("Loading 'callanwu/WebWalkerQA' dataset from the Hugging Face Hub...") + try: + # Use trust_remote_code=True if the dataset requires custom loading logic. + ds = load_dataset("callanwu/WebWalkerQA", split="main", trust_remote_code=True) + print("Dataset loaded successfully!") + except Exception as e: + print(f"Failed to load dataset: {e}") + return + + train_candidates = [] + test_candidates = [] + + print("Processing and filtering the data...") + # 3. Iterate through and process each item in the dataset. + for item in ds: + # You may want to apply your own filtering logic here. + # For example, we filtered out examples that can not be answered by gpt4.1 + + info = item.get("info", {}) + difficulty = info.get("difficulty_level", "") + lang = info.get("lang", "") + + # Construct the 'problem' field. + problem = ( + item.get("original_question", "") + + " You should navigate the website to find the answer. " + + "The root url is " + + item.get("root_url", "") + + ". The answer should be based on the information on the website." + ) + answer = item.get("expected_answer", "") + + simple_item = {"problem": problem, "answer": answer} + + # Split the data into test and train candidates. + if difficulty == "easy" and lang == "en": + test_candidates.append(simple_item) + else: + train_candidates.append(simple_item) + + print( + f"Processing complete. Found {len(train_candidates)} training candidates and {len(test_candidates)} test candidates." + ) + + # 4. Randomly sample from the candidate lists. + # Or you can filter based on other criteria. + random.seed(42) + final_test_list = random.sample(test_candidates, min(test_sample_size, len(test_candidates))) + final_train_list = random.sample( + train_candidates, min(train_sample_size, len(train_candidates)) + ) + + print( + f"Sampling complete. Final train set size: {len(final_train_list)}, Final test set size: {len(final_test_list)}" + ) + + # 5. Save the data to .jsonl files. + dataset_splits = {"train": final_train_list, "test": final_test_list} + + for split_name, data_list in dataset_splits.items(): + output_file = os.path.join(output_dir, f"{split_name}.jsonl") + with open(output_file, "w", encoding="utf-8") as f: + for record in data_list: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + print(f"Successfully wrote to {output_file}") + + # 6. Create and save the dataset_dict.json file. + dataset_info = { + "citation": "", + "description": "A custom dataset created from callanwu/WebWalkerQA for RL training.", + "splits": { + "train": {"name": "train", "num_examples": len(final_train_list)}, + "test": {"name": "test", "num_examples": len(final_test_list)}, + }, + } + + dict_path = os.path.join(output_dir, "dataset_dict.json") + with open(dict_path, "w", encoding="utf-8") as f: + json.dump(dataset_info, f, ensure_ascii=False, indent=2) + print(f"Successfully wrote to {dict_path}") + + print("\nAll files created successfully!") + + +if __name__ == "__main__": + # Define the output directory. You can change this to any path you prefer. + # This example uses a folder named "webwalker_rl_dataset" in the same directory as the script. + current_file_dir = os.path.dirname(os.path.abspath(__file__)) + output_directory = os.path.join(current_file_dir, "webwalker_rl_dataset") + + # Alternatively, you can use an absolute path like in your original script: + # output_directory = "/mnt/data/zhangwenhao.zwh/webwalker_rl_dataset" + + # Call the main function to create the dataset files. + create_dataset_files(output_dir=output_directory, train_sample_size=160, test_sample_size=16) + + # --- Verification Step: Test loading the generated dataset --- + print("\n--- Verifying the created dataset ---") + try: + load_ds = load_dataset(output_directory) + print("Dataset loaded successfully for verification!") + print(f"Train set size: {len(load_ds['train'])}") + print(f"Test set size: {len(load_ds['test'])}") + except Exception as e: + print(f"Failed to load the created dataset: {e}") diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index 54342fea24..230497d1b1 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -11,6 +11,9 @@ from trinity.common.workflows.envs.agentscope.agentscopev1_react_workflow import ( AgentScopeReactMathWorkflow, ) +from trinity.common.workflows.envs.agentscope.agentscopev1_search_workflow import ( + AgentScopeV1ReactSearchWorkflow, +) from trinity.common.workflows.envs.alfworld.alfworld_workflow import ( AlfworldWorkflow, StepWiseAlfworldWorkflow, @@ -76,6 +79,7 @@ "MathEvalWorkflow", "AgentScopeV0ReactMathWorkflow", # will be deprecated soon "AgentScopeReactMathWorkflow", + "AgentScopeV1ReactSearchWorkflow", "EmailSearchWorkflow", "AsyncMathRULERWorkflow", "MathRULERWorkflow", diff --git a/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py index 9640fe9ef7..e2c1fe8f7b 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -"""We include the customized math workflows in this file.""" +"""We include the agentscope react workflows in this file.""" from typing import List, Optional diff --git a/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py new file mode 100644 index 0000000000..b32eeefe9a --- /dev/null +++ b/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- +"""We include simple react deep search workflows in this file. We use AgentScope V1 framework.""" + +import os +import re +from typing import List, Optional + +import openai + +from trinity.common.models.model import ModelWrapper +from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow + + +@WORKFLOWS.register_module("agentscope_v1_react_search_workflow") +class AgentScopeV1ReactSearchWorkflow(Workflow): + """ + This workflow serves as an example of how to use the agentscope framework within the trinity workflow. + """ + + def __init__( + self, + *, + task: Task, + model: ModelWrapper, + auxiliary_models: Optional[List[openai.OpenAI]] = None, + ): # get openai client from model + super().__init__( + task=task, + model=model, + auxiliary_models=auxiliary_models, + ) + # make sure that we have the correct import + try: + from agentscope.formatter import OpenAIChatFormatter + from agentscope.model import OpenAIChatModel + except ImportError as e: + error_message = f"AgentScope is not installed. Please install the agentscope framework first before running the workflow. Error: {str(e)}" + self.logger.error(error_message) + raise ImportError(error_message) + + # get openai client from model + self.openai_async_client = model.get_openai_async_client() + self.model_name = self.openai_async_client.model_path + + temperature = self.rollout_args.get("temperature", 1.0) + max_tokens = self.rollout_args.get("max_tokens", 4096) + self.agent_model = OpenAIChatModel( + api_key="EMPTY", + model_name=self.model_name, + stream=False, + generate_kwargs={ + "temperature": temperature, + "max_tokens": max_tokens, + }, + ) + self.agent_model.client = self.openai_async_client + self.agent_model_formatter = OpenAIChatFormatter() + + self.reset(task) + + @property + def resettable(self): + return True + + @property + def asynchronous(self): + """Whether the workflow runs in async mode.""" + return True + + @property + def repeatable(self): + return False + + def reset(self, task: Task): + self.workflow_args = task.workflow_args + self.max_turns = int(self.workflow_args.get("max_turns", 10)) + self.search_client_type = self.workflow_args.get("search_client_type", "searxng") + self.max_model_tokens = int(self.workflow_args.get("max_model_tokens", 24000)) + if self.search_client_type not in ["searxng", "tavily"]: + raise ValueError( + f"search_client_type must be one of ['searxng', 'tavily'], but got {self.search_client_type}" + ) + self.system_prompt = "You are a Web Information Seeking Master. Your task is to thoroughly seek the internet for information and provide accurate answers to questions." + + self.raw_task = task.raw_task + self.task_desc = task.task_desc + self.truth = task.truth + + def judge_result(self, result, question, correct_answer, judge_model=None) -> bool: + """Use LLM to judge whether the answer is correct or not.""" + if result is None: + return False + + def extract_final_answer(result) -> str: + """Extract the final answer from the agent's response.""" + try: + if ( + hasattr(result, "metadata") + and isinstance(result.metadata, dict) + and "result" in result.metadata + ): + return result.metadata["result"] + if hasattr(result, "content"): + if isinstance(result.content, dict) and "result" in result.content: + return result.content["result"] + return str(result.content) + return str(result) + except Exception as e: + self.logger.warning(f"Extract final answer error: {e}. Raw: {result}") + return str(result) + + final_answer = extract_final_answer(result) + + judge_prompt = f"""Judge whether the following [response] to [question] is correct or not based on the [correct_answer] below. + +[question]: {question} + +[response]: {final_answer} + +[correct_answer]: {correct_answer} + +Your judgement must be in the format and criteria specified below: + +1. **Reasoning**: Explain why the [response] is correct or incorrect based on [correct_answer], focusing only on if there are meaningful differences between [correct_answer] and the response. Do not comment on any background to the problem, do not attempt to solve the problem, do not argue for any answer different than [correct_answer], focus only on whether the answers match. + +2. **Correctness**: Answer exactly "YES" if [response] matches the [correct_answer] given above, or is within a small margin of error for numerical problems and small format issue for text problems (for example, with or without a hyphen should be considered the same). Answer exactly "NO" otherwise, i.e. if there if there is any inconsistency, ambiguity, non-equivalency, or if the [response] is incorrect. +""" + messages = [ + {"role": "system", "content": "You evaluate correctness."}, + {"role": "user", "content": judge_prompt}, + ] + completion = judge_model.chat.completions.create( + model=judge_model.model_path, messages=messages, stream=False + ) + judge_output = completion.choices[0].message.content + + self.logger.info( + f"[judge_result] prompt:\n{judge_prompt}\n\n[judge_result] LLM output:\n{judge_output}" + ) + + # Yes if the response is correct, No otherwise + match = re.search(r"Correctness.*?YES", judge_output, re.IGNORECASE) + return match is not None + + async def run_async(self): + try: + from agentscope.agent import ReActAgent + from agentscope.mcp import StdIOStatefulClient + from agentscope.memory import InMemoryMemory + from agentscope.message import Msg + from pydantic import BaseModel, Field + except ImportError as e: + error_message = f"AgentScope V1 is not installed. Please install the agentscope framework first before running the workflow. Error: {str(e)}" + self.logger.error(error_message) + raise ImportError(error_message) + + self.agent = ReActAgent( + name="Friday", + sys_prompt=self.system_prompt, + model=self.agent_model, + formatter=self.agent_model_formatter, + memory=InMemoryMemory(), + max_iters=self.max_turns, + ) + self.agent.model.client = self.openai_async_client + + if self.search_client_type == "tavily": + tavily_api_key = os.getenv("TAVILY_API_KEY", "") + if not tavily_api_key: + raise ValueError( + "TAVILY_API_KEY environment variable is not set. Please set it to use the Tavily search tool." + ) + + self.search_client = StdIOStatefulClient( + name="tavily_mcp", + command="npx", + args=["-y", "tavily-mcp@latest"], + env={"TAVILY_API_KEY": tavily_api_key}, + ) + elif self.search_client_type == "searxng": + searxng_url = os.getenv("SEARXNG_URL", "") + if not searxng_url: + raise ValueError( + "SEARXNG_URL environment variable is not set. Please set it to use the SearXNG search tool." + ) + self.search_client = StdIOStatefulClient( # refer to https://github.com/ihor-sokoliuk/mcp-searxng for more details + name="searxng_mcp", + command="npx", + args=["-y", "mcp-searxng"], + env={"SEARXNG_URL": searxng_url}, + ) + else: + raise ValueError( + f"search_client_type must be one of ['searxng', 'tavily'], but got {self.search_client_type}" + ) + + instruction = Msg("user", content=self.task_desc, role="user") + + class FinalResult(BaseModel): + result: str = Field(description="The final result to the initial user query") + + try: + await self.search_client.connect() + await self.agent.toolkit.register_mcp_client(self.search_client) + result = await self.agent.reply(instruction, structured_model=FinalResult) + except Exception as e: + self.logger.error(f"Error during agent reply: {e}") + result = None + finally: + if self.search_client and self.search_client.is_connected: + await self.search_client.close() + + # Reward calculation (judge_result can stay sync if your judge_model only has sync chat, otherwise you need to make it async) + try: + judge_model = self.auxiliary_models[0] if self.auxiliary_models else None + assert judge_model is not None, "Please provide a judge model for reward calculation." + reward = 1 if self.judge_result(result, self.task_desc, self.truth, judge_model) else 0 + except Exception as e: + self.logger.error(f"Error in judge_model judging: {e}") + reward = 0 + + self.logger.debug(f"Reward: {reward}") + experiences = self.model.extract_experience_from_history(clear_history=True) + return_experiences = [] + self.logger.debug(f"Experiences extracted len: {len(experiences)}") + for i, experience in enumerate(experiences): + experience.eid.step = i + experience.reward = reward + agent_metrics = { + "react_turns": len(self.agent.memory.content) // 2, + "max_turns": self.max_turns, + } + if experience.metrics is None: + experience.metrics = {} + experience.metrics.update(agent_metrics) + if len(experience.tokens) > self.max_model_tokens: + continue + return_experiences.append(experience) + if return_experiences: + self.logger.debug( + f"return experience len: {len(return_experiences)}, run_id: {str(return_experiences[-1].eid.run)}, final step reward: {return_experiences[-1].reward}" + ) + else: + self.logger.info("No valid experiences to return (all filtered out).") + return return_experiences