Skip to content
2 changes: 1 addition & 1 deletion integrations/mcp/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
]
dependencies = [
"mcp>=1.8.0",
"haystack-ai>=2.18.0",
"haystack-ai>=2.19.0",
"exceptiongroup", # Backport of ExceptionGroup for Python < 3.11
"httpx" # HTTP client library used for SSE connections
]
Expand Down
100 changes: 68 additions & 32 deletions integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ def __init__(
description: str | None = None,
connection_timeout: int = 30,
invocation_timeout: int = 30,
eager_connect: bool = False,
):
"""
Initialize the MCP tool.
Expand All @@ -823,6 +824,9 @@ def __init__(
:param description: Custom description (if None, server description will be used)
:param connection_timeout: Timeout in seconds for server connection
:param invocation_timeout: Default timeout in seconds for tool invocations
:param eager_connect: If True, connect to server during initialization.
If False (default), defer connection until warm_up or first tool use,
whichever comes first.
:raises MCPConnectionError: If connection to the server fails
:raises MCPToolNotFoundError: If no tools are available or the requested tool is not found
:raises TimeoutError: If connection times out
Expand All @@ -832,39 +836,27 @@ def __init__(
self._server_info = server_info
self._connection_timeout = connection_timeout
self._invocation_timeout = invocation_timeout
self._eager_connect = eager_connect
self._client: MCPClient | None = None
self._worker: _MCPClientSessionManager | None = None
self._lock = threading.RLock()

# don't connect now; initialize permissively
if not eager_connect:
# Permissive placeholder JSON Schema so the Tool is valid
# without discovering the remote schema during validation.
# Tool parameters/schema will be replaced with the correct schema (from the MCP server) on first use.
params = {"type": "object", "properties": {}, "additionalProperties": True}
super().__init__(name=name, description=description or "", parameters=params, function=self._invoke_tool)
return

logger.debug(f"TOOL: Initializing MCPTool '{name}'")

try:
# Create client and spin up a long-lived worker that keeps the
# connect/close lifecycle inside one coroutine.
self._client = server_info.create_client()
logger.debug(f"TOOL: Created client for MCPTool '{name}'")

# The worker starts immediately and blocks here until the connection
# is established (or fails), returning the tool list.
self._worker = _MCPClientSessionManager(self._client, timeout=connection_timeout)

tools = self._worker.tools()
# Handle no tools case
if not tools:
logger.debug(f"TOOL: No tools found for '{name}'")
message = "No tools available on server"
raise MCPToolNotFoundError(message, tool_name=name)

# Find the specified tool
tool_dict = {t.name: t for t in tools}
logger.debug(f"TOOL: Available tools: {list(tool_dict.keys())}")

tool_info: types.Tool | None = tool_dict.get(name)

if not tool_info:
available = list(tool_dict.keys())
logger.debug(f"TOOL: Tool '{name}' not found in available tools")
message = f"Tool '{name}' not found on server. Available tools: {', '.join(available)}"
raise MCPToolNotFoundError(message, tool_name=name, available_tools=available)

logger.debug(f"TOOL: Connecting to MCP server for '{name}'")
tool_info = self._connect_and_initialize(name)
logger.debug(f"TOOL: Found tool '{name}', initializing Tool parent class")

# Initialize the parent class
super().__init__(
name=name,
Expand Down Expand Up @@ -897,6 +889,36 @@ def __init__(
message = f"Failed to initialize MCPTool '{name}': {error_message}"
raise MCPConnectionError(message=message, server_info=server_info, operation="initialize") from e

def _connect_and_initialize(self, tool_name: str) -> types.Tool:
"""
Connect to the MCP server and retrieve the tool schema.

