@@ -104,6 +104,14 @@ async def init_tool_sessions(
104104 ) -> None :
105105 pass
106106
107+ @abstractmethod
108+ async def __aenter__ (self ):
109+ pass
110+
111+ @abstractmethod
112+ async def __aexit__ (self , exc_type , exc , tb ):
113+ pass
114+
107115 @abstractmethod
108116 async def cleanup_session (self ) -> None :
109117 raise NotImplementedError ("Should not be called." )
@@ -146,6 +154,12 @@ async def init_tool_sessions(
146154 ) -> None :
147155 pass
148156
157+ async def __aenter__ (self ):
158+ return self
159+
160+ async def __aexit__ (self , exc_type , exc , tb ):
161+ pass
162+
149163 async def cleanup_session (self ) -> None :
150164 raise NotImplementedError ("Should not be called." )
151165
@@ -155,12 +169,17 @@ def __init__(
155169 self ,
156170 messages : list ,
157171 available_tools : list [str ],
172+ tool_server : Optional [ToolServer ],
158173 ):
159174 self ._messages = messages
160175 self .finish_reason : str | None = None
161176 self .available_tools = available_tools
162177 self ._tool_sessions : dict [str , ClientSession | Tool ] = {}
163178 self .called_tools : set [str ] = set ()
179+ self ._tool_server = tool_server
180+ self ._async_exit_stack : Optional [AsyncExitStack ] = None
181+ self ._reference_count = 0
182+ self ._reference_count_lock = asyncio .Lock ()
164183
165184 self .parser = get_streamable_parser_for_assistant ()
166185 self .num_init_messages = len (messages )
@@ -309,6 +328,18 @@ def need_builtin_tool_call(self) -> bool:
309328 or recipient .startswith ("container." )
310329 )
311330
331+ async def _get_tool_session (self , tool_name : str ) -> Union ["ClientSession" , Tool ]:
332+ if tool_name not in self ._tool_sessions and self ._tool_server is not None :
333+ assert self ._async_exit_stack is not None , (
334+ "Async exit stack not set. Please report this issue."
335+ )
336+ self ._tool_sessions [
337+ tool_name
338+ ] = await self ._async_exit_stack .enter_async_context (
339+ self ._tool_server .new_session (tool_name )
340+ )
341+ return self ._tool_sessions [tool_name ]
342+
312343 async def call_tool (self ) -> list [Message ]:
313344 if not self .messages :
314345 return []
@@ -317,15 +348,15 @@ async def call_tool(self) -> list[Message]:
317348 if recipient is not None :
318349 if recipient .startswith ("browser." ):
319350 return await self .call_search_tool (
320- self ._tool_sessions [ "browser" ] , last_msg
351+ await self ._get_tool_session ( "browser" ) , last_msg
321352 )
322353 elif recipient .startswith ("python" ):
323354 return await self .call_python_tool (
324- self ._tool_sessions [ "python" ] , last_msg
355+ await self ._get_tool_session ( "python" ) , last_msg
325356 )
326357 elif recipient .startswith ("container." ):
327358 return await self .call_container_tool (
328- self ._tool_sessions [ "container" ] , last_msg
359+ await self ._get_tool_session ( "container" ) , last_msg
329360 )
330361 raise ValueError ("No tool call found" )
331362
@@ -452,6 +483,25 @@ async def cleanup_tool_session(tool_session):
452483 )
453484 )
454485
486+ async def __aenter__ (self ):
487+ async with self ._reference_count_lock :
488+ self ._reference_count += 1
489+ if self ._async_exit_stack is None :
490+ assert self ._reference_count == 1 , (
491+ "Reference count of exit stack should be "
492+ )
493+ "1 when initializing exit stack."
494+ self ._async_exit_stack = AsyncExitStack ()
495+ await self ._async_exit_stack .__aenter__ ()
496+ return self
497+
498+ async def __aexit__ (self , exc_type , exc , tb ):
499+ async with self ._reference_count_lock :
500+ self ._reference_count -= 1
501+ if self ._reference_count == 0 and self ._async_exit_stack is not None :
502+ await self ._async_exit_stack .__aexit__ (exc_type , exc , tb )
503+ self ._async_exit_stack = None
504+
455505
456506class StreamingHarmonyContext (HarmonyContext ):
457507 def __init__ (self , * args , ** kwargs ):
0 commit comments