Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._actor import McpSessionActor
from ._config import McpServerParams, SseServerParams, StdioServerParams
from ._factory import mcp_server_tools
from ._session import create_mcp_server_session
Expand All @@ -7,6 +8,7 @@

__all__ = [
"create_mcp_server_session",
"McpSessionActor",
"StdioMcpToolAdapter",
"StdioServerParams",
"SseMcpToolAdapter",
Expand Down
147 changes: 147 additions & 0 deletions python/packages/autogen-ext/src/autogen_ext/tools/mcp/_actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import asyncio
import atexit
from typing import Any, Coroutine, Dict, Mapping, TypedDict

from autogen_core import Component, ComponentBase
from mcp.types import CallToolResult, ListToolsResult
from pydantic import BaseModel
from typing_extensions import Self

from ._config import McpServerParams
from ._session import create_mcp_server_session

McpResult = Coroutine[Any, Any, ListToolsResult] | Coroutine[Any, Any, CallToolResult]
McpFuture = asyncio.Future[McpResult]


class McpActorArgs(TypedDict):
name: str | None
kargs: Mapping[str, Any]


class McpSessionActorConfig(BaseModel):
server_params: McpServerParams


class McpSessionActor(ComponentBase[BaseModel], Component[McpSessionActorConfig]):
component_type = "mcp_session_actor"
component_config_schema = McpSessionActorConfig
component_provider_override = "autogen_ext.tools.mcp.McpSessionActor"

server_params: McpServerParams

# model_config = ConfigDict(arbitrary_types_allowed=True)

def __init__(self, server_params: McpServerParams) -> None:
self.server_params: McpServerParams = server_params
self.name = "mcp_session_actor"
self.description = "MCP session actor"
self._command_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
self._actor_task: asyncio.Task[Any] | None = None
self._shutdown_future: asyncio.Future[Any] | None = None
self._active = False
atexit.register(self._sync_shutdown)

async def initialize(self) -> None:
if not self._active:
self._active = True
self._actor_task = asyncio.create_task(self._run_actor())

async def call(self, type: str, args: McpActorArgs | None = None) -> McpFuture:
if not self._active:
raise RuntimeError("MCP Actor not running, call initialize() first")
if self._actor_task and self._actor_task.done():
raise RuntimeError("MCP actor task crashed", self._actor_task.exception())
fut: asyncio.Future[McpFuture] = asyncio.Future()
if type in {"list_tools", "shutdown"}:
await self._command_queue.put({"type": type, "future": fut})
res = await fut
elif type == "call_tool":
if args is None:
raise ValueError("args is required for call_tool")
name = args.get("name", None)
kwargs = args.get("kargs", {})
if name is None:
raise ValueError("name is required for call_tool")
await self._command_queue.put({"type": type, "name": name, "args": kwargs, "future": fut})
res = await fut
else:
raise ValueError(f"Unknown command type: {type}")
return res

async def close(self) -> None:
if not self._active or self._actor_task is None:
return
self._shutdown_future = asyncio.Future()
await self._command_queue.put({"type": "shutdown", "future": self._shutdown_future})
await self._shutdown_future
await self._actor_task
self._active = False

async def _run_actor(self) -> None:
result: McpResult
try:
async with create_mcp_server_session(self.server_params) as session:
await session.initialize()
while True:
cmd = await self._command_queue.get()
if cmd["type"] == "shutdown":
cmd["future"].set_result("ok")
break
elif cmd["type"] == "call_tool":
try:
result = session.call_tool(name=cmd["name"], arguments=cmd["args"])
cmd["future"].set_result(result)
except Exception as e:
cmd["future"].set_exception(e)
elif cmd["type"] == "list_tools":
try:
result = session.list_tools()
cmd["future"].set_result(result)
except Exception as e:
cmd["future"].set_exception(e)
except Exception as e:
if self._shutdown_future and not self._shutdown_future.done():
self._shutdown_future.set_exception(e)
finally:
self._active = False
self._actor_task = None

def _sync_shutdown(self) -> None:
if not self._active or self._actor_task is None:
return
try:
loop = asyncio.get_event_loop()
except RuntimeError:
# No loop available — interpreter is likely shutting down
return

if loop.is_closed():
return

if loop.is_running():
loop.create_task(self.close())
else:
loop.run_until_complete(self.close())

def _to_config(self) -> McpSessionActorConfig:
"""
Convert the adapter to its configuration representation.

Returns:
McpSessionConfig: The configuration of the adapter.
"""
return McpSessionActorConfig(server_params=self.server_params)

@classmethod
def _from_config(cls, config: McpSessionActorConfig) -> Self:
"""
Create an instance of McpSessionActor from its configuration.

Args:
config (McpSessionConfig): The configuration of the adapter.

Returns:
McpSessionActor: An instance of SseMcpToolAdapter.
"""
return cls(server_params=config.server_params)
98 changes: 42 additions & 56 deletions python/packages/autogen-ext/src/autogen_ext/tools/mcp/_workbench.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import asyncio
import builtins
from datetime import timedelta
from typing import Any, Dict, List, Literal, Mapping, Optional
from typing import Any, List, Literal, Mapping

from autogen_core import CancellationToken, Component, ComponentModel, Image
from autogen_core import CancellationToken, Component, Image
from autogen_core.tools import (
ImageResultContent,
ParametersSchema,
Expand All @@ -12,13 +10,11 @@
ToolSchema,
WorkBench,
)
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.types import EmbeddedResource, ImageContent, TextContent
from mcp.types import CallToolResult, EmbeddedResource, ImageContent, ListToolsResult, TextContent
from pydantic import BaseModel
from typing_extensions import Annotated, Self
from typing_extensions import Self

from ._actor import McpSessionActor
from ._config import McpServerParams, SseServerParams, StdioServerParams


Expand All @@ -38,15 +34,24 @@ class McpWorkBench(WorkBench, Component[McpWorkBenchConfig]):

def __init__(self, server_params: McpServerParams) -> None:
self._server_params = server_params
self._session: ClientSession | None = None
# self._session: ClientSession | None = None
self._actor: McpSessionActor | None = None
self._read = None
self._write = None

async def list_tools(self) -> List[ToolSchema]:
if not self._session:
raise RuntimeError("Session is not initialized. Call start() first.")
list_tool_result = await self._session.list_tools()
schema = []
if not self._actor:
await self.start() # fallback to start the actor if not initialized instead of raising an error
# Why? Because when deserializing the workbench, the actor might not be initialized yet.
# raise RuntimeError("Actor is not initialized. Call start() first.")
if self._actor is None:
raise RuntimeError("Actor is not initialized. Please check the server connection.")
result_future = await self._actor.call("list_tools", None)
list_tool_result = await result_future
assert isinstance(
list_tool_result, ListToolsResult
), f"list_tools must return a CallToolResult, instead of : {str(type(list_tool_result))}"
schema: List[ToolSchema] = []
for tool in list_tool_result.tools:
name = tool.name
description = tool.description or ""
Expand All @@ -67,16 +72,23 @@ async def list_tools(self) -> List[ToolSchema]:
async def call_tool(
self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None
) -> ToolResult:
if not self._session:
raise RuntimeError("Session is not initialized. Call start() first.")
if not self._actor:
await self.start() # fallback to start the actor if not initialized instead of raising an error
# Why? Because when deserializing the workbench, the actor might not be initialized yet.
# raise RuntimeError("Actor is not initialized. Call start() first.")
if self._actor is None:
raise RuntimeError("Actor is not initialized. Please check the server connection.")
if not cancellation_token:
cancellation_token = CancellationToken()
if not arguments:
arguments = {}
try:
result_future = asyncio.ensure_future(self._session.call_tool(name=name, arguments=dict(arguments)))
result_future = await self._actor.call("call_tool", {"name": name, "kargs": arguments})
cancellation_token.link_future(result_future)
result = await result_future
assert isinstance(
result, CallToolResult
), f"call_tool must return a CallToolResult, instead of : {str(type(result))}"
result_parts: List[TextResultContent | ImageResultContent] = []
is_error = result.isError
for content in result.content:
Expand Down Expand Up @@ -110,50 +122,24 @@ def _format_errors(self, error: Exception) -> str:
return error_message

async def start(self) -> None:
if self._session:
raise RuntimeError("Session is already initialized. Call stop() first.")

if isinstance(self._server_params, StdioServerParams):
read, write = await stdio_client(self._server_params).__aenter__()
self._read = read
self._write = write
session = await ClientSession(
read_stream=read,
write_stream=write,
read_timeout_seconds=timedelta(seconds=self._server_params.read_timeout_seconds),
).__aenter__()
self._session = session
elif isinstance(self._server_params, SseServerParams):
read, write = await sse_client(**self._server_params.model_dump()).__aenter__()
self._read = read
self._write = write
session = await ClientSession(read_stream=read, write_stream=write).__aenter__()
self._session = session
if self._actor:
return # Already initialized, no need to start again
raise RuntimeError("Actor is already initialized. Call stop() first.")

if isinstance(self._server_params, (StdioServerParams, SseServerParams)):
self._actor = McpSessionActor(self._server_params)
await self._actor.initialize()
else:
raise ValueError(f"Unsupported server params type: {type(self._server_params)}")

async def stop(self) -> None:
if self._session:
# Close the session and streams in reverse order
await self._session.__aexit__(None, None, None)
self._session = None

# If streams exist, close them
if hasattr(self, "_write") and self._write:
# Determine the context manager based on the server params type
if isinstance(self._server_params, StdioServerParams):
cm = stdio_client(self._server_params)
elif isinstance(self._server_params, SseServerParams):
cm = sse_client(**self._server_params.model_dump())
else:
raise ValueError(f"Unsupported server params type: {type(self._server_params)}")

# Exit the context manager to properly close streams
await cm.__aexit__(None, None, None)
self._read = None
self._write = None
if self._actor:
# Close the actor
# await self._session.__aexit__(None, None, None)
await self._actor.close()
self._actor = None
else:
raise RuntimeError("Session is not initialized. Call start() first.")
raise RuntimeError("Actor is not initialized. Call start() first.")

async def reset(self) -> None:
pass
Expand Down
41 changes: 41 additions & 0 deletions python/packages/autogen-ext/tests/tools/test_mcp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
StdioServerParams,
create_mcp_server_session,
mcp_server_tools,
McpWorkBench,
)
from mcp import ClientSession, Tool

Expand Down Expand Up @@ -422,3 +423,43 @@ async def test_mcp_server_github() -> None:
{"owner": "microsoft", "repo": "autogen", "path": "python", "branch": "main"}, CancellationToken()
)
assert result is not None


@pytest.mark.asyncio
async def test_mcp_workbench_start_stop():
params = StdioServerParams(
command="uvx",
args=["mcp-server-fetch"],
read_timeout_seconds=60,
)

workbench = McpWorkBench(params)
assert workbench is not None
assert workbench._server_params == params
await workbench.start()
assert workbench._actor is not None
await workbench.stop()
assert workbench._actor is None


@pytest.mark.asyncio
async def test_mcp_workbench_server_fetch():
params = StdioServerParams(
command="uvx",
args=["mcp-server-fetch"],
read_timeout_seconds=60,
)

workbench = McpWorkBench(server_params=params)
await workbench.start()

tools = await workbench.list_tools()
assert tools is not None
assert tools[0]["name"] == "fetch"

result = await workbench.call_tool(tools[0]["name"], {"url": "https://github.com/"}, CancellationToken())
assert result is not None

await workbench.stop()