Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._config import McpServerParams, SseServerParams, StdioServerParams
from ._factory import mcp_server_tools
from ._session import McpSession, McpSessionActor
from ._sse import SseMcpToolAdapter
from ._stdio import StdioMcpToolAdapter

Expand All @@ -10,4 +11,6 @@
"SseServerParams",
"McpServerParams",
"mcp_server_tools",
"McpSessionActor",
"McpSession",
]
40 changes: 24 additions & 16 deletions python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import builtins
from abc import ABC
from typing import Any, Generic, Type, TypeVar
Expand All @@ -10,7 +9,7 @@
from pydantic import BaseModel

from ._config import McpServerParams
from ._session import create_mcp_server_session
from ._session import McpSession, create_mcp_server_session

TServerParams = TypeVar("TServerParams", bound=McpServerParams)

Expand All @@ -20,15 +19,16 @@ class McpToolAdapter(BaseTool[BaseModel, Any], ABC, Generic[TServerParams]):
Base adapter class for MCP tools to make them compatible with AutoGen.

Args:
server_params (TServerParams): Parameters for the MCP server connection.
session (McpSession): The MCP session to use for communication.
tool (Tool): The MCP tool to wrap.
"""

component_type = "tool"

def __init__(self, server_params: TServerParams, tool: Tool) -> None:
def __init__(self, session: McpSession, tool: Tool) -> None:
self._tool = tool
self._server_params = server_params
self._session = session
self._session.initialize()

# Extract name and description
name = tool.name
Expand Down Expand Up @@ -59,26 +59,33 @@ async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> A
# Convert the input model to a dictionary
# Exclude unset values to avoid sending them to the MCP servers which may cause errors
# for many servers.

kwargs = args.model_dump(exclude_unset=True)

try:
async with create_mcp_server_session(self._server_params) as session:
await session.initialize()

if cancellation_token.is_cancelled():
raise Exception("Operation cancelled")
if cancellation_token.is_cancelled():
raise Exception("Operation cancelled")

result_future = asyncio.ensure_future(session.call_tool(name=self._tool.name, arguments=kwargs))
async with self._session.session() as session:
result_future = await session.call(name=self._tool.name, kwargs=kwargs)
cancellation_token.link_future(result_future)
result = await result_future
result = await result_future

if result.isError:
raise Exception(f"MCP tool execution failed: {result.content}")
return result.content
if result.isError:
raise Exception(f"MCP tool execution failed: {result.content}")
return result.content
except Exception as e:
error_message = self._format_errors(e)
raise Exception(error_message) from e

async def initialize(self) -> None:
"""Initialize the MCP tool adapter."""
self._session.initialize()

async def close(self) -> None:
"""Close the MCP session."""
await self._session.close()

@classmethod
async def from_server_params(cls, server_params: TServerParams, tool_name: str) -> "McpToolAdapter[TServerParams]":
"""
Expand All @@ -105,7 +112,8 @@ async def from_server_params(cls, server_params: TServerParams, tool_name: str)
f"Tool '{tool_name}' not found, available tools: {', '.join([t.name for t in tools_response.tools])}"
)

return cls(server_params=server_params, tool=matching_tool)
session = McpSession(server_params=server_params)
return cls(session=session, tool=matching_tool)

