diff --git a/integrations/mcp/pyproject.toml b/integrations/mcp/pyproject.toml index ed7d66f509..edd35f6b32 100644 --- a/integrations/mcp/pyproject.toml +++ b/integrations/mcp/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ ] dependencies = [ "mcp>=1.8.0", - "haystack-ai>=2.18.0", + "haystack-ai>=2.19.0", "exceptiongroup", # Backport of ExceptionGroup for Python < 3.11 "httpx" # HTTP client library used for SSE connections ] diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py index e626f118fd..425346d57f 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py @@ -814,6 +814,7 @@ def __init__( description: str | None = None, connection_timeout: int = 30, invocation_timeout: int = 30, + eager_connect: bool = False, ): """ Initialize the MCP tool. @@ -823,6 +824,9 @@ def __init__( :param description: Custom description (if None, server description will be used) :param connection_timeout: Timeout in seconds for server connection :param invocation_timeout: Default timeout in seconds for tool invocations + :param eager_connect: If True, connect to server during initialization. + If False (default), defer connection until warm_up or first tool use, + whichever comes first. :raises MCPConnectionError: If connection to the server fails :raises MCPToolNotFoundError: If no tools are available or the requested tool is not found :raises TimeoutError: If connection times out @@ -832,39 +836,27 @@ def __init__( self._server_info = server_info self._connection_timeout = connection_timeout self._invocation_timeout = invocation_timeout + self._eager_connect = eager_connect + self._client: MCPClient | None = None + self._worker: _MCPClientSessionManager | None = None + self._lock = threading.RLock() + + # don't connect now; initialize permissively + if not eager_connect: + # Permissive placeholder JSON Schema so the Tool is valid + # without discovering the remote schema during validation. + # Tool parameters/schema will be replaced with the correct schema (from the MCP server) on first use. + params = {"type": "object", "properties": {}, "additionalProperties": True} + super().__init__(name=name, description=description or "", parameters=params, function=self._invoke_tool) + return logger.debug(f"TOOL: Initializing MCPTool '{name}'") try: - # Create client and spin up a long-lived worker that keeps the - # connect/close lifecycle inside one coroutine. - self._client = server_info.create_client() - logger.debug(f"TOOL: Created client for MCPTool '{name}'") - - # The worker starts immediately and blocks here until the connection - # is established (or fails), returning the tool list. - self._worker = _MCPClientSessionManager(self._client, timeout=connection_timeout) - - tools = self._worker.tools() - # Handle no tools case - if not tools: - logger.debug(f"TOOL: No tools found for '{name}'") - message = "No tools available on server" - raise MCPToolNotFoundError(message, tool_name=name) - - # Find the specified tool - tool_dict = {t.name: t for t in tools} - logger.debug(f"TOOL: Available tools: {list(tool_dict.keys())}") - - tool_info: types.Tool | None = tool_dict.get(name) - - if not tool_info: - available = list(tool_dict.keys()) - logger.debug(f"TOOL: Tool '{name}' not found in available tools") - message = f"Tool '{name}' not found on server. Available tools: {', '.join(available)}" - raise MCPToolNotFoundError(message, tool_name=name, available_tools=available) - + logger.debug(f"TOOL: Connecting to MCP server for '{name}'") + tool_info = self._connect_and_initialize(name) logger.debug(f"TOOL: Found tool '{name}', initializing Tool parent class") + # Initialize the parent class super().__init__( name=name, @@ -897,6 +889,36 @@ def __init__( message = f"Failed to initialize MCPTool '{name}': {error_message}" raise MCPConnectionError(message=message, server_info=server_info, operation="initialize") from e + def _connect_and_initialize(self, tool_name: str) -> types.Tool: + """ + Connect to the MCP server and retrieve the tool schema. + + :param tool_name: Name of the tool to look for + :returns: The tool schema for this tool + :raises MCPToolNotFoundError: If the tool is not found on the server + """ + client = self._server_info.create_client() + worker = _MCPClientSessionManager(client, timeout=self._connection_timeout) + tools = worker.tools() + + # Handle no tools case + if not tools: + message = "No tools available on server" + raise MCPToolNotFoundError(message, tool_name=tool_name) + + # Find the specified tool + tool = next((t for t in tools if t.name == tool_name), None) + if tool is None: + available = [t.name for t in tools] + msg = f"Tool '{tool_name}' not found on server. Available tools: {', '.join(available)}" + raise MCPToolNotFoundError(msg, tool_name=tool_name, available_tools=available) + + # Publish connection + self._client = client + self._worker = worker + + return tool + def _invoke_tool(self, **kwargs: Any) -> str: """ Synchronous tool invocation. @@ -906,12 +928,13 @@ def _invoke_tool(self, **kwargs: Any) -> str: """ logger.debug(f"TOOL: Invoking tool '{self.name}' with args: {kwargs}") try: + # Connect on first use if eager_connect is turned off + self.warm_up() async def invoke(): logger.debug(f"TOOL: Inside invoke coroutine for '{self.name}'") - result = await asyncio.wait_for( - self._client.call_tool(self.name, kwargs), timeout=self._invocation_timeout - ) + client = cast(MCPClient, self._client) + result = await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout) logger.debug(f"TOOL: Invoke successful for '{self.name}'") return result @@ -939,7 +962,9 @@ async def ainvoke(self, **kwargs: Any) -> str: :raises TimeoutError: If the operation times out """ try: - return await asyncio.wait_for(self._client.call_tool(self.name, kwargs), timeout=self._invocation_timeout) + self.warm_up() + client = cast(MCPClient, self._client) + return await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout) except asyncio.TimeoutError as e: message = f"Tool invocation timed out after {self._invocation_timeout} seconds" raise TimeoutError(message) from e @@ -949,6 +974,14 @@ async def ainvoke(self, **kwargs: Any) -> str: message = f"Failed to invoke tool '{self.name}' with args: {kwargs} , got error: {e!s}" raise MCPInvocationError(message, self.name, kwargs) from e + def warm_up(self) -> None: + """Connect and fetch the tool schema if eager_connect is turned off.""" + with self._lock: + if self._client is not None: + return + tool = self._connect_and_initialize(self.name) + self.parameters = tool.inputSchema + def to_dict(self) -> dict[str, Any]: """ Serializes the MCPTool to a dictionary. @@ -966,6 +999,7 @@ def to_dict(self) -> dict[str, Any]: "server_info": self._server_info.to_dict(), "connection_timeout": self._connection_timeout, "invocation_timeout": self._invocation_timeout, + "eager_connect": self._eager_connect, } return { "type": generate_qualified_class_name(type(self)), @@ -998,6 +1032,7 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool": # Handle backward compatibility for timeout parameters connection_timeout = inner_data.get("connection_timeout", 30) invocation_timeout = inner_data.get("invocation_timeout", 30) + eager_connect = inner_data.get("eager_connect", False) # because False is the default # Create a new MCPTool instance with the deserialized parameters # This will establish a new connection to the MCP server @@ -1007,6 +1042,7 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool": server_info=server_info, connection_timeout=connection_timeout, invocation_timeout=invocation_timeout, + eager_connect=eager_connect, ) def close(self): diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py index b13f87e856..529e7e487c 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py @@ -120,6 +120,7 @@ def __init__( tool_names: list[str] | None = None, connection_timeout: float = 30.0, invocation_timeout: float = 30.0, + eager_connect: bool = False, ): """ Initialize the MCP toolset. @@ -129,6 +130,8 @@ def __init__( matching names will be added to the toolset. :param connection_timeout: Timeout in seconds for server connection :param invocation_timeout: Default timeout in seconds for tool invocations + :param eager_connect: If True, connect to server and load tools during initialization. + If False (default), defer connection to warm_up. :raises MCPToolNotFoundError: If any of the specified tool names are not found on the server """ # Store configuration @@ -136,8 +139,39 @@ def __init__( self.tool_names = tool_names self.connection_timeout = connection_timeout self.invocation_timeout = invocation_timeout + self.eager_connect = eager_connect + self._warmup_called = False + + if not eager_connect: + # Do not connect during validation; expose a toolset with one fake tool to pass validation + placeholder_tool = Tool( + name=f"mcp_not_connected_placeholder_{id(self)}", + description="Placeholder tool initialised when eager_connect is turned off", + parameters={"type": "object", "properties": {}, "additionalProperties": True}, + function=lambda: None, + ) + super().__init__(tools=[placeholder_tool]) + else: + tools = self._connect_and_load_tools() + super().__init__(tools=tools) + self._warmup_called = True + + def warm_up(self) -> None: + """Connect and load tools when eager_connect is turned off. + + This method is automatically called by ``ToolInvoker.warm_up()`` and ``Pipeline.warm_up()``. + You can also call it directly before using the toolset to ensure all tool schemas + are available without performing a real invocation. + """ + if self._warmup_called: + return + + # connect and load tools never adds duplicate tools, set the tools attribute directly + self.tools = self._connect_and_load_tools() + self._warmup_called = True - # Connect and load tools + def _connect_and_load_tools(self) -> list[Tool]: + """Connect and load tools.""" try: # Create the client and spin up a worker so open/close happen in the # same coroutine, avoiding AnyIO cancel-scope issues. @@ -195,9 +229,7 @@ def invoke_tool(**kwargs: Any) -> Any: ) haystack_tools.append(tool) - # Initialize parent class with complete tools list - super().__init__(tools=haystack_tools) - + return haystack_tools except Exception as e: # We need to close because we could connect properly, retrieve tools yet # fail because of an MCPToolNotFoundError @@ -273,6 +305,7 @@ def to_dict(self) -> dict[str, Any]: "tool_names": self.tool_names, "connection_timeout": self.connection_timeout, "invocation_timeout": self.invocation_timeout, + "eager_connect": self.eager_connect, }, } @@ -297,6 +330,7 @@ def from_dict(cls, data: dict[str, Any]) -> "MCPToolset": tool_names=inner_data.get("tool_names"), connection_timeout=inner_data.get("connection_timeout", 30.0), invocation_timeout=inner_data.get("invocation_timeout", 30.0), + eager_connect=inner_data.get("eager_connect", True), ) def close(self): diff --git a/integrations/mcp/tests/test_mcp_integration.py b/integrations/mcp/tests/test_mcp_integration.py index 04601624fd..484a34109b 100644 --- a/integrations/mcp/tests/test_mcp_integration.py +++ b/integrations/mcp/tests/test_mcp_integration.py @@ -234,7 +234,7 @@ def test_mcp_tool_error_handling_integration(self): # Use a non-existent server address to force a connection error server_info = SSEServerInfo(base_url="http://localhost:9999", timeout=1) # Short timeout with pytest.raises(MCPConnectionError) as exc_info: - MCPTool(name="non_existent_tool", server_info=server_info, connection_timeout=2) + MCPTool(name="non_existent_tool", server_info=server_info, connection_timeout=2, eager_connect=True) # Check for platform-agnostic error message patterns error_message = str(exc_info.value) diff --git a/integrations/mcp/tests/test_mcp_timeout_reconnection.py b/integrations/mcp/tests/test_mcp_timeout_reconnection.py index f2f64025f7..5ed3fcebd5 100644 --- a/integrations/mcp/tests/test_mcp_timeout_reconnection.py +++ b/integrations/mcp/tests/test_mcp_timeout_reconnection.py @@ -13,6 +13,7 @@ import subprocess import sys import tempfile +import textwrap import time from unittest.mock import AsyncMock, MagicMock @@ -108,40 +109,39 @@ def test_real_sse_reconnection_after_server_restart(self): try: # Create server script with cross-platform signal handling with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_file: - temp_file.write( - f""" -import sys -import signal -from mcp.server.fastmcp import FastMCP - -# Handle shutdown signals gracefully (cross-platform) -def signal_handler(signum, frame): - sys.exit(0) - -# Only set up signal handlers that exist on the platform -if hasattr(signal, 'SIGTERM'): - signal.signal(signal.SIGTERM, signal_handler) -if hasattr(signal, 'SIGINT'): - signal.signal(signal.SIGINT, signal_handler) - -mcp = FastMCP("Reconnection Test Server", host="127.0.0.1", port={port}) - -@mcp.tool() -def test_tool(message: str) -> str: - return f"Server response: {{message}}" - -if __name__ == "__main__": - try: - print(f"Starting server on port {port}", flush=True) - mcp.run(transport="sse") - except (KeyboardInterrupt, SystemExit): - print("Server shutting down gracefully", flush=True) - sys.exit(0) - except Exception as e: - print(f"Server error: {{e}}", file=sys.stderr, flush=True) - sys.exit(1) -""".encode() - ) + script_content = textwrap.dedent(f""" + import sys + import signal + from mcp.server.fastmcp import FastMCP + + # Handle shutdown signals gracefully (cross-platform) + def signal_handler(signum, frame): + sys.exit(0) + + # Only set up signal handlers that exist on the platform + if hasattr(signal, 'SIGTERM'): + signal.signal(signal.SIGTERM, signal_handler) + if hasattr(signal, 'SIGINT'): + signal.signal(signal.SIGINT, signal_handler) + + mcp = FastMCP("Reconnection Test Server", host="127.0.0.1", port={port}) + + @mcp.tool() + def test_tool(message: str) -> str: + return f"Server response: {{message}}" + + if __name__ == "__main__": + try: + print(f"Starting server on port {port}", flush=True) + mcp.run(transport="sse") + except (KeyboardInterrupt, SystemExit): + print("Server shutting down gracefully", flush=True) + sys.exit(0) + except Exception as e: + print(f"Server error: {{e}}", file=sys.stderr, flush=True) + sys.exit(1) + """).strip() + temp_file.write(script_content.encode()) server_script_path = temp_file.name # Start server diff --git a/integrations/mcp/tests/test_mcp_tool.py b/integrations/mcp/tests/test_mcp_tool.py index a2e0bf5935..111a58c07a 100644 --- a/integrations/mcp/tests/test_mcp_tool.py +++ b/integrations/mcp/tests/test_mcp_tool.py @@ -1,11 +1,17 @@ import json +import os import pytest +from haystack.components.agents import Agent +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.core.pipeline import Pipeline +from haystack.dataclasses import ChatMessage from haystack.tools.errors import ToolInvocationError from haystack.tools.from_function import tool from haystack_integrations.tools.mcp import ( MCPTool, + StdioServerInfo, ) from .mcp_memory_transport import InMemoryServerInfo @@ -26,21 +32,21 @@ def mcp_add_tool(self, mcp_tool_cleanup): """Provides an MCPTool instance for the 'add' tool using the in-memory calculator server.""" server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) # The MCPTool constructor will fetch the tool's schema from the in-memory server - tool = MCPTool(name="add", server_info=server_info) + tool = MCPTool(name="add", server_info=server_info, eager_connect=True) return mcp_tool_cleanup(tool) @pytest.fixture def mcp_echo_tool(self, mcp_tool_cleanup): """Provides an MCPTool instance for the 'echo' tool using the in-memory echo server.""" server_info = InMemoryServerInfo(server=echo_mcp._mcp_server) - tool = MCPTool(name="echo", server_info=server_info) + tool = MCPTool(name="echo", server_info=server_info, eager_connect=True) return mcp_tool_cleanup(tool) @pytest.fixture def mcp_error_tool(self, mcp_tool_cleanup): """Provides an MCPTool instance for the 'divide_by_zero' tool for error testing.""" server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) - tool = MCPTool(name="divide_by_zero", server_info=server_info) + tool = MCPTool(name="divide_by_zero", server_info=server_info, eager_connect=True) return mcp_tool_cleanup(tool) # New tests using in-memory approach will be added below @@ -90,7 +96,9 @@ def test_mcp_tool_serde(self, mcp_tool_cleanup): """Test serialization and deserialization of MCPTool with in-memory server.""" server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) - tool = MCPTool(name="add", server_info=server_info, description="Addition tool for serde testing") + tool = MCPTool( + name="add", server_info=server_info, description="Addition tool for serde testing", eager_connect=True + ) # Register tool for cleanup mcp_tool_cleanup(tool) @@ -124,3 +132,26 @@ def test_mcp_tool_serde(self, mcp_tool_cleanup): } assert isinstance(new_tool._server_info, InMemoryServerInfo) + + @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_pipeline_warmup_with_mcp_tool(self): + """Test lazy connection with Pipeline.warm_up() - replicates time_pipeline.py.""" + + # Replicate time_pipeline.py using MCPTool instead of MCPToolset + server_info = StdioServerInfo(command="uvx", args=["mcp-server-time", "--local-timezone=Europe/Berlin"]) + + # Create tool with lazy connection (default behavior) + tool = MCPTool(name="get_current_time", server_info=server_info) + try: + # Build pipeline with Agent, Pipeline will warm up the tool in the agent automatically + agent = Agent(chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"), tools=[tool]) + pipeline = Pipeline() + pipeline.add_component("agent", agent) + + user_input_msg = ChatMessage.from_user(text="What is the time in New York?") + result = pipeline.run({"agent": {"messages": [user_input_msg]}}) + assert "New York" in result["agent"]["messages"][3].text + finally: + if tool: + tool.close() diff --git a/integrations/mcp/tests/test_mcp_toolset.py b/integrations/mcp/tests/test_mcp_toolset.py index 68df52d28d..625df58a71 100644 --- a/integrations/mcp/tests/test_mcp_toolset.py +++ b/integrations/mcp/tests/test_mcp_toolset.py @@ -11,10 +11,13 @@ import pytest import pytest_asyncio from haystack import logging +from haystack.components.agents import Agent +from haystack.components.generators.chat import OpenAIChatGenerator from haystack.core.pipeline import Pipeline +from haystack.dataclasses import ChatMessage from haystack.tools import Tool -from haystack_integrations.tools.mcp import MCPToolset +from haystack_integrations.tools.mcp import MCPToolset, StdioServerInfo from haystack_integrations.tools.mcp.mcp_tool import ( MCPConnectionError, MCPToolNotFoundError, @@ -38,6 +41,7 @@ async def calculator_toolset(mcp_tool_cleanup): server_info=server_info, connection_timeout=45, invocation_timeout=60, + eager_connect=True, ) return mcp_tool_cleanup(toolset) @@ -52,6 +56,7 @@ async def echo_toolset(mcp_tool_cleanup): server_info=server_info, connection_timeout=45, invocation_timeout=60, + eager_connect=True, ) return mcp_tool_cleanup(toolset) @@ -67,6 +72,7 @@ async def calculator_toolset_with_tool_filter(mcp_tool_cleanup): tool_names=["add"], # Only include the 'add' tool connection_timeout=45, invocation_timeout=60, + eager_connect=True, ) return mcp_tool_cleanup(toolset) @@ -211,6 +217,7 @@ async def test_toolset_error_handling(self, mock_create_client): server_info=server_info, connection_timeout=1.0, invocation_timeout=1.0, + eager_connect=True, ) async def test_toolset_tool_not_found(self): @@ -223,8 +230,32 @@ async def test_toolset_tool_not_found(self): tool_names=["non_existent_tool"], connection_timeout=10, invocation_timeout=10, + eager_connect=True, ) + @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") + @pytest.mark.integration + async def test_pipeline_warmup_with_mcp_toolset(self): + """Test lazy connection with Pipeline.warm_up() - replicates time_pipeline.py.""" + + # Replicate time_pipeline.py using calculator instead of time server + server_info = StdioServerInfo(command="uvx", args=["mcp-server-time", "--local-timezone=Europe/Berlin"]) + + # Create toolset with lazy connection (default behavior) + toolset = MCPToolset(server_info=server_info) + try: + # Build pipeline exactly like time_pipeline.py + agent = Agent(chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"), tools=toolset) + pipeline = Pipeline() + pipeline.add_component("agent", agent) + + user_input_msg = ChatMessage.from_user(text="What is the time in New York?") + result = pipeline.run({"agent": {"messages": [user_input_msg]}}) + assert "New York" in result["agent"]["messages"][3].text + finally: + if toolset: + toolset.close() + @pytest.mark.integration class TestMCPToolsetIntegration: @@ -280,7 +311,7 @@ def subtract(a: int, b: int) -> int: # Create the toolset server_info = SSEServerInfo(base_url=f"http://127.0.0.1:{port}") - toolset = MCPToolset(server_info=server_info) + toolset = MCPToolset(server_info=server_info, eager_connect=True) # Verify we got both tools assert len(toolset) == 2 @@ -381,7 +412,7 @@ def subtract(a: int, b: int) -> int: # Create the toolset - note the /mcp endpoint for streamable-http server_info = StreamableHttpServerInfo(url=f"http://127.0.0.1:{port}/mcp") - toolset = MCPToolset(server_info=server_info) + toolset = MCPToolset(server_info=server_info, eager_connect=True) # Verify we got both tools assert len(toolset) == 2