diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py index 83d76fcad502..1b1cb4ab69e1 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py @@ -1,5 +1,6 @@ from ._config import McpServerParams, SseServerParams, StdioServerParams from ._factory import mcp_server_tools +from ._session import McpSession, McpSessionActor from ._sse import SseMcpToolAdapter from ._stdio import StdioMcpToolAdapter @@ -10,4 +11,6 @@ "SseServerParams", "McpServerParams", "mcp_server_tools", + "McpSessionActor", + "McpSession", ] diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py index 0901be9eda93..0d2192c46fd5 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py @@ -1,4 +1,3 @@ -import asyncio import builtins from abc import ABC from typing import Any, Generic, Type, TypeVar @@ -10,7 +9,7 @@ from pydantic import BaseModel from ._config import McpServerParams -from ._session import create_mcp_server_session +from ._session import McpSession, create_mcp_server_session TServerParams = TypeVar("TServerParams", bound=McpServerParams) @@ -20,15 +19,16 @@ class McpToolAdapter(BaseTool[BaseModel, Any], ABC, Generic[TServerParams]): Base adapter class for MCP tools to make them compatible with AutoGen. Args: - server_params (TServerParams): Parameters for the MCP server connection. + session (McpSession): The MCP session to use for communication. tool (Tool): The MCP tool to wrap. """ component_type = "tool" - def __init__(self, server_params: TServerParams, tool: Tool) -> None: + def __init__(self, session: McpSession, tool: Tool) -> None: self._tool = tool - self._server_params = server_params + self._session = session + self._session.initialize() # Extract name and description name = tool.name @@ -59,26 +59,33 @@ async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> A # Convert the input model to a dictionary # Exclude unset values to avoid sending them to the MCP servers which may cause errors # for many servers. + kwargs = args.model_dump(exclude_unset=True) try: - async with create_mcp_server_session(self._server_params) as session: - await session.initialize() - - if cancellation_token.is_cancelled(): - raise Exception("Operation cancelled") + if cancellation_token.is_cancelled(): + raise Exception("Operation cancelled") - result_future = asyncio.ensure_future(session.call_tool(name=self._tool.name, arguments=kwargs)) + async with self._session.session() as session: + result_future = await session.call(name=self._tool.name, kwargs=kwargs) cancellation_token.link_future(result_future) - result = await result_future + result = await result_future - if result.isError: - raise Exception(f"MCP tool execution failed: {result.content}") - return result.content + if result.isError: + raise Exception(f"MCP tool execution failed: {result.content}") + return result.content except Exception as e: error_message = self._format_errors(e) raise Exception(error_message) from e + async def initialize(self) -> None: + """Initialize the MCP tool adapter.""" + self._session.initialize() + + async def close(self) -> None: + """Close the MCP session.""" + await self._session.close() + @classmethod async def from_server_params(cls, server_params: TServerParams, tool_name: str) -> "McpToolAdapter[TServerParams]": """ @@ -105,7 +112,8 @@ async def from_server_params(cls, server_params: TServerParams, tool_name: str) f"Tool '{tool_name}' not found, available tools: {', '.join([t.name for t in tools_response.tools])}" ) - return cls(server_params=server_params, tool=matching_tool) + session = McpSession(server_params=server_params) + return cls(session=session, tool=matching_tool) def _format_errors(self, error: Exception) -> str: """Recursively format errors into a string.""" diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py index 3b8c2356b79f..861611b51f58 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py @@ -1,5 +1,5 @@ from ._config import McpServerParams, SseServerParams, StdioServerParams -from ._session import create_mcp_server_session +from ._session import McpSession, create_mcp_server_session from ._sse import SseMcpToolAdapter from ._stdio import StdioMcpToolAdapter @@ -135,8 +135,9 @@ async def main() -> None: tools = await session.list_tools() + session = McpSession(server_params=server_params) if isinstance(server_params, StdioServerParams): - return [StdioMcpToolAdapter(server_params=server_params, tool=tool) for tool in tools.tools] + return [StdioMcpToolAdapter(session=session, tool=tool) for tool in tools.tools] elif isinstance(server_params, SseServerParams): - return [SseMcpToolAdapter(server_params=server_params, tool=tool) for tool in tools.tools] + return [SseMcpToolAdapter(session=session, tool=tool) for tool in tools.tools] raise ValueError(f"Unsupported server params type: {type(server_params)}") diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_session.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_session.py index bc1a28fac5cb..00a2aa5d45df 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_session.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_session.py @@ -1,10 +1,15 @@ +import asyncio +import atexit from contextlib import asynccontextmanager from datetime import timedelta -from typing import AsyncGenerator +from typing import Any, AsyncGenerator, Dict +from autogen_core import Component, ComponentBase from mcp import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client +from pydantic import BaseModel, ConfigDict, PrivateAttr +from typing_extensions import Self from ._config import McpServerParams, SseServerParams, StdioServerParams @@ -26,3 +31,201 @@ async def create_mcp_server_session( async with sse_client(**server_params.model_dump()) as (read, write): async with ClientSession(read_stream=read, write_stream=write) as session: yield session + + +class McpSessionActorConfig(BaseModel): + server_params: McpServerParams + + +class McpSessionActor(ComponentBase[BaseModel], Component[McpSessionActorConfig]): + component_type = "mcp_session_actor" + component_config_schema = McpSessionActorConfig + component_provider_override = "autogen_ext.tools.mcp.McpSessionActor" + + server_params: McpServerParams + _actor: Any = PrivateAttr(default=None) + # actor: Any = PrivateAttr(default=None) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def __init__(self, server_params: McpServerParams) -> None: + self.server_params: McpServerParams = server_params + self.name = "mcp_session_actor" + self.description = "MCP session actor" + self._command_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() + self._actor_task: asyncio.Task[Any] | None = None + self._shutdown_future: asyncio.Future[Any] | None = None + self._active = False + atexit.register(self._sync_shutdown) + + async def initialize(self) -> None: + if not self._active: + self._active = True + self._actor_task = asyncio.create_task(self._run_actor()) + + async def call(self, name: str, kwargs: Dict[str, Any]) -> Any: + if not self._active: + raise RuntimeError("MCP Actor not running, call initialize() first") + if self._actor_task and self._actor_task.done(): + raise RuntimeError("MCP actor task crashed", self._actor_task.exception()) + fut: asyncio.Future[Any] = asyncio.Future() + await self._command_queue.put({"type": "call", "name": name, "args": kwargs, "future": fut}) + res = await fut + return res + + async def close(self) -> None: + if not self._active or self._actor_task is None: + return + self._shutdown_future = asyncio.Future() + await self._command_queue.put({"type": "shutdown", "future": self._shutdown_future}) + await self._shutdown_future + await self._actor_task + self._active = False + + async def _run_actor(self) -> None: + try: + async with create_mcp_server_session(self.server_params) as session: + await session.initialize() + while True: + cmd = await self._command_queue.get() + if cmd["type"] == "shutdown": + cmd["future"].set_result("ok") + break + elif cmd["type"] == "call": + try: + result = session.call_tool(name=cmd["name"], arguments=cmd["args"]) + cmd["future"].set_result(result) + except Exception as e: + cmd["future"].set_exception(e) + except Exception as e: + if self._shutdown_future and not self._shutdown_future.done(): + self._shutdown_future.set_exception(e) + finally: + self._active = False + self._actor_task = None + + def _sync_shutdown(self) -> None: + if not self._active or self._actor_task is None: + return + try: + loop = asyncio.get_event_loop() + except RuntimeError: + # No loop available — interpreter is likely shutting down + return + + if loop.is_closed(): + return + + if loop.is_running(): + loop.create_task(self.close()) + else: + loop.run_until_complete(self.close()) + + def _to_config(self) -> McpSessionActorConfig: + """ + Convert the adapter to its configuration representation. + + Returns: + McpSessionConfig: The configuration of the adapter. + """ + return McpSessionActorConfig(server_params=self.server_params) + + @classmethod + def _from_config(cls, config: McpSessionActorConfig) -> Self: + """ + Create an instance of McpSessionActor from its configuration. + + Args: + config (McpSessionConfig): The configuration of the adapter. + + Returns: + McpSessionActor: An instance of SseMcpToolAdapter. + """ + return cls(server_params=config.server_params) + + +class McpSessionConfig(BaseModel): + """Configuration for the MCP session actor.""" + + session_id: int = 0 + server_params: McpServerParams + + +class McpSession(ComponentBase[BaseModel], Component[McpSessionConfig]): + """MCP session component. + + This component is used to manage the MCP session and provide access to the MCP server. + It is used internally by the MCP tool adapters. + + Args: + server_params (McpServerParams): Parameters for the MCP server connection. + session_id (int, optional): Session ID. If 0 or do not insert, a new session will be created. + """ + + component_type = "mcp_session" + component_config_schema = McpSessionConfig + component_provider_override = "autogen_ext.tools.mcp.McpSession" + + __sessions: Dict[int, McpSessionActor] = {} # singleton instance + __session_ref_count: Dict[int, int] = {} # reference count for each session + + def __init__(self, server_params: McpServerParams, session_id: int = 0) -> None: + """Initialize the MCP session. + Args: + session_id (int): Session ID. If 0, a new session will be created. + server_params (McpServerParams): Parameters for the MCP server connection. + """ + self._server_params: McpServerParams = server_params + if session_id == 0: + self._session_id = max(self.__sessions.keys(), default=0) + 1 + if session_id != 0: + if session_id not in self.__sessions: + self._session_id = session_id + else: + self._session_id = session_id + + @property + def id(self) -> int: + """Get the session ID.""" + return self._session_id + + def initialize(self) -> None: + """Initialize the MCP session.""" + if self._session_id == 0: + raise ValueError("Session ID cannot be 0") + if self._session_id not in self.__sessions: + self.__sessions[self._session_id] = McpSessionActor(self._server_params) + self.__session_ref_count[self._session_id] = 0 + self.__session_ref_count[self._session_id] += 1 + + async def close(self) -> None: + """Close the MCP session.""" + if self._session_id == 0: + raise ValueError("Session ID cannot be 0") + if self._session_id not in self.__sessions: + raise ValueError(f"Session ID {self._session_id} not found") + self.__session_ref_count[self._session_id] -= 1 + if self.__session_ref_count[self._session_id] == 0: + await self.__sessions[self._session_id].close() + del self.__sessions[self._session_id] + del self.__session_ref_count[self._session_id] + + @asynccontextmanager + async def session(self) -> AsyncGenerator[McpSessionActor, None]: + """Create a new MCP session.""" + if self._session_id == 0: + raise ValueError("Session ID cannot be 0") + """ + if session_id not in self.__sessions: + self.__sessions[session_id] = McpSessionActor(server_params) + """ + await self.__sessions[self._session_id].initialize() + yield self.__sessions[self._session_id] + # do not close the session here, cause all of MCP tools share the same session + + def _to_config(self): + return McpSessionConfig(session_id=self._session_id, server_params=self._server_params) + + @classmethod + def _from_config(cls, config: McpSessionConfig) -> Self: + return cls(session_id=config.session_id, server_params=config.server_params) diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py index 252af7ce50da..7b4319286398 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py @@ -1,16 +1,17 @@ -from autogen_core import Component +from autogen_core import Component, ComponentModel from mcp import Tool from pydantic import BaseModel from typing_extensions import Self from ._base import McpToolAdapter from ._config import SseServerParams +from ._session import McpSession class SseMcpToolAdapterConfig(BaseModel): """Configuration for the MCP tool adapter.""" - server_params: SseServerParams + session: ComponentModel tool: Tool @@ -34,8 +35,7 @@ class SseMcpToolAdapter( pip install -U "autogen-ext[mcp]" Args: - server_params (SseServerParameters): Parameters for the MCP server connection, - including URL, headers, and timeouts + session (McpSession): The MCP session to use for communication with the server. tool (Tool): The MCP tool to wrap Examples: @@ -86,8 +86,8 @@ async def main() -> None: component_config_schema = SseMcpToolAdapterConfig component_provider_override = "autogen_ext.tools.mcp.SseMcpToolAdapter" - def __init__(self, server_params: SseServerParams, tool: Tool) -> None: - super().__init__(server_params=server_params, tool=tool) + def __init__(self, session: McpSession, tool: Tool) -> None: + super().__init__(session=session, tool=tool) def _to_config(self) -> SseMcpToolAdapterConfig: """ @@ -96,7 +96,7 @@ def _to_config(self) -> SseMcpToolAdapterConfig: Returns: SseMcpToolAdapterConfig: The configuration of the adapter. """ - return SseMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool) + return SseMcpToolAdapterConfig(session=self._session.dump_component(), tool=self._tool) @classmethod def _from_config(cls, config: SseMcpToolAdapterConfig) -> Self: @@ -109,4 +109,4 @@ def _from_config(cls, config: SseMcpToolAdapterConfig) -> Self: Returns: SseMcpToolAdapter: An instance of SseMcpToolAdapter. """ - return cls(server_params=config.server_params, tool=config.tool) + return cls(session=McpSession.load_component(config.session), tool=config.tool) diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py index 4f827785e903..e41a1b1c983e 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py @@ -1,16 +1,17 @@ -from autogen_core import Component +from autogen_core import Component, ComponentModel from mcp import Tool from pydantic import BaseModel from typing_extensions import Self from ._base import McpToolAdapter from ._config import StdioServerParams +from ._session import McpSession class StdioMcpToolAdapterConfig(BaseModel): """Configuration for the MCP tool adapter.""" - server_params: StdioServerParams + session: ComponentModel tool: Tool @@ -34,8 +35,7 @@ class StdioMcpToolAdapter( Args: - server_params (StdioServerParams): Parameters for the MCP server connection, - including command to run and its arguments + session (McpSession): The MCP session to use for communication with the server. tool (Tool): The MCP tool to wrap See :func:`~autogen_ext.tools.mcp.mcp_server_tools` for examples. @@ -44,8 +44,8 @@ class StdioMcpToolAdapter( component_config_schema = StdioMcpToolAdapterConfig component_provider_override = "autogen_ext.tools.mcp.StdioMcpToolAdapter" - def __init__(self, server_params: StdioServerParams, tool: Tool) -> None: - super().__init__(server_params=server_params, tool=tool) + def __init__(self, session: McpSession, tool: Tool) -> None: + super().__init__(session=session, tool=tool) def _to_config(self) -> StdioMcpToolAdapterConfig: """ @@ -54,7 +54,7 @@ def _to_config(self) -> StdioMcpToolAdapterConfig: Returns: StdioMcpToolAdapterConfig: The configuration of the adapter. """ - return StdioMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool) + return StdioMcpToolAdapterConfig(session=self._session.dump_component(), tool=self._tool) @classmethod def _from_config(cls, config: StdioMcpToolAdapterConfig) -> Self: @@ -67,4 +67,4 @@ def _from_config(cls, config: StdioMcpToolAdapterConfig) -> Self: Returns: StdioMcpToolAdapter: An instance of StdioMcpToolAdapter. """ - return cls(server_params=config.server_params, tool=config.tool) + return cls(session=McpSession.load_component(config.session), tool=config.tool) diff --git a/python/packages/autogen-ext/tests/tools/test_mcp_tools.py b/python/packages/autogen-ext/tests/tools/test_mcp_tools.py index 14ba9c89ca30..0cec8ecf957c 100644 --- a/python/packages/autogen-ext/tests/tools/test_mcp_tools.py +++ b/python/packages/autogen-ext/tests/tools/test_mcp_tools.py @@ -1,10 +1,13 @@ +import asyncio import logging import os +from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock import pytest from autogen_core import CancellationToken from autogen_ext.tools.mcp import ( + McpSession, SseMcpToolAdapter, SseServerParams, StdioMcpToolAdapter, @@ -79,10 +82,12 @@ def cancellation_token() -> CancellationToken: def test_adapter_config_serialization(sample_tool: Tool, sample_server_params: StdioServerParams) -> None: """Test that adapter can be saved to and loaded from config.""" - original_adapter = StdioMcpToolAdapter(server_params=sample_server_params, tool=sample_tool) + # Create an instance of the adapter + session = McpSession(server_params=sample_server_params) + original_adapter = StdioMcpToolAdapter(session=session, tool=sample_tool) config = original_adapter.dump_component() loaded_adapter = StdioMcpToolAdapter.load_component(config) - + asyncio.run(session.close()) # Test that the loaded adapter has the same properties assert loaded_adapter.name == "test_tool" assert loaded_adapter.description == "A test tool" @@ -115,17 +120,22 @@ async def test_mcp_tool_execution( caplog: pytest.LogCaptureFixture, ) -> None: """Test that adapter properly executes tools through ClientSession.""" - mock_context = AsyncMock() - mock_context.__aenter__.return_value = mock_session + + @asynccontextmanager + async def fake_create_session(*args, **kwargs): # type: ignore + yield mock_session + monkeypatch.setattr( - "autogen_ext.tools.mcp._base.create_mcp_server_session", - lambda *args, **kwargs: mock_context, # type: ignore + "autogen_ext.tools.mcp._session.create_mcp_server_session", + fake_create_session, # type: ignore ) mock_session.call_tool.return_value = mock_tool_response with caplog.at_level(logging.INFO): - adapter = StdioMcpToolAdapter(server_params=sample_server_params, tool=sample_tool) + session = McpSession(server_params=sample_server_params) + session.initialize() + adapter = StdioMcpToolAdapter(session=session, tool=sample_tool) result = await adapter.run_json( args=create_model(sample_tool.inputSchema)(**{"test_param": "test"}).model_dump(), cancellation_token=cancellation_token, @@ -137,6 +147,7 @@ async def test_mcp_tool_execution( # Check log. assert "test_output" in caplog.text + await session.close() @pytest.mark.asyncio @@ -154,10 +165,23 @@ async def test_adapter_from_server_params( lambda *args, **kwargs: mock_context, # type: ignore ) + @asynccontextmanager + async def fake_create_session(*args, **kwargs): # type: ignore + try: + yield mock_session + finally: + # graceful shutdown + pass + + monkeypatch.setattr( + "autogen_ext.tools.mcp._session.create_mcp_server_session", + fake_create_session, # type: ignore + ) + mock_session.list_tools.return_value.tools = [sample_tool] adapter = await StdioMcpToolAdapter.from_server_params(sample_server_params, "test_tool") - + await adapter.close() assert isinstance(adapter, StdioMcpToolAdapter) assert adapter.name == "test_tool" assert adapter.description == "A test tool" @@ -183,7 +207,9 @@ async def test_adapter_from_server_params( async def test_sse_adapter_config_serialization(sample_sse_tool: Tool) -> None: """Test that SSE adapter can be saved to and loaded from config.""" params = SseServerParams(url="http://test-url") - original_adapter = SseMcpToolAdapter(server_params=params, tool=sample_sse_tool) + session = McpSession(server_params=params) + session.initialize() + original_adapter = SseMcpToolAdapter(session=session, tool=sample_sse_tool) config = original_adapter.dump_component() loaded_adapter = SseMcpToolAdapter.load_component(config) @@ -207,6 +233,7 @@ async def test_sse_adapter_config_serialization(sample_sse_tool: Tool) -> None: params_schema["properties"]["test_param"]["type"] == sample_sse_tool.inputSchema["properties"]["test_param"]["type"] ) + await session.close() @pytest.mark.asyncio @@ -218,18 +245,23 @@ async def test_sse_tool_execution( ) -> None: """Test that SSE adapter properly executes tools through ClientSession.""" params = SseServerParams(url="http://test-url") - mock_context = AsyncMock() - mock_context.__aenter__.return_value = mock_sse_session - mock_sse_session.call_tool.return_value = MagicMock(isError=False, content={"result": "test_output"}) + mock_result = MagicMock(isError=False, content={"result": "test_output"}) + mock_sse_session.call_tool.return_value = mock_result + + @asynccontextmanager + async def fake_create_session(*args, **kwargs): # type: ignore + yield mock_sse_session monkeypatch.setattr( - "autogen_ext.tools.mcp._base.create_mcp_server_session", - lambda *args, **kwargs: mock_context, # type: ignore + "autogen_ext.tools.mcp._session.create_mcp_server_session", + fake_create_session, # type: ignore ) with caplog.at_level(logging.INFO): - adapter = SseMcpToolAdapter(server_params=params, tool=sample_sse_tool) + session = McpSession(server_params=params) + session.initialize() + adapter = SseMcpToolAdapter(session=session, tool=sample_sse_tool) result = await adapter.run_json( args=create_model(sample_sse_tool.inputSchema)(**{"test_param": "test"}).model_dump(), cancellation_token=CancellationToken(), @@ -242,6 +274,8 @@ async def test_sse_tool_execution( # Check log. assert "test_output" in caplog.text + await session.close() + @pytest.mark.asyncio async def test_sse_adapter_from_server_params( @@ -251,6 +285,8 @@ async def test_sse_adapter_from_server_params( ) -> None: """Test that SSE adapter can be created from server parameters.""" params = SseServerParams(url="http://test-url") + mock_sse_session.list_tools.return_value.tools = [sample_sse_tool] + mock_context = AsyncMock() mock_context.__aenter__.return_value = mock_sse_session monkeypatch.setattr( @@ -258,9 +294,17 @@ async def test_sse_adapter_from_server_params( lambda *args, **kwargs: mock_context, # type: ignore ) - mock_sse_session.list_tools.return_value.tools = [sample_sse_tool] + @asynccontextmanager + async def fake_create_session(*args, **kwargs): # type: ignore + yield mock_sse_session + + monkeypatch.setattr( + "autogen_ext.tools.mcp._session.create_mcp_server_session", + fake_create_session, # type: ignore + ) adapter = await SseMcpToolAdapter.from_server_params(params, "test_sse_tool") + await adapter.close() assert isinstance(adapter, SseMcpToolAdapter) assert adapter.name == "test_sse_tool"