:param tool_name: Name of the tool to look for
:returns: The tool schema for this tool
:raises MCPToolNotFoundError: If the tool is not found on the server
"""
client = self._server_info.create_client()
worker = _MCPClientSessionManager(client, timeout=self._connection_timeout)
tools = worker.tools()

# Handle no tools case
if not tools:
message = "No tools available on server"
raise MCPToolNotFoundError(message, tool_name=tool_name)

# Find the specified tool
tool = next((t for t in tools if t.name == tool_name), None)
if tool is None:
available = [t.name for t in tools]
msg = f"Tool '{tool_name}' not found on server. Available tools: {', '.join(available)}"
raise MCPToolNotFoundError(msg, tool_name=tool_name, available_tools=available)

# Publish connection
self._client = client
self._worker = worker

return tool

def _invoke_tool(self, **kwargs: Any) -> str:
"""
Synchronous tool invocation.
Expand All @@ -906,12 +928,13 @@ def _invoke_tool(self, **kwargs: Any) -> str:
"""
logger.debug(f"TOOL: Invoking tool '{self.name}' with args: {kwargs}")
try:
# Connect on first use if eager_connect is turned off
self.warm_up()

async def invoke():
logger.debug(f"TOOL: Inside invoke coroutine for '{self.name}'")
result = await asyncio.wait_for(
self._client.call_tool(self.name, kwargs), timeout=self._invocation_timeout
)
client = cast(MCPClient, self._client)
result = await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout)
logger.debug(f"TOOL: Invoke successful for '{self.name}'")
return result

Expand Down Expand Up @@ -939,7 +962,9 @@ async def ainvoke(self, **kwargs: Any) -> str:
:raises TimeoutError: If the operation times out
"""
try:
return await asyncio.wait_for(self._client.call_tool(self.name, kwargs), timeout=self._invocation_timeout)
self.warm_up()
client = cast(MCPClient, self._client)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like you are using a different approach here to get around the mypy error than here. Let's make it consistent.

return await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout)
except asyncio.TimeoutError as e:
message = f"Tool invocation timed out after {self._invocation_timeout} seconds"
raise TimeoutError(message) from e
Expand All @@ -949,6 +974,14 @@ async def ainvoke(self, **kwargs: Any) -> str:
message = f"Failed to invoke tool '{self.name}' with args: {kwargs} , got error: {e!s}"
raise MCPInvocationError(message, self.name, kwargs) from e

def warm_up(self) -> None:
"""Connect and fetch the tool schema if eager_connect is turned off."""
with self._lock:
if self._client is not None:
return
tool = self._connect_and_initialize(self.name)
self.parameters = tool.inputSchema

