diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py index f75372110095..0492c79ad631 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py @@ -1,3 +1,4 @@ +from ._actor import McpSessionActor from ._config import McpServerParams, SseServerParams, StdioServerParams from ._factory import mcp_server_tools from ._session import create_mcp_server_session @@ -7,6 +8,7 @@ __all__ = [ "create_mcp_server_session", + "McpSessionActor", "StdioMcpToolAdapter", "StdioServerParams", "SseMcpToolAdapter", diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_actor.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_actor.py new file mode 100644 index 000000000000..7e84e24b5d36 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_actor.py @@ -0,0 +1,147 @@ +import asyncio +import atexit +from typing import Any, Coroutine, Dict, Mapping, TypedDict + +from autogen_core import Component, ComponentBase +from mcp.types import CallToolResult, ListToolsResult +from pydantic import BaseModel +from typing_extensions import Self + +from ._config import McpServerParams +from ._session import create_mcp_server_session + +McpResult = Coroutine[Any, Any, ListToolsResult] | Coroutine[Any, Any, CallToolResult] +McpFuture = asyncio.Future[McpResult] + + +class McpActorArgs(TypedDict): + name: str | None + kargs: Mapping[str, Any] + + +class McpSessionActorConfig(BaseModel): + server_params: McpServerParams + + +class McpSessionActor(ComponentBase[BaseModel], Component[McpSessionActorConfig]): + component_type = "mcp_session_actor" + component_config_schema = McpSessionActorConfig + component_provider_override = "autogen_ext.tools.mcp.McpSessionActor" + + server_params: McpServerParams + + # model_config = ConfigDict(arbitrary_types_allowed=True) + + def __init__(self, server_params: McpServerParams) -> None: + self.server_params: McpServerParams = server_params + self.name = "mcp_session_actor" + self.description = "MCP session actor" + self._command_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() + self._actor_task: asyncio.Task[Any] | None = None + self._shutdown_future: asyncio.Future[Any] | None = None + self._active = False + atexit.register(self._sync_shutdown) + + async def initialize(self) -> None: + if not self._active: + self._active = True + self._actor_task = asyncio.create_task(self._run_actor()) + + async def call(self, type: str, args: McpActorArgs | None = None) -> McpFuture: + if not self._active: + raise RuntimeError("MCP Actor not running, call initialize() first") + if self._actor_task and self._actor_task.done(): + raise RuntimeError("MCP actor task crashed", self._actor_task.exception()) + fut: asyncio.Future[McpFuture] = asyncio.Future() + if type in {"list_tools", "shutdown"}: + await self._command_queue.put({"type": type, "future": fut}) + res = await fut + elif type == "call_tool": + if args is None: + raise ValueError("args is required for call_tool") + name = args.get("name", None) + kwargs = args.get("kargs", {}) + if name is None: + raise ValueError("name is required for call_tool") + await self._command_queue.put({"type": type, "name": name, "args": kwargs, "future": fut}) + res = await fut + else: + raise ValueError(f"Unknown command type: {type}") + return res + + async def close(self) -> None: + if not self._active or self._actor_task is None: + return + self._shutdown_future = asyncio.Future() + await self._command_queue.put({"type": "shutdown", "future": self._shutdown_future}) + await self._shutdown_future + await self._actor_task + self._active = False + + async def _run_actor(self) -> None: + result: McpResult + try: + async with create_mcp_server_session(self.server_params) as session: + await session.initialize() + while True: + cmd = await self._command_queue.get() + if cmd["type"] == "shutdown": + cmd["future"].set_result("ok") + break + elif cmd["type"] == "call_tool": + try: + result = session.call_tool(name=cmd["name"], arguments=cmd["args"]) + cmd["future"].set_result(result) + except Exception as e: + cmd["future"].set_exception(e) + elif cmd["type"] == "list_tools": + try: + result = session.list_tools() + cmd["future"].set_result(result) + except Exception as e: + cmd["future"].set_exception(e) + except Exception as e: + if self._shutdown_future and not self._shutdown_future.done(): + self._shutdown_future.set_exception(e) + finally: + self._active = False + self._actor_task = None + + def _sync_shutdown(self) -> None: + if not self._active or self._actor_task is None: + return + try: + loop = asyncio.get_event_loop() + except RuntimeError: + # No loop available — interpreter is likely shutting down + return + + if loop.is_closed(): + return + + if loop.is_running(): + loop.create_task(self.close()) + else: + loop.run_until_complete(self.close()) + + def _to_config(self) -> McpSessionActorConfig: + """ + Convert the adapter to its configuration representation. + + Returns: + McpSessionConfig: The configuration of the adapter. + """ + return McpSessionActorConfig(server_params=self.server_params) + + @classmethod + def _from_config(cls, config: McpSessionActorConfig) -> Self: + """ + Create an instance of McpSessionActor from its configuration. + + Args: + config (McpSessionConfig): The configuration of the adapter. + + Returns: + McpSessionActor: An instance of SseMcpToolAdapter. + """ + return cls(server_params=config.server_params) diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_workbench.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_workbench.py index a587b175340d..5187c59785b8 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_workbench.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_workbench.py @@ -1,9 +1,7 @@ -import asyncio import builtins -from datetime import timedelta -from typing import Any, Dict, List, Literal, Mapping, Optional +from typing import Any, List, Literal, Mapping -from autogen_core import CancellationToken, Component, ComponentModel, Image +from autogen_core import CancellationToken, Component, Image from autogen_core.tools import ( ImageResultContent, ParametersSchema, @@ -12,13 +10,11 @@ ToolSchema, WorkBench, ) -from mcp import ClientSession -from mcp.client.sse import sse_client -from mcp.client.stdio import stdio_client -from mcp.types import EmbeddedResource, ImageContent, TextContent +from mcp.types import CallToolResult, EmbeddedResource, ImageContent, ListToolsResult, TextContent from pydantic import BaseModel -from typing_extensions import Annotated, Self +from typing_extensions import Self +from ._actor import McpSessionActor from ._config import McpServerParams, SseServerParams, StdioServerParams @@ -38,15 +34,24 @@ class McpWorkBench(WorkBench, Component[McpWorkBenchConfig]): def __init__(self, server_params: McpServerParams) -> None: self._server_params = server_params - self._session: ClientSession | None = None + # self._session: ClientSession | None = None + self._actor: McpSessionActor | None = None self._read = None self._write = None async def list_tools(self) -> List[ToolSchema]: - if not self._session: - raise RuntimeError("Session is not initialized. Call start() first.") - list_tool_result = await self._session.list_tools() - schema = [] + if not self._actor: + await self.start() # fallback to start the actor if not initialized instead of raising an error + # Why? Because when deserializing the workbench, the actor might not be initialized yet. + # raise RuntimeError("Actor is not initialized. Call start() first.") + if self._actor is None: + raise RuntimeError("Actor is not initialized. Please check the server connection.") + result_future = await self._actor.call("list_tools", None) + list_tool_result = await result_future + assert isinstance( + list_tool_result, ListToolsResult + ), f"list_tools must return a CallToolResult, instead of : {str(type(list_tool_result))}" + schema: List[ToolSchema] = [] for tool in list_tool_result.tools: name = tool.name description = tool.description or "" @@ -67,16 +72,23 @@ async def list_tools(self) -> List[ToolSchema]: async def call_tool( self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None ) -> ToolResult: - if not self._session: - raise RuntimeError("Session is not initialized. Call start() first.") + if not self._actor: + await self.start() # fallback to start the actor if not initialized instead of raising an error + # Why? Because when deserializing the workbench, the actor might not be initialized yet. + # raise RuntimeError("Actor is not initialized. Call start() first.") + if self._actor is None: + raise RuntimeError("Actor is not initialized. Please check the server connection.") if not cancellation_token: cancellation_token = CancellationToken() if not arguments: arguments = {} try: - result_future = asyncio.ensure_future(self._session.call_tool(name=name, arguments=dict(arguments))) + result_future = await self._actor.call("call_tool", {"name": name, "kargs": arguments}) cancellation_token.link_future(result_future) result = await result_future + assert isinstance( + result, CallToolResult + ), f"call_tool must return a CallToolResult, instead of : {str(type(result))}" result_parts: List[TextResultContent | ImageResultContent] = [] is_error = result.isError for content in result.content: @@ -110,50 +122,24 @@ def _format_errors(self, error: Exception) -> str: return error_message async def start(self) -> None: - if self._session: - raise RuntimeError("Session is already initialized. Call stop() first.") - - if isinstance(self._server_params, StdioServerParams): - read, write = await stdio_client(self._server_params).__aenter__() - self._read = read - self._write = write - session = await ClientSession( - read_stream=read, - write_stream=write, - read_timeout_seconds=timedelta(seconds=self._server_params.read_timeout_seconds), - ).__aenter__() - self._session = session - elif isinstance(self._server_params, SseServerParams): - read, write = await sse_client(**self._server_params.model_dump()).__aenter__() - self._read = read - self._write = write - session = await ClientSession(read_stream=read, write_stream=write).__aenter__() - self._session = session + if self._actor: + return # Already initialized, no need to start again + raise RuntimeError("Actor is already initialized. Call stop() first.") + + if isinstance(self._server_params, (StdioServerParams, SseServerParams)): + self._actor = McpSessionActor(self._server_params) + await self._actor.initialize() else: raise ValueError(f"Unsupported server params type: {type(self._server_params)}") async def stop(self) -> None: - if self._session: - # Close the session and streams in reverse order - await self._session.__aexit__(None, None, None) - self._session = None - - # If streams exist, close them - if hasattr(self, "_write") and self._write: - # Determine the context manager based on the server params type - if isinstance(self._server_params, StdioServerParams): - cm = stdio_client(self._server_params) - elif isinstance(self._server_params, SseServerParams): - cm = sse_client(**self._server_params.model_dump()) - else: - raise ValueError(f"Unsupported server params type: {type(self._server_params)}") - - # Exit the context manager to properly close streams - await cm.__aexit__(None, None, None) - self._read = None - self._write = None + if self._actor: + # Close the actor + # await self._session.__aexit__(None, None, None) + await self._actor.close() + self._actor = None else: - raise RuntimeError("Session is not initialized. Call start() first.") + raise RuntimeError("Actor is not initialized. Call start() first.") async def reset(self) -> None: pass diff --git a/python/packages/autogen-ext/tests/tools/test_mcp_tools.py b/python/packages/autogen-ext/tests/tools/test_mcp_tools.py index 998d678ab98f..f641a1ae3f67 100644 --- a/python/packages/autogen-ext/tests/tools/test_mcp_tools.py +++ b/python/packages/autogen-ext/tests/tools/test_mcp_tools.py @@ -12,6 +12,7 @@ StdioServerParams, create_mcp_server_session, mcp_server_tools, + McpWorkBench, ) from mcp import ClientSession, Tool @@ -422,3 +423,43 @@ async def test_mcp_server_github() -> None: {"owner": "microsoft", "repo": "autogen", "path": "python", "branch": "main"}, CancellationToken() ) assert result is not None + + +@pytest.mark.asyncio +async def test_mcp_workbench_start_stop(): + params = StdioServerParams( + command="uvx", + args=["mcp-server-fetch"], + read_timeout_seconds=60, + ) + + workbench = McpWorkBench(params) + assert workbench is not None + assert workbench._server_params == params + await workbench.start() + assert workbench._actor is not None + await workbench.stop() + assert workbench._actor is None + + +@pytest.mark.asyncio +async def test_mcp_workbench_server_fetch(): + params = StdioServerParams( + command="uvx", + args=["mcp-server-fetch"], + read_timeout_seconds=60, + ) + + workbench = McpWorkBench(server_params=params) + await workbench.start() + + tools = await workbench.list_tools() + assert tools is not None + assert tools[0]["name"] == "fetch" + + result = await workbench.call_tool(tools[0]["name"], {"url": "https://github.com/"}, CancellationToken()) + assert result is not None + + await workbench.stop() + + \ No newline at end of file