@@ -79,16 +79,6 @@ def need_builtin_tool_call(self) -> bool:
7979 def render_for_completion (self ) -> list [int ]:
8080 pass
8181
82- @abstractmethod
83- async def init_tool_sessions (
84- self ,
85- tool_server : Optional [ToolServer ],
86- exit_stack : AsyncExitStack ,
87- request_id : str ,
88- mcp_tools : dict [str , Mcp ],
89- ) -> None :
90- pass
91-
9282 @abstractmethod
9383 async def __aenter__ (self ):
9484 pass
@@ -128,15 +118,6 @@ async def call_tool(self) -> list[Message]:
128118 def render_for_completion (self ) -> list [int ]:
129119 raise NotImplementedError ("Should not be called." )
130120
131- async def init_tool_sessions (
132- self ,
133- tool_server : Optional [ToolServer ],
134- exit_stack : AsyncExitStack ,
135- request_id : str ,
136- mcp_tools : dict [str , Mcp ],
137- ) -> None :
138- pass
139-
140121 async def __aenter__ (self ):
141122 return self
142123
@@ -153,13 +134,17 @@ def __init__(
153134 messages : list ,
154135 available_tools : list [str ],
155136 tool_server : Optional [ToolServer ],
137+ request_id : str ,
138+ mcp_tools : dict [str , Mcp ],
156139 ):
157140 self ._messages = messages
158141 self .finish_reason : Optional [str ] = None
159142 self .available_tools = available_tools
160143 self ._tool_sessions : dict [str , Union [ClientSession , Tool ]] = {}
161144 self .called_tools : set [str ] = set ()
162145 self ._tool_server = tool_server
146+ self .request_id = request_id
147+ self .mcp_tools = mcp_tools
163148 self ._async_exit_stack : Optional [AsyncExitStack ] = None
164149 self ._reference_count = 0
165150 self ._reference_count_lock = asyncio .Lock ()
@@ -307,18 +292,6 @@ def need_builtin_tool_call(self) -> bool:
307292 or recipient .startswith ("container." )
308293 )
309294
310- async def _get_tool_session (self , tool_name : str ) -> Union ["ClientSession" , Tool ]:
311- if tool_name not in self ._tool_sessions and self ._tool_server is not None :
312- assert self ._async_exit_stack is not None , (
313- "Async exit stack not set. Please report this issue."
314- )
315- self ._tool_sessions [
316- tool_name
317- ] = await self ._async_exit_stack .enter_async_context (
318- self ._tool_server .new_session (tool_name )
319- )
320- return self ._tool_sessions [tool_name ]
321-
322295 async def call_tool (self ) -> list [Message ]:
323296 if not self .messages :
324297 return []
@@ -342,6 +315,24 @@ async def call_tool(self) -> list[Message]:
342315 def render_for_completion (self ) -> list [int ]:
343316 return render_for_completion (self .messages )
344317
318+ async def _get_tool_session (self , tool_name : str ) -> Union ["ClientSession" , Tool ]:
319+ if tool_name not in self ._tool_sessions and self ._tool_server is not None :
320+ assert self ._async_exit_stack is not None , (
321+ "Async exit stack not set. Please report this issue."
322+ )
323+ tool_type = _map_tool_name_to_tool_type (tool_name )
324+ headers = (
325+ self .mcp_tools [tool_type ].headers
326+ if tool_type in self .mcp_tools
327+ else None
328+ )
329+ self ._tool_sessions [
330+ tool_name
331+ ] = await self ._async_exit_stack .enter_async_context (
332+ self ._tool_server .new_session (tool_name , self .request_id , headers )
333+ )
334+ return self ._tool_sessions [tool_name ]
335+
345336 async def call_search_tool (
346337 self , tool_session : Union ["ClientSession" , Tool ], last_msg : Message
347338 ) -> list [Message ]:
@@ -387,26 +378,6 @@ async def call_python_tool(
387378 )
388379 ]
389380
390- async def init_tool_sessions (
391- self ,
392- tool_server : Optional [ToolServer ],
393- exit_stack : AsyncExitStack ,
394- request_id : str ,
395- mcp_tools : dict [str , Mcp ],
396- ):
397- if tool_server :
398- for tool_name in self .available_tools :
399- if tool_name not in self ._tool_sessions :
400- tool_type = _map_tool_name_to_tool_type (tool_name )
401- headers = (
402- mcp_tools [tool_type ].headers if tool_type in mcp_tools else None
403- )
404- tool_session = await exit_stack .enter_async_context (
405- tool_server .new_session (tool_name , request_id , headers )
406- )
407- self ._tool_sessions [tool_name ] = tool_session
408- exit_stack .push_async_exit (self .cleanup_session )
409-
410381 async def call_container_tool (
411382 self , tool_session : Union ["ClientSession" , Tool ], last_msg : Message
412383 ) -> list [Message ]:
@@ -468,10 +439,11 @@ async def __aenter__(self):
468439 if self ._async_exit_stack is None :
469440 assert self ._reference_count == 1 , (
470441 "Reference count of exit stack should be "
442+ "1 when initializing exit stack."
471443 )
472- "1 when initializing exit stack."
473444 self ._async_exit_stack = AsyncExitStack ()
474445 await self ._async_exit_stack .__aenter__ ()
446+ self ._async_exit_stack .push_async_callback (self .cleanup_session )
475447 return self
476448
477449 async def __aexit__ (self , exc_type , exc , tb ):
0 commit comments