|
| 1 | +import asyncio |
| 2 | +import threading |
| 3 | +from mcp import ClientSession, StdioServerParameters |
| 4 | +from mcp.client.stdio import stdio_client |
| 5 | +from crewai.tools.base_tool import BaseTool |
| 6 | +from pydantic import BaseModel, Field, create_model, ConfigDict |
| 7 | +from typing import Type, Dict, Any, Union |
| 8 | +from contextlib import AsyncExitStack |
| 9 | + |
| 10 | +def create_pydantic_model_from_dict(model_name: str, schema_dict: Dict[str, Any]) -> Type[BaseModel]: |
| 11 | + fields = {} |
| 12 | + type_mapping = { |
| 13 | + 'string': str, |
| 14 | + 'number': float, |
| 15 | + 'integer': int, |
| 16 | + 'boolean': bool, |
| 17 | + 'object': dict, |
| 18 | + 'array': list, |
| 19 | + } |
| 20 | + properties = schema_dict.get('properties', {}) |
| 21 | + required_fields = schema_dict.get('required', []) |
| 22 | + for field_name, field_info in properties.items(): |
| 23 | + json_type = field_info.get('type', 'string') |
| 24 | + python_type = type_mapping.get(json_type, Any) |
| 25 | + description = field_info.get('description', '') |
| 26 | + default = field_info.get('default', ...) |
| 27 | + is_required = field_name in required_fields |
| 28 | + if not is_required: |
| 29 | + python_type = Union[python_type, None] |
| 30 | + default = None if default is ... else default |
| 31 | + field = (python_type, Field(default, description=description)) |
| 32 | + fields[field_name] = field |
| 33 | + model = create_model(model_name, **fields) |
| 34 | + return model |
| 35 | + |
| 36 | +class AsyncioEventLoopThread(threading.Thread): |
| 37 | + def __init__(self): |
| 38 | + super().__init__() |
| 39 | + self.loop = asyncio.new_event_loop() |
| 40 | + self._stop_event = threading.Event() |
| 41 | + def run(self): |
| 42 | + asyncio.set_event_loop(self.loop) |
| 43 | + self.loop.run_forever() |
| 44 | + def stop(self): |
| 45 | + self.loop.call_soon_threadsafe(self.loop.stop) |
| 46 | + self._stop_event.set() |
| 47 | + def schedule_coroutine(self, coro): |
| 48 | + return asyncio.run_coroutine_threadsafe(coro, self.loop) |
| 49 | + |
| 50 | +class MCPClient: |
| 51 | + def __init__(self, server_params: StdioServerParameters, loop_thread: AsyncioEventLoopThread): |
| 52 | + self.server_params = server_params |
| 53 | + self.loop_thread = loop_thread |
| 54 | + self.initialized = False |
| 55 | + self.client_session = None |
| 56 | + self.read = None |
| 57 | + self.write = None |
| 58 | + self._init_future = None |
| 59 | + self._exit_stack = AsyncExitStack() |
| 60 | + self._init_future = self.loop_thread.schedule_coroutine(self._async_init()) |
| 61 | + async def _async_init(self): |
| 62 | + await self._exit_stack.__aenter__() |
| 63 | + self.stdio_client = stdio_client(self.server_params) |
| 64 | + self.read, self.write = await self._exit_stack.enter_async_context(self.stdio_client) |
| 65 | + self.client_session = ClientSession(self.read, self.write) |
| 66 | + await self._exit_stack.enter_async_context(self.client_session) |
| 67 | + await self.client_session.initialize() |
| 68 | + self.initialized = True |
| 69 | + def call_tool(self, tool_name: str, tool_input: dict = None): |
| 70 | + future = self.loop_thread.schedule_coroutine(self._call_tool_async(tool_name, tool_input)) |
| 71 | + return future.result() |
| 72 | + async def _call_tool_async(self, tool_name: str, tool_input: dict = None): |
| 73 | + if not self.initialized: |
| 74 | + await asyncio.wrap_future(self._init_future) |
| 75 | + return await self.client_session.call_tool(tool_name, tool_input) |
| 76 | + def close(self): |
| 77 | + future = self.loop_thread.schedule_coroutine(self._async_close()) |
| 78 | + future.result() |
| 79 | + async def _async_close(self): |
| 80 | + await self._exit_stack.aclose() |
| 81 | + self.initialized = False |
| 82 | + |
| 83 | +class MCPTool(BaseTool): |
| 84 | + name: str |
| 85 | + description: str |
| 86 | + args_schema: Type[BaseModel] |
| 87 | + client: 'MCPClient' |
| 88 | + def __init__(self, name: str, description: str, args_schema: Type[BaseModel], client: 'MCPClient'): |
| 89 | + self.name = name |
| 90 | + self.description = description |
| 91 | + self.args_schema = args_schema |
| 92 | + self.client = client |
| 93 | + def _run(self, **kwargs): |
| 94 | + validated_inputs = self.args_schema(**kwargs) |
| 95 | + result = self.client.call_tool(self.name, validated_inputs.dict()) |
| 96 | + return result |
| 97 | + model_config = ConfigDict(arbitrary_types_allowed=True) |
| 98 | + |
| 99 | +def initialise_tools_sync(client: MCPClient): |
| 100 | + future = client.loop_thread.schedule_coroutine(initialise_tools(client)) |
| 101 | + return future.result() |
| 102 | + |
| 103 | +async def initialise_tools(client: MCPClient): |
| 104 | + if not client.initialized: |
| 105 | + await asyncio.wrap_future(client._init_future) |
| 106 | + tools_list = await client.client_session.list_tools() |
| 107 | + available_tools = [tool.model_dump() for tool in tools_list] |
| 108 | + mcp_tools = [] |
| 109 | + for tool in available_tools: |
| 110 | + mcp_tools.append(MCPTool( |
| 111 | + name=tool['name'], |
| 112 | + description=tool['description'], |
| 113 | + args_schema=create_pydantic_model_from_dict(f"{tool['name']}Input", tool['inputSchema']), |
| 114 | + client=client |
| 115 | + )) |
| 116 | + return mcp_tools |
| 117 | + |
| 118 | +class MCPStdioServerParams(StdioServerParameters): |
| 119 | + command: str |
| 120 | + args: list[str] = [] |
| 121 | + env: dict[str, str] = None |
| 122 | + |
| 123 | +def get_persistent_mcp_client(params: StdioServerParameters): |
| 124 | + loop_thread = AsyncioEventLoopThread() |
| 125 | + loop_thread.start() |
| 126 | + client = MCPClient(params, loop_thread) |
| 127 | + return client, loop_thread |
| 128 | + |
| 129 | +if __name__ == "__main__": |
| 130 | + params = MCPStdioServerParams( |
| 131 | + command="/opt/homebrew/bin/npx", |
| 132 | + args=["-y", "@modelcontextprotocol/server-filesystem", "/Users/burnerlee/Projects/dashwave/nucleon"] |
| 133 | + ) |
| 134 | + client, loop_thread = get_persistent_mcp_client(params) |
| 135 | + try: |
| 136 | + tools = initialise_tools_sync(client) |
| 137 | + selected_tool = tools[0] |
| 138 | + result = selected_tool._run( |
| 139 | + path="/Users/burnerlee/Projects/dashwave/nucleon/README.md" |
| 140 | + ) |
| 141 | + print(result) |
| 142 | + finally: |
| 143 | + client.close() |
| 144 | + loop_thread.stop() |
0 commit comments