Skip to content

Commit 5d1d8ee

Browse files
committed
call storage.search in user context search instead of memory.search
1 parent f8a8e7b commit 5d1d8ee

File tree

2 files changed

+145
-1
lines changed

2 files changed

+145
-1
lines changed

src/crewai/mcp/client.py

+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import asyncio
2+
import threading
3+
from mcp import ClientSession, StdioServerParameters
4+
from mcp.client.stdio import stdio_client
5+
from crewai.tools.base_tool import BaseTool
6+
from pydantic import BaseModel, Field, create_model, ConfigDict
7+
from typing import Type, Dict, Any, Union
8+
from contextlib import AsyncExitStack
9+
10+
def create_pydantic_model_from_dict(model_name: str, schema_dict: Dict[str, Any]) -> Type[BaseModel]:
11+
fields = {}
12+
type_mapping = {
13+
'string': str,
14+
'number': float,
15+
'integer': int,
16+
'boolean': bool,
17+
'object': dict,
18+
'array': list,
19+
}
20+
properties = schema_dict.get('properties', {})
21+
required_fields = schema_dict.get('required', [])
22+
for field_name, field_info in properties.items():
23+
json_type = field_info.get('type', 'string')
24+
python_type = type_mapping.get(json_type, Any)
25+
description = field_info.get('description', '')
26+
default = field_info.get('default', ...)
27+
is_required = field_name in required_fields
28+
if not is_required:
29+
python_type = Union[python_type, None]
30+
default = None if default is ... else default
31+
field = (python_type, Field(default, description=description))
32+
fields[field_name] = field
33+
model = create_model(model_name, **fields)
34+
return model
35+
36+
class AsyncioEventLoopThread(threading.Thread):
37+
def __init__(self):
38+
super().__init__()
39+
self.loop = asyncio.new_event_loop()
40+
self._stop_event = threading.Event()
41+
def run(self):
42+
asyncio.set_event_loop(self.loop)
43+
self.loop.run_forever()
44+
def stop(self):
45+
self.loop.call_soon_threadsafe(self.loop.stop)
46+
self._stop_event.set()
47+
def schedule_coroutine(self, coro):
48+
return asyncio.run_coroutine_threadsafe(coro, self.loop)
49+
50+
class MCPClient:
51+
def __init__(self, server_params: StdioServerParameters, loop_thread: AsyncioEventLoopThread):
52+
self.server_params = server_params
53+
self.loop_thread = loop_thread
54+
self.initialized = False
55+
self.client_session = None
56+
self.read = None
57+
self.write = None
58+
self._init_future = None
59+
self._exit_stack = AsyncExitStack()
60+
self._init_future = self.loop_thread.schedule_coroutine(self._async_init())
61+
async def _async_init(self):
62+
await self._exit_stack.__aenter__()
63+
self.stdio_client = stdio_client(self.server_params)
64+
self.read, self.write = await self._exit_stack.enter_async_context(self.stdio_client)
65+
self.client_session = ClientSession(self.read, self.write)
66+
await self._exit_stack.enter_async_context(self.client_session)
67+
await self.client_session.initialize()
68+
self.initialized = True
69+
def call_tool(self, tool_name: str, tool_input: dict = None):
70+
future = self.loop_thread.schedule_coroutine(self._call_tool_async(tool_name, tool_input))
71+
return future.result()
72+
async def _call_tool_async(self, tool_name: str, tool_input: dict = None):
73+
if not self.initialized:
74+
await asyncio.wrap_future(self._init_future)
75+
return await self.client_session.call_tool(tool_name, tool_input)
76+
def close(self):
77+
future = self.loop_thread.schedule_coroutine(self._async_close())
78+
future.result()
79+
async def _async_close(self):
80+
await self._exit_stack.aclose()
81+
self.initialized = False
82+
83+
class MCPTool(BaseTool):
84+
name: str
85+
description: str
86+
args_schema: Type[BaseModel]
87+
client: 'MCPClient'
88+
def __init__(self, name: str, description: str, args_schema: Type[BaseModel], client: 'MCPClient'):
89+
self.name = name
90+
self.description = description
91+
self.args_schema = args_schema
92+
self.client = client
93+
def _run(self, **kwargs):
94+
validated_inputs = self.args_schema(**kwargs)
95+
result = self.client.call_tool(self.name, validated_inputs.dict())
96+
return result
97+
model_config = ConfigDict(arbitrary_types_allowed=True)
98+
99+
def initialise_tools_sync(client: MCPClient):
100+
future = client.loop_thread.schedule_coroutine(initialise_tools(client))
101+
return future.result()
102+
103+
async def initialise_tools(client: MCPClient):
104+
if not client.initialized:
105+
await asyncio.wrap_future(client._init_future)
106+
tools_list = await client.client_session.list_tools()
107+
available_tools = [tool.model_dump() for tool in tools_list]
108+
mcp_tools = []
109+
for tool in available_tools:
110+
mcp_tools.append(MCPTool(
111+
name=tool['name'],
112+
description=tool['description'],
113+
args_schema=create_pydantic_model_from_dict(f"{tool['name']}Input", tool['inputSchema']),
114+
client=client
115+
))
116+
return mcp_tools
117+
118+
class MCPStdioServerParams(StdioServerParameters):
119+
command: str
120+
args: list[str] = []
121+
env: dict[str, str] = None
122+
123+
def get_persistent_mcp_client(params: StdioServerParameters):
124+
loop_thread = AsyncioEventLoopThread()
125+
loop_thread.start()
126+
client = MCPClient(params, loop_thread)
127+
return client, loop_thread
128+
129+
if __name__ == "__main__":
130+
params = MCPStdioServerParams(
131+
command="/opt/homebrew/bin/npx",
132+
args=["-y", "@modelcontextprotocol/server-filesystem", "/Users/burnerlee/Projects/dashwave/nucleon"]
133+
)
134+
client, loop_thread = get_persistent_mcp_client(params)
135+
try:
136+
tools = initialise_tools_sync(client)
137+
selected_tool = tools[0]
138+
result = selected_tool._run(
139+
path="/Users/burnerlee/Projects/dashwave/nucleon/README.md"
140+
)
141+
print(result)
142+
finally:
143+
client.close()
144+
loop_thread.stop()

src/crewai/memory/user/user_memory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def search(
3737
limit: int = 3,
3838
score_threshold: float = 0.35,
3939
):
40-
results = super().search(
40+
results = self.storage.search(
4141
query=query,
4242
limit=limit,
4343
score_threshold=score_threshold,

0 commit comments

Comments
 (0)