11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ import asyncio
4+ import contextlib
35import json
46import logging
57from abc import ABC , abstractmethod
@@ -57,9 +59,14 @@ def render_for_completion(self) -> list[int]:
5759
5860 @abstractmethod
5961 async def init_tool_sessions (self , tool_server : Optional [ToolServer ],
60- exit_stack : AsyncExitStack ) -> None :
62+ exit_stack : AsyncExitStack ,
63+ request_id : str ) -> None :
6164 pass
6265
66+ @abstractmethod
67+ async def cleanup_session (self ) -> None :
68+ raise NotImplementedError ("Should not be called." )
69+
6370
6471class SimpleContext (ConversationContext ):
6572
@@ -89,9 +96,13 @@ def render_for_completion(self) -> list[int]:
8996 raise NotImplementedError ("Should not be called." )
9097
9198 async def init_tool_sessions (self , tool_server : Optional [ToolServer ],
92- exit_stack : AsyncExitStack ) -> None :
99+ exit_stack : AsyncExitStack ,
100+ request_id : str ) -> None :
93101 pass
94102
103+ async def cleanup_session (self ) -> None :
104+ raise NotImplementedError ("Should not be called." )
105+
95106
96107class HarmonyContext (ConversationContext ):
97108
@@ -103,6 +114,7 @@ def __init__(
103114 self ._messages = messages
104115 self .available_tools = available_tools
105116 self ._tool_sessions : dict [str , Union [ClientSession , Tool ]] = {}
117+ self .called_tools : set [str ] = set ()
106118
107119 self .parser = get_streamable_parser_for_assistant ()
108120 self .num_init_messages = len (messages )
@@ -234,7 +246,8 @@ def need_builtin_tool_call(self) -> bool:
234246 last_msg = self .messages [- 1 ]
235247 recipient = last_msg .recipient
236248 return recipient is not None and (recipient .startswith ("browser." )
237- or recipient .startswith ("python" ))
249+ or recipient .startswith ("python" ) or
250+ recipient .startswith ("container." ))
238251
239252 async def call_tool (self ) -> list [Message ]:
240253 if not self .messages :
@@ -248,6 +261,9 @@ async def call_tool(self) -> list[Message]:
248261 elif recipient .startswith ("python" ):
249262 return await self .call_python_tool (
250263 self ._tool_sessions ["python" ], last_msg )
264+ elif recipient .startswith ("container." ):
265+ return await self .call_container_tool (
266+ self ._tool_sessions ["container" ], last_msg )
251267 raise ValueError ("No tool call found" )
252268
253269 def render_for_completion (self ) -> list [int ]:
@@ -256,6 +272,7 @@ def render_for_completion(self) -> list[int]:
256272 async def call_search_tool (self , tool_session : Union ["ClientSession" ,
257273 Tool ],
258274 last_msg : Message ) -> list [Message ]:
275+ self .called_tools .add ("browser" )
259276 if isinstance (tool_session , Tool ):
260277 return await tool_session .get_result (self )
261278 tool_name = last_msg .recipient .split ("." )[1 ]
@@ -265,12 +282,16 @@ async def call_search_tool(self, tool_session: Union["ClientSession",
265282 content = TextContent (text = result_str )
266283 author = Author (role = Role .TOOL , name = last_msg .recipient )
267284 return [
268- Message (author = author , content = [content ], recipient = Role .ASSISTANT )
285+ Message (author = author ,
286+ content = [content ],
287+ recipient = Role .ASSISTANT ,
288+ channel = last_msg .channel )
269289 ]
270290
271291 async def call_python_tool (self , tool_session : Union ["ClientSession" ,
272292 Tool ],
273293 last_msg : Message ) -> list [Message ]:
294+ self .called_tools .add ("python" )
274295 if isinstance (tool_session , Tool ):
275296 return await tool_session .get_result (self )
276297 param = {
@@ -290,13 +311,63 @@ async def call_python_tool(self, tool_session: Union["ClientSession",
290311 ]
291312
292313 async def init_tool_sessions (self , tool_server : Optional [ToolServer ],
293- exit_stack : AsyncExitStack ) -> None :
314+ exit_stack : AsyncExitStack ,
315+ request_id : str ) -> None :
294316 if tool_server :
295317 for tool_name in self .available_tools :
296318 if tool_name not in self ._tool_sessions :
297- self ._tool_sessions [
298- tool_name ] = await exit_stack .enter_async_context (
299- tool_server .new_session (tool_name ))
319+ tool_session = await exit_stack .enter_async_context (
320+ tool_server .new_session (tool_name , request_id ))
321+ self ._tool_sessions [tool_name ] = tool_session
322+ exit_stack .push_async_exit (self .cleanup_session )
323+
324+ async def call_container_tool (self , tool_session : Union ["ClientSession" ,
325+ Tool ],
326+ last_msg : Message ) -> list [Message ]:
327+ """
328+ Call container tool. Expect this to be run in a stateful docker
329+ with command line terminal.
330+ The official container tool would at least
331+ expect the following format:
332+ - for tool name: exec
333+ - args:
334+ {
335+ "cmd":List[str] "command to execute",
336+ "workdir":optional[str] "current working directory",
337+ "env":optional[object/dict] "environment variables",
338+ "session_name":optional[str] "session name",
339+ "timeout":optional[int] "timeout in seconds",
340+ "user":optional[str] "user name",
341+ }
342+ """
343+ self .called_tools .add ("container" )
344+ if isinstance (tool_session , Tool ):
345+ return await tool_session .get_result (self )
346+ tool_name = last_msg .recipient .split ("." )[1 ].split (" " )[0 ]
347+ args = json .loads (last_msg .content [0 ].text )
348+ result = await tool_session .call_tool (tool_name , args )
349+ result_str = result .content [0 ].text
350+ content = TextContent (text = result_str )
351+ author = Author (role = Role .TOOL , name = last_msg .recipient )
352+ return [
353+ Message (author = author ,
354+ content = [content ],
355+ recipient = Role .ASSISTANT ,
356+ channel = last_msg .channel )
357+ ]
358+
359+ async def cleanup_session (self , * args , ** kwargs ) -> None :
360+ """Can be used as coro to used in __aexit__"""
361+
362+ async def cleanup_tool_session (tool_session ):
363+ if not isinstance (tool_session , Tool ):
364+ logger .info ("Cleaning up tool session for %s" ,
365+ tool_session ._client_info )
366+ with contextlib .suppress (Exception ):
367+ await tool_session .call_tool ("cleanup_session" , {})
368+
369+ await asyncio .gather (* (cleanup_tool_session (self ._tool_sessions [tool ])
370+ for tool in self .called_tools ))
300371
301372
302373class StreamingHarmonyContext (HarmonyContext ):
0 commit comments