diff --git a/docs/mcp.md b/docs/mcp.md new file mode 100644 index 00000000..7ec11c16 --- /dev/null +++ b/docs/mcp.md @@ -0,0 +1,51 @@ +# Model context protocol + +The [Model context protocol](https://modelcontextprotocol.io/introduction) (aka MCP) is a way to provide tools and context to the LLM. From the MCP docs: + +> MCP is an open protocol that standardizes how applications provide context to LLMs. Think of MCP like a USB-C port for AI applications. Just as USB-C provides a standardized way to connect your devices to various peripherals and accessories, MCP provides a standardized way to connect AI models to different data sources and tools. + +The Agents SDK has support for MCP. This enables you to use a wide range of MCP servers to provide tools to your Agents. + +## MCP servers + +Currently, the MCP spec defines two kinds of servers, based on the transport mechanism they use: + +1. **stdio** servers run as a subprocess of your application. You can think of them as running "locally". +2. **HTTP over SSE** servers run remotely. You connect to them via a URL. + +You can use the [`MCPServerStdio`][agents.mcp.server.MCPServerStdio] and [`MCPServerSse`][agents.mcp.server.MCPServerSse] classes to connect to these servers. + +For example, this is how you'd use the [official MCP filesystem server](https://www.npmjs.com/package/@modelcontextprotocol/server-filesystem). + +```python +async with MCPServerStdio( + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + } +) as server: + tools = await server.list_tools() +``` + +## Using MCP servers + +MCP servers can be added to Agents. The Agents SDK will call `list_tools()` on the MCP servers each time the Agent is run. This makes the LLM aware of the MCP server's tools. When the LLM calls a tool from an MCP server, the SDK calls `call_tool()` on that server. + +```python + +agent=Agent( + name="Assistant", + instructions="Use the tools to achieve the task", + mcp_servers=[mcp_server_1, mcp_server_2] +) +``` + +## Caching + +Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to both [`MCPServerStdio`][agents.mcp.server.MCPServerStdio] and [`MCPServerSse`][agents.mcp.server.MCPServerSse]. You should only do this if you're certain the tool list will not change. + +If you want to invalidate the cache, you can call `invalidate_tools_cache()` on the servers. + +## End-to-end example + +View complete working examples at [examples/mcp](https://github.com/openai/openai-agents-python/tree/main/examples/mcp). diff --git a/docs/ref/mcp/server.md b/docs/ref/mcp/server.md new file mode 100644 index 00000000..e58efab2 --- /dev/null +++ b/docs/ref/mcp/server.md @@ -0,0 +1,3 @@ +# `MCP Servers` + +::: agents.mcp.server diff --git a/docs/ref/mcp/util.md b/docs/ref/mcp/util.md new file mode 100644 index 00000000..b3f7db25 --- /dev/null +++ b/docs/ref/mcp/util.md @@ -0,0 +1,3 @@ +# `MCP Util` + +::: agents.mcp.util diff --git a/examples/mcp/filesystem_example/README.md b/examples/mcp/filesystem_example/README.md new file mode 100644 index 00000000..682afc83 --- /dev/null +++ b/examples/mcp/filesystem_example/README.md @@ -0,0 +1,26 @@ +# MCP Filesystem Example + +This example uses the [fileystem MCP server](https://github.com/modelcontextprotocol/servers/tree/main/src/filesystem), running locally via `npx`. + +Run it via: + +``` +uv run python python examples/mcp/filesystem_example/main.py +``` + +## Details + +The example uses the `MCPServerStdio` class from `agents`, with the command: + +```bash +npx -y "@modelcontextprotocol/server-filesystem" +``` + +It's only given access to the `sample_files` directory adjacent to the example, which contains some sample data. + +Under the hood: + +1. The server is spun up in a subprocess, and exposes a bunch of tools like `list_directory()`, `read_file()`, etc. +2. We add the server instance to the Agent via `mcp_agents`. +3. Each time the agent runs, we call out to the MCP server to fetch the list of tools via `server.list_tools()`. +4. If the LLM chooses to use an MCP tool, we call the MCP server to run the tool via `server.run_tool()`. diff --git a/examples/mcp/filesystem_example/main.py b/examples/mcp/filesystem_example/main.py new file mode 100644 index 00000000..0ba2b675 --- /dev/null +++ b/examples/mcp/filesystem_example/main.py @@ -0,0 +1,54 @@ +import asyncio +import os +import shutil + +from agents import Agent, Runner, trace +from agents.mcp import MCPServer, MCPServerStdio + + +async def run(mcp_server: MCPServer): + agent = Agent( + name="Assistant", + instructions="Use the tools to read the filesystem and answer questions based on those files.", + mcp_servers=[mcp_server], + ) + + # List the files it can read + message = "Read the files and list them." + print(f"Running: {message}") + result = await Runner.run(starting_agent=agent, input=message) + print(result.final_output) + + # Ask about books + message = "What is my #1 favorite book?" + print(f"\n\nRunning: {message}") + result = await Runner.run(starting_agent=agent, input=message) + print(result.final_output) + + # Ask a question that reads then reasons. + message = "Look at my favorite songs. Suggest one new song that I might like." + print(f"\n\nRunning: {message}") + result = await Runner.run(starting_agent=agent, input=message) + print(result.final_output) + + +async def main(): + current_dir = os.path.dirname(os.path.abspath(__file__)) + samples_dir = os.path.join(current_dir, "sample_files") + + async with MCPServerStdio( + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + } + ) as server: + with trace(workflow_name="MCP Filesystem Example"): + await run(server) + + +if __name__ == "__main__": + # Let's make sure the user has npx installed + if not shutil.which("npx"): + raise RuntimeError("npx is not installed. Please install it with `npm install -g npx`.") + + asyncio.run(main()) diff --git a/examples/mcp/filesystem_example/sample_files/favorite_books.txt b/examples/mcp/filesystem_example/sample_files/favorite_books.txt new file mode 100644 index 00000000..c55f457e --- /dev/null +++ b/examples/mcp/filesystem_example/sample_files/favorite_books.txt @@ -0,0 +1,20 @@ +1. To Kill a Mockingbird – Harper Lee +2. Pride and Prejudice – Jane Austen +3. 1984 – George Orwell +4. The Hobbit – J.R.R. Tolkien +5. Harry Potter and the Sorcerer’s Stone – J.K. Rowling +6. The Great Gatsby – F. Scott Fitzgerald +7. Charlotte’s Web – E.B. White +8. Anne of Green Gables – Lucy Maud Montgomery +9. The Alchemist – Paulo Coelho +10. Little Women – Louisa May Alcott +11. The Catcher in the Rye – J.D. Salinger +12. Animal Farm – George Orwell +13. The Chronicles of Narnia: The Lion, the Witch, and the Wardrobe – C.S. Lewis +14. The Book Thief – Markus Zusak +15. A Wrinkle in Time – Madeleine L’Engle +16. The Secret Garden – Frances Hodgson Burnett +17. Moby-Dick – Herman Melville +18. Fahrenheit 451 – Ray Bradbury +19. Jane Eyre – Charlotte Brontë +20. The Little Prince – Antoine de Saint-Exupéry \ No newline at end of file diff --git a/examples/mcp/filesystem_example/sample_files/favorite_cities.txt b/examples/mcp/filesystem_example/sample_files/favorite_cities.txt new file mode 100644 index 00000000..1d3354f2 --- /dev/null +++ b/examples/mcp/filesystem_example/sample_files/favorite_cities.txt @@ -0,0 +1,4 @@ +- In the summer, I love visiting London. +- In the winter, Tokyo is great. +- In the spring, San Francisco. +- In the fall, New York is the best. \ No newline at end of file diff --git a/examples/mcp/filesystem_example/sample_files/favorite_songs.txt b/examples/mcp/filesystem_example/sample_files/favorite_songs.txt new file mode 100644 index 00000000..d659bb58 --- /dev/null +++ b/examples/mcp/filesystem_example/sample_files/favorite_songs.txt @@ -0,0 +1,10 @@ +1. "Here Comes the Sun" – The Beatles +2. "Imagine" – John Lennon +3. "Bohemian Rhapsody" – Queen +4. "Shake It Off" – Taylor Swift +5. "Billie Jean" – Michael Jackson +6. "Uptown Funk" – Mark Ronson ft. Bruno Mars +7. "Don’t Stop Believin’" – Journey +8. "Dancing Queen" – ABBA +9. "Happy" – Pharrell Williams +10. "Wonderwall" – Oasis diff --git a/mkdocs.yml b/mkdocs.yml index 941f29ed..454d881a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,6 +28,7 @@ nav: - results.md - streaming.md - tools.md + - mcp.md - handoffs.md - tracing.md - context.md @@ -60,6 +61,8 @@ nav: - ref/models/interface.md - ref/models/openai_chatcompletions.md - ref/models/openai_responses.md + - ref/mcp/server.md + - ref/mcp/util.md - Tracing: - ref/tracing/index.md - ref/tracing/create.md @@ -107,6 +110,8 @@ plugins: show_signature_annotations: true # Makes the font sizes nicer heading_level: 3 + # Show inherited members + inherited_members: true extra: # Remove material generation message in footer diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index df3db42f..80b9cc6d 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -344,6 +344,7 @@ def process_model_response( cls, *, agent: Agent[Any], + all_tools: list[Tool], response: ModelResponse, output_schema: AgentOutputSchema | None, handoffs: list[Handoff], @@ -355,8 +356,8 @@ def process_model_response( computer_actions = [] handoff_map = {handoff.tool_name: handoff for handoff in handoffs} - function_map = {tool.name: tool for tool in agent.tools if isinstance(tool, FunctionTool)} - computer_tool = next((tool for tool in agent.tools if isinstance(tool, ComputerTool)), None) + function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} + computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None) for output in response.output: if isinstance(output, ResponseOutputMessage): diff --git a/src/agents/agent.py b/src/agents/agent.py index 2723e678..3258e15a 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -12,6 +12,7 @@ from .handoffs import Handoff from .items import ItemHelpers from .logger import logger +from .mcp import MCPUtil from .model_settings import ModelSettings from .models.interface import Model from .run_context import RunContextWrapper, TContext @@ -21,6 +22,7 @@ if TYPE_CHECKING: from .lifecycle import AgentHooks + from .mcp import MCPServer from .result import RunResult @@ -107,6 +109,16 @@ class Agent(Generic[TContext]): tools: list[Tool] = field(default_factory=list) """A list of tools that the agent can use.""" + mcp_servers: list[MCPServer] = field(default_factory=list) + """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that + the agent can use. Every time the agent runs, it will include tools from these servers in the + list of available tools. + + NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call + `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no + longer needed. + """ + input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list) """A list of checks that run in parallel to the agent's execution, before generating a response. Runs only if the agent is the first agent in the chain. @@ -205,3 +217,11 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s logger.error(f"Instructions must be a string or a function, got {self.instructions}") return None + + async def get_mcp_tools(self) -> list[Tool]: + """Fetches the available tools from the MCP servers.""" + return await MCPUtil.get_all_function_tools(self.mcp_servers) + + async def get_all_tools(self) -> list[Tool]: + """All agent tools, including MCP tools and function tools.""" + return await MCPUtil.get_all_function_tools(self.mcp_servers) + self.tools diff --git a/src/agents/mcp/mcp_util.py b/src/agents/mcp/mcp_util.py deleted file mode 100644 index 41b4c521..00000000 --- a/src/agents/mcp/mcp_util.py +++ /dev/null @@ -1,94 +0,0 @@ -import functools -import json -from typing import Any - -from mcp.types import Tool as MCPTool - -from .. import _debug -from ..exceptions import AgentsException, ModelBehaviorError, UserError -from ..logger import logger -from ..run_context import RunContextWrapper -from ..tool import FunctionTool, Tool -from .server import MCPServer - - -class MCPUtil: - """Set of utilities for interop between MCP and Agents SDK tools.""" - - @classmethod - async def get_all_function_tools(cls, servers: list[MCPServer]) -> list[Tool]: - """Get all function tools from a list of MCP servers.""" - tools = [] - tool_names: set[str] = set() - for server in servers: - server_tools = await cls.get_function_tools(server) - server_tool_names = {tool.name for tool in server_tools} - if len(server_tool_names & tool_names) > 0: - raise UserError( - f"Duplicate tool names found across MCP servers: " - f"{server_tool_names & tool_names}" - ) - tool_names.update(server_tool_names) - tools.extend(server_tools) - - return tools - - @classmethod - async def get_function_tools(cls, server: MCPServer) -> list[Tool]: - """Get all function tools from a single MCP server.""" - tools = await server.list_tools() - return [cls.to_function_tool(tool, server) for tool in tools] - - @classmethod - def to_function_tool(cls, tool: MCPTool, server: MCPServer) -> FunctionTool: - """Convert an MCP tool to an Agents SDK function tool.""" - invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool) - return FunctionTool( - name=tool.name, - description=tool.description or "", - params_json_schema=tool.inputSchema, - on_invoke_tool=invoke_func, - strict_json_schema=False, - ) - - @classmethod - async def invoke_mcp_tool( - cls, server: MCPServer, tool: MCPTool, context: RunContextWrapper[Any], input_json: str - ) -> str: - """Invoke an MCP tool and return the result as a string.""" - try: - json_data: dict[str, Any] = json.loads(input_json) if input_json else {} - except Exception as e: - if _debug.DONT_LOG_TOOL_DATA: - logger.debug(f"Invalid JSON input for tool {tool.name}") - else: - logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}") - raise ModelBehaviorError( - f"Invalid JSON input for tool {tool.name}: {input_json}" - ) from e - - if _debug.DONT_LOG_TOOL_DATA: - logger.debug(f"Invoking MCP tool {tool.name}") - else: - logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}") - - try: - result = await server.call_tool(tool.name, json_data) - except Exception as e: - logger.error(f"Error invoking MCP tool {tool.name}: {e}") - raise AgentsException(f"Error invoking MCP tool {tool.name}: {e}") from e - - if _debug.DONT_LOG_TOOL_DATA: - logger.debug(f"MCP tool {tool.name} completed.") - else: - logger.debug(f"MCP tool {tool.name} returned {result}") - - # The MCP tool result is a list of content items, whereas OpenAI tool outputs are a single - # string. We'll try to convert. - if len(result.content) == 1: - return result.content[0].model_dump_json() - elif len(result.content) > 1: - return json.dumps([item.model_dump() for item in result.content]) - else: - logger.error(f"Errored MCP tool result: {result}") - return "Error running tool." diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index e19e686a..91af31db 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -175,10 +175,10 @@ def __init__(self, params: MCPServerStdioParams, cache_tools_list: bool = False) """Create a new MCP server based on the stdio transport. Args: - params: The params that configure the server. This includes: - - The command (e.g. `python` or `node`) that starts the server. - - The args to pass to the server command (e.g. `foo.py` or `server.js`). - - The environment variables to set for the server. + params: The params that configure the server. This includes the command to run to + start the server, the args to pass to the command, the environment variables to + set for the server, the working directory to use when spawning the process, and + the text encoding used when sending/receiving messages to the server. cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be cached and only fetched from the server once. If `False`, the tools list will be fetched from the server on each call to `list_tools()`. The cache can be @@ -235,11 +235,9 @@ def __init__(self, params: MCPServerSseParams, cache_tools_list: bool = False): """Create a new MCP server based on the HTTP with SSE transport. Args: - params: The params that configure the server. This includes: - - The URL of the server. - - The headers to send to the server. - - The timeout for the HTTP request. - - The timeout for the SSE connection. + params: The params that configure the server. This includes the URL of the server, + the headers to send to the server, the timeout for the HTTP request, and the + timeout for the SSE connection. cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be cached and only fetched from the server once. If `False`, the tools list will be diff --git a/src/agents/run.py b/src/agents/run.py index 934400fe..b7ac85f9 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -7,6 +7,8 @@ from openai.types.responses import ResponseCompletedEvent +from agents.tool import Tool + from ._run_impl import ( NextStepFinalOutput, NextStepHandoff, @@ -177,7 +179,8 @@ async def run( # agent changes, or if the agent loop ends. if current_span is None: handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] - tool_names = [t.name for t in current_agent.tools] + all_tools = await cls._get_all_tools(current_agent) + tool_names = [t.name for t in all_tools] if output_schema := cls._get_output_schema(current_agent): output_type_name = output_schema.output_type_name() else: @@ -217,6 +220,7 @@ async def run( ), cls._run_single_turn( agent=current_agent, + all_tools=all_tools, original_input=original_input, generated_items=generated_items, hooks=hooks, @@ -228,6 +232,7 @@ async def run( else: turn_result = await cls._run_single_turn( agent=current_agent, + all_tools=all_tools, original_input=original_input, generated_items=generated_items, hooks=hooks, @@ -627,7 +632,7 @@ async def _run_single_turn_streamed( system_prompt = await agent.get_system_prompt(context_wrapper) handoffs = cls._get_handoffs(agent) - + all_tools = await cls._get_all_tools(agent) model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) final_response: ModelResponse | None = None @@ -640,7 +645,7 @@ async def _run_single_turn_streamed( system_prompt, input, model_settings, - agent.tools, + all_tools, output_schema, handoffs, get_model_tracing_impl( @@ -677,6 +682,7 @@ async def _run_single_turn_streamed( pre_step_items=streamed_result.new_items, new_response=final_response, output_schema=output_schema, + all_tools=all_tools, handoffs=handoffs, hooks=hooks, context_wrapper=context_wrapper, @@ -691,6 +697,7 @@ async def _run_single_turn( cls, *, agent: Agent[TContext], + all_tools: list[Tool], original_input: str | list[TResponseInputItem], generated_items: list[RunItem], hooks: RunHooks[TContext], @@ -721,6 +728,7 @@ async def _run_single_turn( system_prompt, input, output_schema, + all_tools, handoffs, context_wrapper, run_config, @@ -732,6 +740,7 @@ async def _run_single_turn( pre_step_items=generated_items, new_response=new_response, output_schema=output_schema, + all_tools=all_tools, handoffs=handoffs, hooks=hooks, context_wrapper=context_wrapper, @@ -743,6 +752,7 @@ async def _get_single_step_result_from_response( cls, *, agent: Agent[TContext], + all_tools: list[Tool], original_input: str | list[TResponseInputItem], pre_step_items: list[RunItem], new_response: ModelResponse, @@ -754,6 +764,7 @@ async def _get_single_step_result_from_response( ) -> SingleStepResult: processed_response = RunImpl.process_model_response( agent=agent, + all_tools=all_tools, response=new_response, output_schema=output_schema, handoffs=handoffs, @@ -853,6 +864,7 @@ async def _get_new_response( system_prompt: str | None, input: list[TResponseInputItem], output_schema: AgentOutputSchema | None, + all_tools: list[Tool], handoffs: list[Handoff], context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, @@ -863,7 +875,7 @@ async def _get_new_response( system_instructions=system_prompt, input=input, model_settings=model_settings, - tools=agent.tools, + tools=all_tools, output_schema=output_schema, handoffs=handoffs, tracing=get_model_tracing_impl( @@ -892,6 +904,10 @@ def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]: handoffs.append(handoff(handoff_item)) return handoffs + @classmethod + async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]: + return await agent.get_all_tools() + @classmethod def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: if isinstance(run_config.model, Model): diff --git a/tests/mcp/__init__.py b/tests/mcp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mcp/conftest.py b/tests/mcp/conftest.py new file mode 100644 index 00000000..80fd15ec --- /dev/null +++ b/tests/mcp/conftest.py @@ -0,0 +1,11 @@ +import os +import sys + + +# Skip MCP tests on Python 3.9 +def pytest_ignore_collect(collection_path, config): + if sys.version_info[:2] == (3, 9): + this_dir = os.path.dirname(__file__) + + if str(collection_path).startswith(this_dir): + return True diff --git a/tests/mcp/helpers.py b/tests/mcp/helpers.py new file mode 100644 index 00000000..952b3ea7 --- /dev/null +++ b/tests/mcp/helpers.py @@ -0,0 +1,54 @@ +import json +import shutil +from typing import Any + +from mcp import Tool as MCPTool +from mcp.types import CallToolResult, TextContent + +from agents.mcp import MCPServer + +tee = shutil.which("tee") or "" +assert tee, "tee not found" + + +# Added dummy stream classes for patching stdio_client to avoid real I/O during tests +class DummyStream: + async def send(self, msg): + pass + + async def receive(self): + raise Exception("Dummy receive not implemented") + + +class DummyStreamsContextManager: + async def __aenter__(self): + return (DummyStream(), DummyStream()) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class FakeMCPServer(MCPServer): + def __init__(self, tools: list[MCPTool] | None = None): + self.tools: list[MCPTool] = tools or [] + self.tool_calls: list[str] = [] + self.tool_results: list[str] = [] + + def add_tool(self, name: str, input_schema: dict[str, Any]): + self.tools.append(MCPTool(name=name, inputSchema=input_schema)) + + async def connect(self): + pass + + async def cleanup(self): + pass + + async def list_tools(self): + return self.tools + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + self.tool_calls.append(tool_name) + self.tool_results.append(f"result_{tool_name}_{json.dumps(arguments)}") + return CallToolResult( + content=[TextContent(text=self.tool_results[-1], type="text")], + ) diff --git a/tests/mcp/test_caching.py b/tests/mcp/test_caching.py new file mode 100644 index 00000000..cac409e6 --- /dev/null +++ b/tests/mcp/test_caching.py @@ -0,0 +1,57 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from mcp.types import ListToolsResult, Tool as MCPTool + +from agents.mcp import MCPServerStdio + +from .helpers import DummyStreamsContextManager, tee + + +@pytest.mark.asyncio +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("mcp.client.session.ClientSession.list_tools") +async def test_server_caching_works( + mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client +): + """Test that if we turn caching on, the list of tools is cached and not fetched from the server + on each call to `list_tools()`. + """ + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + + tools = [ + MCPTool(name="tool1", inputSchema={}), + MCPTool(name="tool2", inputSchema={}), + ] + + mock_list_tools.return_value = ListToolsResult(tools=tools) + + async with server: + # Call list_tools() multiple times + tools = await server.list_tools() + assert tools == tools + + assert mock_list_tools.call_count == 1, "list_tools() should have been called once" + + # Call list_tools() again, should return the cached value + tools = await server.list_tools() + assert tools == tools + + assert mock_list_tools.call_count == 1, "list_tools() should not have been called again" + + # Invalidate the cache and call list_tools() again + server.invalidate_tools_cache() + tools = await server.list_tools() + assert tools == tools + + assert mock_list_tools.call_count == 2, "list_tools() should be called again" + + # Without invalidating the cache, calling list_tools() again should return the cached value + tools = await server.list_tools() + assert tools == tools diff --git a/tests/mcp/test_connect_disconnect.py b/tests/mcp/test_connect_disconnect.py new file mode 100644 index 00000000..b0013039 --- /dev/null +++ b/tests/mcp/test_connect_disconnect.py @@ -0,0 +1,69 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from mcp.types import ListToolsResult, Tool as MCPTool + +from agents.mcp import MCPServerStdio + +from .helpers import DummyStreamsContextManager, tee + + +@pytest.mark.asyncio +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("mcp.client.session.ClientSession.list_tools") +async def test_async_ctx_manager_works( + mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client +): + """Test that the async context manager works.""" + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + + tools = [ + MCPTool(name="tool1", inputSchema={}), + MCPTool(name="tool2", inputSchema={}), + ] + + mock_list_tools.return_value = ListToolsResult(tools=tools) + + assert server.session is None, "Server should not be connected" + + async with server: + assert server.session is not None, "Server should be connected" + + assert server.session is None, "Server should be disconnected" + + +@pytest.mark.asyncio +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("mcp.client.session.ClientSession.list_tools") +async def test_manual_connect_disconnect_works( + mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client +): + """Test that the async context manager works.""" + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + + tools = [ + MCPTool(name="tool1", inputSchema={}), + MCPTool(name="tool2", inputSchema={}), + ] + + mock_list_tools.return_value = ListToolsResult(tools=tools) + + assert server.session is None, "Server should not be connected" + + await server.connect() + assert server.session is not None, "Server should be connected" + + await server.cleanup() + assert server.session is None, "Server should be disconnected" diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py new file mode 100644 index 00000000..345df996 --- /dev/null +++ b/tests/mcp/test_mcp_util.py @@ -0,0 +1,109 @@ +import logging +from typing import Any + +import pytest +from mcp.types import Tool as MCPTool +from pydantic import BaseModel + +from agents import FunctionTool, RunContextWrapper +from agents.exceptions import AgentsException, ModelBehaviorError +from agents.mcp import MCPServer, MCPUtil + +from .helpers import FakeMCPServer + + +class Foo(BaseModel): + bar: str + baz: int + + +class Bar(BaseModel): + qux: str + + +@pytest.mark.asyncio +async def test_get_all_function_tools(): + """Test that the get_all_function_tools function returns all function tools from a list of MCP + servers. + """ + names = ["test_tool_1", "test_tool_2", "test_tool_3", "test_tool_4", "test_tool_5"] + schemas = [ + {}, + {}, + {}, + Foo.model_json_schema(), + Bar.model_json_schema(), + ] + + server1 = FakeMCPServer() + server1.add_tool(names[0], schemas[0]) + server1.add_tool(names[1], schemas[1]) + + server2 = FakeMCPServer() + server2.add_tool(names[2], schemas[2]) + server2.add_tool(names[3], schemas[3]) + + server3 = FakeMCPServer() + server3.add_tool(names[4], schemas[4]) + + servers: list[MCPServer] = [server1, server2, server3] + tools = await MCPUtil.get_all_function_tools(servers) + assert len(tools) == 5 + assert all(tool.name in names for tool in tools) + + for idx, tool in enumerate(tools): + assert isinstance(tool, FunctionTool) + assert tool.params_json_schema == schemas[idx] + assert tool.name == names[idx] + + +@pytest.mark.asyncio +async def test_invoke_mcp_tool(): + """Test that the invoke_mcp_tool function invokes an MCP tool and returns the result.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + + ctx = RunContextWrapper(context=None) + tool = MCPTool(name="test_tool_1", inputSchema={}) + + await MCPUtil.invoke_mcp_tool(server, tool, ctx, "") + # Just making sure it doesn't crash + + +@pytest.mark.asyncio +async def test_mcp_invoke_bad_json_errors(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG) + + """Test that bad JSON input errors are logged and re-raised.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + + ctx = RunContextWrapper(context=None) + tool = MCPTool(name="test_tool_1", inputSchema={}) + + with pytest.raises(ModelBehaviorError): + await MCPUtil.invoke_mcp_tool(server, tool, ctx, "not_json") + + assert "Invalid JSON input for tool test_tool_1" in caplog.text + + +class CrashingFakeMCPServer(FakeMCPServer): + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None): + raise Exception("Crash!") + + +@pytest.mark.asyncio +async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG) + + """Test that bad JSON input errors are logged and re-raised.""" + server = CrashingFakeMCPServer() + server.add_tool("test_tool_1", {}) + + ctx = RunContextWrapper(context=None) + tool = MCPTool(name="test_tool_1", inputSchema={}) + + with pytest.raises(AgentsException): + await MCPUtil.invoke_mcp_tool(server, tool, ctx, "") + + assert "Error invoking MCP tool test_tool_1" in caplog.text diff --git a/tests/mcp/test_runner_calls_mcp.py b/tests/mcp/test_runner_calls_mcp.py new file mode 100644 index 00000000..3319c097 --- /dev/null +++ b/tests/mcp/test_runner_calls_mcp.py @@ -0,0 +1,197 @@ +import json + +import pytest +from pydantic import BaseModel + +from agents import Agent, ModelBehaviorError, Runner, UserError + +from ..fake_model import FakeModel +from ..test_responses import get_function_tool_call, get_text_message +from .helpers import FakeMCPServer + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_calls_mcp_tool(streaming: bool): + """Test that the runner calls an MCP tool when the model produces a tool call.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + server.add_tool("test_tool_2", {}) + server.add_tool("test_tool_3", {}) + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_2", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + assert server.tool_calls == ["test_tool_2"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_asserts_when_mcp_tool_not_found(streaming: bool): + """Test that the runner asserts when an MCP tool is not found.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + server.add_tool("test_tool_2", {}) + server.add_tool("test_tool_3", {}) + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_doesnt_exist", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + with pytest.raises(ModelBehaviorError): + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_works_with_multiple_mcp_servers(streaming: bool): + """Test that the runner works with multiple MCP servers.""" + server1 = FakeMCPServer() + server1.add_tool("test_tool_1", {}) + + server2 = FakeMCPServer() + server2.add_tool("test_tool_2", {}) + server2.add_tool("test_tool_3", {}) + + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server1, server2], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_2", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + assert server1.tool_calls == [] + assert server2.tool_calls == ["test_tool_2"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_errors_when_mcp_tools_clash(streaming: bool): + """Test that the runner errors when multiple servers have the same tool name.""" + server1 = FakeMCPServer() + server1.add_tool("test_tool_1", {}) + server1.add_tool("test_tool_2", {}) + + server2 = FakeMCPServer() + server2.add_tool("test_tool_2", {}) + server2.add_tool("test_tool_3", {}) + + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server1, server2], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_3", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + with pytest.raises(UserError): + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + +class Foo(BaseModel): + bar: str + baz: int + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_calls_mcp_tool_with_args(streaming: bool): + """Test that the runner calls an MCP tool when the model produces a tool call.""" + server = FakeMCPServer() + await server.connect() + server.add_tool("test_tool_1", {}) + server.add_tool("test_tool_2", Foo.model_json_schema()) + server.add_tool("test_tool_3", {}) + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server], + ) + + json_args = json.dumps(Foo(bar="baz", baz=1).model_dump()) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_2", json_args)], + # Second turn: text message + [get_text_message("done")], + ] + ) + + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + assert server.tool_calls == ["test_tool_2"] + assert server.tool_results == [f"result_test_tool_2_{json_args}"] + + await server.cleanup() diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py new file mode 100644 index 00000000..5c6432bc --- /dev/null +++ b/tests/mcp/test_server_errors.py @@ -0,0 +1,38 @@ +import pytest + +from agents.exceptions import UserError +from agents.mcp.server import _MCPServerWithClientSession + + +class CrashingClientSessionServer(_MCPServerWithClientSession): + def __init__(self): + super().__init__(cache_tools_list=False) + self.cleanup_called = False + + def create_streams(self): + raise ValueError("Crash!") + + async def cleanup(self): + self.cleanup_called = True + await super().cleanup() + + +@pytest.mark.asyncio +async def test_server_errors_cause_error_and_cleanup_called(): + server = CrashingClientSessionServer() + + with pytest.raises(ValueError): + await server.connect() + + assert server.cleanup_called + + +@pytest.mark.asyncio +async def test_not_calling_connect_causes_error(): + server = CrashingClientSessionServer() + + with pytest.raises(UserError): + await server.list_tools() + + with pytest.raises(UserError): + await server.call_tool("foo", {}) diff --git a/tests/test_run_step_execution.py b/tests/test_run_step_execution.py index 2d581bf6..16c62c84 100644 --- a/tests/test_run_step_execution.py +++ b/tests/test_run_step_execution.py @@ -290,6 +290,7 @@ async def get_execute_result( processed_response = RunImpl.process_model_response( agent=agent, + all_tools=await agent.get_all_tools(), response=response, output_schema=output_schema, handoffs=handoffs, diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 24f9e8e3..2a6634ac 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -43,7 +43,11 @@ def test_empty_response(): ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=[], ) assert not result.handoffs assert not result.functions @@ -57,13 +61,14 @@ def test_no_tool_calls(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, response=response, output_schema=None, handoffs=[], all_tools=[] ) assert not result.handoffs assert not result.functions -def test_single_tool_call(): +@pytest.mark.asyncio +async def test_single_tool_call(): agent = Agent(name="test", tools=[get_function_tool(name="test")]) response = ModelResponse( output=[ @@ -74,7 +79,11 @@ def test_single_tool_call(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert not result.handoffs assert result.functions and len(result.functions) == 1 @@ -84,7 +93,8 @@ def test_single_tool_call(): assert func.tool_call.arguments == "" -def test_missing_tool_call_raises_error(): +@pytest.mark.asyncio +async def test_missing_tool_call_raises_error(): agent = Agent(name="test", tools=[get_function_tool(name="test")]) response = ModelResponse( output=[ @@ -97,11 +107,16 @@ def test_missing_tool_call_raises_error(): with pytest.raises(ModelBehaviorError): RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) -def test_multiple_tool_calls(): +@pytest.mark.asyncio +async def test_multiple_tool_calls(): agent = Agent( name="test", tools=[ @@ -121,7 +136,11 @@ def test_multiple_tool_calls(): ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert not result.handoffs assert result.functions and len(result.functions) == 2 @@ -146,7 +165,11 @@ async def test_handoffs_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent_3, response=response, output_schema=None, handoffs=[] + agent=agent_3, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent_3.get_all_tools(), ) assert not result.handoffs, "Shouldn't have a handoff here" @@ -160,6 +183,7 @@ async def test_handoffs_parsed_correctly(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) assert len(result.handoffs) == 1, "Should have a handoff here" handoff = result.handoffs[0] @@ -189,10 +213,12 @@ async def test_missing_handoff_fails(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) -def test_multiple_handoffs_doesnt_error(): +@pytest.mark.asyncio +async def test_multiple_handoffs_doesnt_error(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2]) @@ -210,6 +236,7 @@ def test_multiple_handoffs_doesnt_error(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) assert len(result.handoffs) == 2, "Should have multiple handoffs here" @@ -218,7 +245,8 @@ class Foo(BaseModel): bar: str -def test_final_output_parsed_correctly(): +@pytest.mark.asyncio +async def test_final_output_parsed_correctly(): agent = Agent(name="test", output_type=Foo) response = ModelResponse( output=[ @@ -234,10 +262,12 @@ def test_final_output_parsed_correctly(): response=response, output_schema=Runner._get_output_schema(agent), handoffs=[], + all_tools=await agent.get_all_tools(), ) -def test_file_search_tool_call_parsed_correctly(): +@pytest.mark.asyncio +async def test_file_search_tool_call_parsed_correctly(): # Ensure that a ResponseFileSearchToolCall output is parsed into a ToolCallItem and that no tool # runs are scheduled. @@ -254,7 +284,11 @@ def test_file_search_tool_call_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) # The final item should be a ToolCallItem for the file search call assert any( @@ -265,7 +299,8 @@ def test_file_search_tool_call_parsed_correctly(): assert not result.handoffs -def test_function_web_search_tool_call_parsed_correctly(): +@pytest.mark.asyncio +async def test_function_web_search_tool_call_parsed_correctly(): agent = Agent(name="test") web_search_call = ResponseFunctionWebSearch(id="w1", status="completed", type="web_search_call") response = ModelResponse( @@ -274,7 +309,11 @@ def test_function_web_search_tool_call_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert any( isinstance(item, ToolCallItem) and item.raw_item is web_search_call @@ -284,7 +323,8 @@ def test_function_web_search_tool_call_parsed_correctly(): assert not result.handoffs -def test_reasoning_item_parsed_correctly(): +@pytest.mark.asyncio +async def test_reasoning_item_parsed_correctly(): # Verify that a Reasoning output item is converted into a ReasoningItem. reasoning = ResponseReasoningItem( @@ -296,7 +336,11 @@ def test_reasoning_item_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=Agent(name="test"), response=response, output_schema=None, handoffs=[] + agent=Agent(name="test"), + response=response, + output_schema=None, + handoffs=[], + all_tools=await Agent(name="test").get_all_tools(), ) assert any( isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items @@ -342,7 +386,8 @@ def drag(self, path: list[tuple[int, int]]) -> None: return None # pragma: no cover -def test_computer_tool_call_without_computer_tool_raises_error(): +@pytest.mark.asyncio +async def test_computer_tool_call_without_computer_tool_raises_error(): # If the agent has no ComputerTool in its tools, process_model_response should raise a # ModelBehaviorError when encountering a ResponseComputerToolCall. computer_call = ResponseComputerToolCall( @@ -360,11 +405,16 @@ def test_computer_tool_call_without_computer_tool_raises_error(): ) with pytest.raises(ModelBehaviorError): RunImpl.process_model_response( - agent=Agent(name="test"), response=response, output_schema=None, handoffs=[] + agent=Agent(name="test"), + response=response, + output_schema=None, + handoffs=[], + all_tools=await Agent(name="test").get_all_tools(), ) -def test_computer_tool_call_with_computer_tool_parsed_correctly(): +@pytest.mark.asyncio +async def test_computer_tool_call_with_computer_tool_parsed_correctly(): # If the agent contains a ComputerTool, ensure that a ResponseComputerToolCall is parsed into a # ToolCallItem and scheduled to run in computer_actions. dummy_computer = DummyComputer() @@ -383,7 +433,11 @@ def test_computer_tool_call_with_computer_tool_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert any( isinstance(item, ToolCallItem) and item.raw_item is computer_call @@ -392,7 +446,8 @@ def test_computer_tool_call_with_computer_tool_parsed_correctly(): assert result.computer_actions and result.computer_actions[0].tool_call == computer_call -def test_tool_and_handoff_parsed_correctly(): +@pytest.mark.asyncio +async def test_tool_and_handoff_parsed_correctly(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent( @@ -413,6 +468,7 @@ def test_tool_and_handoff_parsed_correctly(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) assert result.functions and len(result.functions) == 1 assert len(result.handoffs) == 1, "Should have a handoff here" diff --git a/tests/voice/test_openai_stt.py b/tests/voice/test_openai_stt.py index 75559232..89b5cca7 100644 --- a/tests/voice/test_openai_stt.py +++ b/tests/voice/test_openai_stt.py @@ -269,7 +269,7 @@ def fake_time_func(): async for _ in turns: pass - assert "Timeout waiting for transcription_session.created event" in str(exc_info.value) + assert "Timeout waiting for transcription_session.created event" in str(exc_info.value) await session.close() @@ -302,13 +302,11 @@ async def test_session_error_event(): trace_include_sensitive_audio_data=False, ) - with pytest.raises(STTWebsocketConnectionError) as exc_info: + with pytest.raises(STTWebsocketConnectionError): turns = session.transcribe_turns() async for _ in turns: pass - assert "Simulated server error!" in str(exc_info.value) - await session.close() @@ -362,8 +360,8 @@ async def test_inactivity_timeout(): async for turn in session.transcribe_turns(): collected_turns.append(turn) - assert "Timeout waiting for transcription_session" in str(exc_info.value) + assert "Timeout waiting for transcription_session" in str(exc_info.value) - assert len(collected_turns) == 0, "No transcripts expected, but we got something?" + assert len(collected_turns) == 0, "No transcripts expected, but we got something?" await session.close()