22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33import ast
44import json
5- import uuid
65from collections .abc import Sequence
76from typing import Any
87from xml .parsers .expat import ParserCreate
98
109import regex as re
1110
11+ from vllm .entrypoints .chat_utils import make_tool_call_id
1212from vllm .entrypoints .openai .protocol import (
1313 ChatCompletionRequest ,
1414 ChatCompletionToolsParam ,
@@ -375,14 +375,21 @@ def _find_next_complete_element(self, start_pos: int) -> tuple[str | None, int]:
375375 return buffer [: tag_end2 + 1 ], start_pos + tag_end2 + 1
376376 else :
377377 # If currently not parsing tool calls (entering a tool_call),
378- # check if starts with <tool_call>
378+ # check if starts with <tool_call> or <function=
379379 if self .current_call_id is None :
380380 # Check if might be start of <tool_call>
381381 if buffer == "<tool_call>" [: len (buffer )]:
382382 # Might be start of <tool_call>, wait for more data
383383 return None , start_pos
384+ elif (
385+ buffer .startswith ("<function=" )
386+ or buffer == "<function=" [: len (buffer )]
387+ ):
388+ # Might be start of <function=, wait for more data
389+ # to get the complete function tag
390+ return None , start_pos
384391 else :
385- # Not start of <tool_call>, treat as text
392+ # Not start of <tool_call> or <function= , treat as text
386393 return buffer , start_pos + len (buffer )
387394 else :
388395 # When parsing tool calls,
@@ -621,7 +628,7 @@ def _start_element(self, name: str, attrs: dict[str, str]):
621628 self ._auto_close_open_parameter_if_needed ("tool_call" )
622629
623630 self .parameters = {}
624- self .current_call_id = self . _get_next_call_id ()
631+ self .current_call_id = make_tool_call_id ()
625632 self .current_param_is_first = True
626633 self .tool_call_index += 1
627634 elif name .startswith ("function" ) or (name == "function" ):
@@ -957,10 +964,6 @@ def set_tools(self, tools: list[ChatCompletionToolsParam] | None):
957964 """Set tool configuration information"""
958965 self .tools = tools
959966
960- def _get_next_call_id (self ):
961- """Generate unique call ID"""
962- return f"call_{ uuid .uuid4 ().hex [:24 ]} "
963-
964967 def _extract_function_name (self , name : str , attrs : dict [str , str ]) -> str | None :
965968 """Extract function name from various formats"""
966969 if attrs and "name" in attrs :
@@ -1168,6 +1171,10 @@ def __init__(self, tokenizer: AnyTokenizer):
11681171 super ().__init__ (tokenizer )
11691172 self .parser = StreamingXMLToolCallParser ()
11701173
1174+ # Add missing attributes for compatibility with serving_chat.py
1175+ self .prev_tool_call_arr : list [dict ] = []
1176+ self .streamed_args_for_tool : list [str ] = []
1177+
11711178 logger .info (
11721179 "vLLM Successfully import tool parser %s !" , self .__class__ .__name__
11731180 )
@@ -1178,6 +1185,9 @@ def extract_tool_calls(
11781185 request : ChatCompletionRequest ,
11791186 ) -> ExtractedToolCallInformation :
11801187 self .parser .reset_streaming_state ()
1188+ # Reset tool call tracking arrays for new extraction
1189+ self .prev_tool_call_arr = []
1190+ self .streamed_args_for_tool = []
11811191 if request :
11821192 self .parser .set_tools (request .tools )
11831193 result = self .parser .parse_single_streaming_chunks (model_output )
@@ -1201,6 +1211,34 @@ def extract_tool_calls(
12011211 ),
12021212 )
12031213 )
1214+
1215+ # Update tool call tracking arrays for compatibility
1216+ tool_index = (
1217+ tool_call .index
1218+ if tool_call .index is not None
1219+ else len (self .prev_tool_call_arr ) - 1
1220+ )
1221+
1222+ # Ensure we have enough entries in our tracking arrays
1223+ while len (self .prev_tool_call_arr ) <= tool_index :
1224+ self .prev_tool_call_arr .append ({"name" : "" , "arguments" : "" })
1225+ while len (self .streamed_args_for_tool ) <= tool_index :
1226+ self .streamed_args_for_tool .append ("" )
1227+
1228+ # Update tool call information
1229+ self .prev_tool_call_arr [tool_index ]["name" ] = (
1230+ tool_call .function .name
1231+ )
1232+ self .prev_tool_call_arr [tool_index ]["arguments" ] = (
1233+ tool_call .function .arguments
1234+ )
1235+
1236+ # Update streamed arguments
1237+ if tool_call .function .arguments :
1238+ self .streamed_args_for_tool [tool_index ] = (
1239+ tool_call .function .arguments
1240+ )
1241+
12041242 return ExtractedToolCallInformation (
12051243 tool_calls = tool_calls ,
12061244 tools_called = len (tool_calls ) > 0 ,
@@ -1219,6 +1257,9 @@ def extract_tool_calls_streaming(
12191257 ) -> DeltaMessage | None :
12201258 if not previous_text :
12211259 self .parser .reset_streaming_state ()
1260+ # Reset tool call tracking arrays for new streaming session
1261+ self .prev_tool_call_arr = []
1262+ self .streamed_args_for_tool = []
12221263 if request :
12231264 self .parser .set_tools (request .tools )
12241265
@@ -1230,20 +1271,48 @@ def extract_tool_calls_streaming(
12301271 open_calls = current_text .count (
12311272 self .parser .tool_call_start_token
12321273 ) - current_text .count (self .parser .tool_call_end_token )
1233- if open_calls == 0 and self .parser .tool_call_index > 0 :
1234- # If current_call_id is None, use last_completed_call_id
1235- call_id = (
1236- self .parser .current_call_id or self .parser .last_completed_call_id
1237- )
1238- return DeltaMessage (
1239- tool_calls = [
1240- DeltaToolCall (
1241- index = self .parser .tool_call_index - 1 ,
1242- id = call_id ,
1243- function = DeltaFunctionCall (arguments = "" ),
1244- type = "function" ,
1274+ if (
1275+ open_calls == 0
1276+ and self .parser .tool_call_index > 0
1277+ or not self .parser .tool_call_index
1278+ and current_text
1279+ ):
1280+ return DeltaMessage (content = "" )
1281+ return None
1282+
1283+ # Parse the delta text and get the result
1284+ result = self .parser .parse_single_streaming_chunks (delta_text )
1285+
1286+ # Update tool call tracking arrays based on incremental parsing results
1287+ if result and result .tool_calls :
1288+ for tool_call in result .tool_calls :
1289+ if tool_call .function :
1290+ tool_index = (
1291+ tool_call .index
1292+ if tool_call .index is not None
1293+ else len (self .prev_tool_call_arr ) - 1
1294+ )
1295+
1296+ # Ensure we have enough entries in our tracking arrays
1297+ while len (self .prev_tool_call_arr ) <= tool_index :
1298+ self .prev_tool_call_arr .append ({"name" : "" , "arguments" : "" })
1299+ while len (self .streamed_args_for_tool ) <= tool_index :
1300+ self .streamed_args_for_tool .append ("" )
1301+
1302+ # Update tool name if provided
1303+ if tool_call .function .name :
1304+ self .prev_tool_call_arr [tool_index ]["name" ] = (
1305+ tool_call .function .name
12451306 )
1246- ]
1247- )
12481307
1249- return self .parser .parse_single_streaming_chunks (delta_text )
1308+ # Update arguments incrementally
1309+ if tool_call .function .arguments is not None :
1310+ # Concatenate the incremental arguments
1311+ # to the existing streamed arguments
1312+ self .prev_tool_call_arr [tool_index ]["arguments" ] += (
1313+ tool_call .function .arguments
1314+ )
1315+ self .streamed_args_for_tool [tool_index ] += (
1316+ tool_call .function .arguments
1317+ )
1318+ return result
0 commit comments