def _format_errors(self, error: Exception) -> str:
"""Recursively format errors into a string."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._config import McpServerParams, SseServerParams, StdioServerParams
from ._session import create_mcp_server_session
from ._session import McpSession, create_mcp_server_session
from ._sse import SseMcpToolAdapter
from ._stdio import StdioMcpToolAdapter

Expand Down Expand Up @@ -135,8 +135,9 @@ async def main() -> None:

tools = await session.list_tools()

session = McpSession(server_params=server_params)
if isinstance(server_params, StdioServerParams):
return [StdioMcpToolAdapter(server_params=server_params, tool=tool) for tool in tools.tools]
return [StdioMcpToolAdapter(session=session, tool=tool) for tool in tools.tools]
elif isinstance(server_params, SseServerParams):
return [SseMcpToolAdapter(server_params=server_params, tool=tool) for tool in tools.tools]
return [SseMcpToolAdapter(session=session, tool=tool) for tool in tools.tools]
raise ValueError(f"Unsupported server params type: {type(server_params)}")
205 changes: 204 additions & 1 deletion python/packages/autogen-ext/src/autogen_ext/tools/mcp/_session.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import asyncio
import atexit
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import AsyncGenerator
from typing import Any, AsyncGenerator, Dict

from autogen_core import Component, ComponentBase
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from pydantic import BaseModel, ConfigDict, PrivateAttr
from typing_extensions import Self

from ._config import McpServerParams, SseServerParams, StdioServerParams

Expand All @@ -26,3 +31,201 @@ async def create_mcp_server_session(
async with sse_client(**server_params.model_dump()) as (read, write):
async with ClientSession(read_stream=read, write_stream=write) as session:
yield session


class McpSessionActorConfig(BaseModel):
server_params: McpServerParams


class McpSessionActor(ComponentBase[BaseModel], Component[McpSessionActorConfig]):
component_type = "mcp_session_actor"
component_config_schema = McpSessionActorConfig
component_provider_override = "autogen_ext.tools.mcp.McpSessionActor"

server_params: McpServerParams
_actor: Any = PrivateAttr(default=None)
# actor: Any = PrivateAttr(default=None)

model_config = ConfigDict(arbitrary_types_allowed=True)

def __init__(self, server_params: McpServerParams) -> None:
self.server_params: McpServerParams = server_params
self.name = "mcp_session_actor"
self.description = "MCP session actor"
self._command_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
self._actor_task: asyncio.Task[Any] | None = None
self._shutdown_future: asyncio.Future[Any] | None = None
self._active = False
atexit.register(self._sync_shutdown)

async def initialize(self) -> None:
if not self._active:
self._active = True
self._actor_task = asyncio.create_task(self._run_actor())

async def call(self, name: str, kwargs: Dict[str, Any]) -> Any:
if not self._active:
raise RuntimeError("MCP Actor not running, call initialize() first")
if self._actor_task and self._actor_task.done():
raise RuntimeError("MCP actor task crashed", self._actor_task.exception())
fut: asyncio.Future[Any] = asyncio.Future()
await self._command_queue.put({"type": "call", "name": name, "args": kwargs, "future": fut})
res = await fut
return res

async def close(self) -> None:
if not self._active or self._actor_task is None:
return
self._shutdown_future = asyncio.Future()
await self._command_queue.put({"type": "shutdown", "future": self._shutdown_future})
await self._shutdown_future
await self._actor_task
self._active = False

async def _run_actor(self) -> None:
try:
async with create_mcp_server_session(self.server_params) as session:
await session.initialize()
while True:
cmd = await self._command_queue.get()
if cmd["type"] == "shutdown":
cmd["future"].set_result("ok")
break
elif cmd["type"] == "call":
try:
result = session.call_tool(name=cmd["name"], arguments=cmd["args"])
cmd["future"].set_result(result)
except Exception as e:
cmd["future"].set_exception(e)
except Exception as e:
if self._shutdown_future and not self._shutdown_future.done():
self._shutdown_future.set_exception(e)
finally:
self._active = False
self._actor_task = None

def _sync_shutdown(self) -> None:
if not self._active or self._actor_task is None:
return
try:
loop = asyncio.get_event_loop()
except RuntimeError:
# No loop available — interpreter is likely shutting down
return

if loop.is_closed():
return

if loop.is_running():
loop.create_task(self.close())
else:
loop.run_until_complete(self.close())

def _to_config(self) -> McpSessionActorConfig:
"""
Convert the adapter to its configuration representation.

Returns:
McpSessionConfig: The configuration of the adapter.
"""
return McpSessionActorConfig(server_params=self.server_params)

@classmethod
def _from_config(cls, config: McpSessionActorConfig) -> Self:
"""
Create an instance of McpSessionActor from its configuration.

Args:
config (McpSessionConfig): The configuration of the adapter.

Returns:
McpSessionActor: An instance of SseMcpToolAdapter.
"""
return cls(server_params=config.server_params)


class McpSessionConfig(BaseModel):
"""Configuration for the MCP session actor."""

session_id: int = 0
server_params: McpServerParams


class McpSession(ComponentBase[BaseModel], Component[McpSessionConfig]):
"""MCP session component.

This component is used to manage the MCP session and provide access to the MCP server.
It is used internally by the MCP tool adapters.

