@@ -89,6 +89,14 @@ async def init_tool_sessions(
8989 ) -> None :
9090 pass
9191
92+ @abstractmethod
93+ async def __aenter__ (self ):
94+ pass
95+
96+ @abstractmethod
97+ async def __aexit__ (self , exc_type , exc , tb ):
98+ pass
99+
92100 @abstractmethod
93101 async def cleanup_session (self ) -> None :
94102 raise NotImplementedError ("Should not be called." )
@@ -129,6 +137,12 @@ async def init_tool_sessions(
129137 ) -> None :
130138 pass
131139
140+ async def __aenter__ (self ):
141+ return self
142+
143+ async def __aexit__ (self , exc_type , exc , tb ):
144+ pass
145+
132146 async def cleanup_session (self ) -> None :
133147 raise NotImplementedError ("Should not be called." )
134148
@@ -138,12 +152,17 @@ def __init__(
138152 self ,
139153 messages : list ,
140154 available_tools : list [str ],
155+ tool_server : Optional [ToolServer ],
141156 ):
142157 self ._messages = messages
143158 self .finish_reason : Optional [str ] = None
144159 self .available_tools = available_tools
145160 self ._tool_sessions : dict [str , Union [ClientSession , Tool ]] = {}
146161 self .called_tools : set [str ] = set ()
162+ self ._tool_server = tool_server
163+ self ._async_exit_stack : Optional [AsyncExitStack ] = None
164+ self ._reference_count = 0
165+ self ._reference_count_lock = asyncio .Lock ()
147166
148167 self .parser = get_streamable_parser_for_assistant ()
149168 self .num_init_messages = len (messages )
@@ -288,6 +307,18 @@ def need_builtin_tool_call(self) -> bool:
288307 or recipient .startswith ("container." )
289308 )
290309
310+ async def _get_tool_session (self , tool_name : str ) -> Union ["ClientSession" , Tool ]:
311+ if tool_name not in self ._tool_sessions and self ._tool_server is not None :
312+ assert self ._async_exit_stack is not None , (
313+ "Async exit stack not set. Please report this issue."
314+ )
315+ self ._tool_sessions [
316+ tool_name
317+ ] = await self ._async_exit_stack .enter_async_context (
318+ self ._tool_server .new_session (tool_name )
319+ )
320+ return self ._tool_sessions [tool_name ]
321+
291322 async def call_tool (self ) -> list [Message ]:
292323 if not self .messages :
293324 return []
@@ -296,15 +327,15 @@ async def call_tool(self) -> list[Message]:
296327 if recipient is not None :
297328 if recipient .startswith ("browser." ):
298329 return await self .call_search_tool (
299- self ._tool_sessions [ "browser" ] , last_msg
330+ await self ._get_tool_session ( "browser" ) , last_msg
300331 )
301332 elif recipient .startswith ("python" ):
302333 return await self .call_python_tool (
303- self ._tool_sessions [ "python" ] , last_msg
334+ await self ._get_tool_session ( "python" ) , last_msg
304335 )
305336 elif recipient .startswith ("container." ):
306337 return await self .call_container_tool (
307- self ._tool_sessions [ "container" ] , last_msg
338+ await self ._get_tool_session ( "container" ) , last_msg
308339 )
309340 raise ValueError ("No tool call found" )
310341
@@ -431,6 +462,25 @@ async def cleanup_tool_session(tool_session):
431462 )
432463 )
433464
465+ async def __aenter__ (self ):
466+ async with self ._reference_count_lock :
467+ self ._reference_count += 1
468+ if self ._async_exit_stack is None :
469+ assert self ._reference_count == 1 , (
470+ "Reference count of exit stack should be "
471+ )
472+ "1 when initializing exit stack."
473+ self ._async_exit_stack = AsyncExitStack ()
474+ await self ._async_exit_stack .__aenter__ ()
475+ return self
476+
477+ async def __aexit__ (self , exc_type , exc , tb ):
478+ async with self ._reference_count_lock :
479+ self ._reference_count -= 1
480+ if self ._reference_count == 0 and self ._async_exit_stack is not None :
481+ await self ._async_exit_stack .__aexit__ (exc_type , exc , tb )
482+ self ._async_exit_stack = None
483+
434484
435485class StreamingHarmonyContext (HarmonyContext ):
436486 def __init__ (self , * args , ** kwargs ):
0 commit comments