def to_dict(self) -> dict[str, Any]:
"""
Serializes the MCPTool to a dictionary.
Expand All @@ -966,6 +999,7 @@ def to_dict(self) -> dict[str, Any]:
"server_info": self._server_info.to_dict(),
"connection_timeout": self._connection_timeout,
"invocation_timeout": self._invocation_timeout,
"eager_connect": self._eager_connect,
}
return {
"type": generate_qualified_class_name(type(self)),
Expand Down Expand Up @@ -998,6 +1032,7 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool":
# Handle backward compatibility for timeout parameters
connection_timeout = inner_data.get("connection_timeout", 30)
invocation_timeout = inner_data.get("invocation_timeout", 30)
eager_connect = inner_data.get("eager_connect", False) # because False is the default

# Create a new MCPTool instance with the deserialized parameters
# This will establish a new connection to the MCP server
Expand All @@ -1007,6 +1042,7 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool":
server_info=server_info,
connection_timeout=connection_timeout,
invocation_timeout=invocation_timeout,
eager_connect=eager_connect,
)

def close(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
tool_names: list[str] | None = None,
connection_timeout: float = 30.0,
invocation_timeout: float = 30.0,
eager_connect: bool = False,
):
"""
Initialize the MCP toolset.
Expand All @@ -129,15 +130,48 @@ def __init__(
matching names will be added to the toolset.
:param connection_timeout: Timeout in seconds for server connection
:param invocation_timeout: Default timeout in seconds for tool invocations
:param eager_connect: If True, connect to server and load tools during initialization.
If False (default), defer connection to warm_up.
:raises MCPToolNotFoundError: If any of the specified tool names are not found on the server
"""
# Store configuration
self.server_info = server_info
self.tool_names = tool_names
self.connection_timeout = connection_timeout
self.invocation_timeout = invocation_timeout
self.eager_connect = eager_connect
self._warmup_called = False

if not eager_connect:
# Do not connect during validation; expose a toolset with one fake tool to pass validation
placeholder_tool = Tool(
name=f"mcp_not_connected_placeholder_{id(self)}",
description="Placeholder tool initialised when eager_connect is turned off",
parameters={"type": "object", "properties": {}, "additionalProperties": True},
function=lambda: None,
)
super().__init__(tools=[placeholder_tool])
else:
tools = self._connect_and_load_tools()
super().__init__(tools=tools)
self._warmup_called = True

def warm_up(self) -> None:
"""Connect and load tools when eager_connect is turned off.

This method is automatically called by ``ToolInvoker.warm_up()`` and ``Pipeline.warm_up()``.
You can also call it directly before using the toolset to ensure all tool schemas
are available without performing a real invocation.
"""
if self._warmup_called:
return

# connect and load tools never adds duplicate tools, set the tools attribute directly
self.tools = self._connect_and_load_tools()
self._warmup_called = True

# Connect and load tools
def _connect_and_load_tools(self) -> list[Tool]:
"""Connect and load tools."""
try:
# Create the client and spin up a worker so open/close happen in the
# same coroutine, avoiding AnyIO cancel-scope issues.
Expand Down Expand Up @@ -195,9 +229,7 @@ def invoke_tool(**kwargs: Any) -> Any:
)
haystack_tools.append(tool)

# Initialize parent class with complete tools list
super().__init__(tools=haystack_tools)

return haystack_tools
except Exception as e:
# We need to close because we could connect properly, retrieve tools yet
# fail because of an MCPToolNotFoundError
Expand Down Expand Up @@ -273,6 +305,7 @@ def to_dict(self) -> dict[str, Any]:
"tool_names": self.tool_names,
"connection_timeout": self.connection_timeout,
"invocation_timeout": self.invocation_timeout,
"eager_connect": self.eager_connect,
},
}

Expand All @@ -297,6 +330,7 @@ def from_dict(cls, data: dict[str, Any]) -> "MCPToolset":
tool_names=inner_data.get("tool_names"),
connection_timeout=inner_data.get("connection_timeout", 30.0),
invocation_timeout=inner_data.get("invocation_timeout", 30.0),
eager_connect=inner_data.get("eager_connect", True),
)

def close(self):
Expand Down
2 changes: 1 addition & 1 deletion integrations/mcp/tests/test_mcp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def test_mcp_tool_error_handling_integration(self):
# Use a non-existent server address to force a connection error
server_info = SSEServerInfo(base_url="http://localhost:9999", timeout=1) # Short timeout
with pytest.raises(MCPConnectionError) as exc_info:
MCPTool(name="non_existent_tool", server_info=server_info, connection_timeout=2)
MCPTool(name="non_existent_tool", server_info=server_info, connection_timeout=2, eager_connect=True)

# Check for platform-agnostic error message patterns
error_message = str(exc_info.value)
Expand Down
68 changes: 34 additions & 34 deletions integrations/mcp/tests/test_mcp_timeout_reconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import subprocess
import sys
import tempfile
import textwrap
import time
from unittest.mock import AsyncMock, MagicMock

Expand Down Expand Up @@ -108,40 +109,39 @@ def test_real_sse_reconnection_after_server_restart(self):
try:
# Create server script with cross-platform signal handling
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_file:
temp_file.write(
f"""
import sys
import signal
from mcp.server.fastmcp import FastMCP

# Handle shutdown signals gracefully (cross-platform)
def signal_handler(signum, frame):
sys.exit(0)

# Only set up signal handlers that exist on the platform
if hasattr(signal, 'SIGTERM'):
signal.signal(signal.SIGTERM, signal_handler)
if hasattr(signal, 'SIGINT'):
signal.signal(signal.SIGINT, signal_handler)

mcp = FastMCP("Reconnection Test Server", host="127.0.0.1", port={port})

@mcp.tool()
def test_tool(message: str) -> str:
return f"Server response: {{message}}"

if __name__ == "__main__":
try:
print(f"Starting server on port {port}", flush=True)
mcp.run(transport="sse")
except (KeyboardInterrupt, SystemExit):
print("Server shutting down gracefully", flush=True)
sys.exit(0)
except Exception as e:
print(f"Server error: {{e}}", file=sys.stderr, flush=True)
sys.exit(1)
""".encode()
)
script_content = textwrap.dedent(f"""
import sys
import signal
from mcp.server.fastmcp import FastMCP

# Handle shutdown signals gracefully (cross-platform)
def signal_handler(signum, frame):
sys.exit(0)

# Only set up signal handlers that exist on the platform
if hasattr(signal, 'SIGTERM'):
signal.signal(signal.SIGTERM, signal_handler)
if hasattr(signal, 'SIGINT'):
signal.signal(signal.SIGINT, signal_handler)

mcp = FastMCP("Reconnection Test Server", host="127.0.0.1", port={port})

@mcp.tool()
def test_tool(message: str) -> str:
return f"Server response: {{message}}"

if __name__ == "__main__":
try:
print(f"Starting server on port {port}", flush=True)
mcp.run(transport="sse")
except (KeyboardInterrupt, SystemExit):
print("Server shutting down gracefully", flush=True)
sys.exit(0)
except Exception as e:
print(f"Server error: {{e}}", file=sys.stderr, flush=True)
sys.exit(1)
""").strip()
temp_file.write(script_content.encode())
server_script_path = temp_file.name

# Start server
Expand Down
Loading