Args:
server_params (McpServerParams): Parameters for the MCP server connection.
session_id (int, optional): Session ID. If 0 or do not insert, a new session will be created.
"""

component_type = "mcp_session"
component_config_schema = McpSessionConfig
component_provider_override = "autogen_ext.tools.mcp.McpSession"

__sessions: Dict[int, McpSessionActor] = {} # singleton instance
__session_ref_count: Dict[int, int] = {} # reference count for each session

def __init__(self, server_params: McpServerParams, session_id: int = 0) -> None:
"""Initialize the MCP session.
Args:
session_id (int): Session ID. If 0, a new session will be created.
server_params (McpServerParams): Parameters for the MCP server connection.
"""
self._server_params: McpServerParams = server_params
if session_id == 0:
self._session_id = max(self.__sessions.keys(), default=0) + 1
if session_id != 0:
if session_id not in self.__sessions:
self._session_id = session_id
else:
self._session_id = session_id

@property
def id(self) -> int:
"""Get the session ID."""
return self._session_id

def initialize(self) -> None:
"""Initialize the MCP session."""
if self._session_id == 0:
raise ValueError("Session ID cannot be 0")
if self._session_id not in self.__sessions:
self.__sessions[self._session_id] = McpSessionActor(self._server_params)
self.__session_ref_count[self._session_id] = 0
self.__session_ref_count[self._session_id] += 1

async def close(self) -> None:
"""Close the MCP session."""
if self._session_id == 0:
raise ValueError("Session ID cannot be 0")
if self._session_id not in self.__sessions:
raise ValueError(f"Session ID {self._session_id} not found")
self.__session_ref_count[self._session_id] -= 1
if self.__session_ref_count[self._session_id] == 0:
await self.__sessions[self._session_id].close()
del self.__sessions[self._session_id]
del self.__session_ref_count[self._session_id]

@asynccontextmanager
async def session(self) -> AsyncGenerator[McpSessionActor, None]:
"""Create a new MCP session."""
if self._session_id == 0:
raise ValueError("Session ID cannot be 0")
"""
if session_id not in self.__sessions:
self.__sessions[session_id] = McpSessionActor(server_params)
"""
await self.__sessions[self._session_id].initialize()
yield self.__sessions[self._session_id]
# do not close the session here, cause all of MCP tools share the same session

def _to_config(self):
return McpSessionConfig(session_id=self._session_id, server_params=self._server_params)

@classmethod
def _from_config(cls, config: McpSessionConfig) -> Self:
return cls(session_id=config.session_id, server_params=config.server_params)
16 changes: 8 additions & 8 deletions python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from autogen_core import Component
from autogen_core import Component, ComponentModel
from mcp import Tool
from pydantic import BaseModel
from typing_extensions import Self

from ._base import McpToolAdapter
from ._config import SseServerParams
from ._session import McpSession


class SseMcpToolAdapterConfig(BaseModel):
"""Configuration for the MCP tool adapter."""

server_params: SseServerParams
session: ComponentModel
tool: Tool


Expand All @@ -34,8 +35,7 @@ class SseMcpToolAdapter(
pip install -U "autogen-ext[mcp]"

Args:
server_params (SseServerParameters): Parameters for the MCP server connection,
including URL, headers, and timeouts
session (McpSession): The MCP session to use for communication with the server.
tool (Tool): The MCP tool to wrap

Examples:
Expand Down Expand Up @@ -86,8 +86,8 @@ async def main() -> None:
component_config_schema = SseMcpToolAdapterConfig
component_provider_override = "autogen_ext.tools.mcp.SseMcpToolAdapter"

def __init__(self, server_params: SseServerParams, tool: Tool) -> None:
super().__init__(server_params=server_params, tool=tool)
def __init__(self, session: McpSession, tool: Tool) -> None:
super().__init__(session=session, tool=tool)

def _to_config(self) -> SseMcpToolAdapterConfig:
"""
Expand All @@ -96,7 +96,7 @@ def _to_config(self) -> SseMcpToolAdapterConfig:
Returns:
SseMcpToolAdapterConfig: The configuration of the adapter.
"""
return SseMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool)
return SseMcpToolAdapterConfig(session=self._session.dump_component(), tool=self._tool)

@classmethod
def _from_config(cls, config: SseMcpToolAdapterConfig) -> Self:
Expand All @@ -109,4 +109,4 @@ def _from_config(cls, config: SseMcpToolAdapterConfig) -> Self:
Returns:
SseMcpToolAdapter: An instance of SseMcpToolAdapter.
"""
return cls(server_params=config.server_params, tool=config.tool)
return cls(session=McpSession.load_component(config.session), tool=config.tool)
Loading