Skip to content

Commit 9354660

Browse files
authored
[Bugfix]fix Qwen3 xml tool parser (#26345)
Signed-off-by: Zhikaiiii <1658973216@qq.com>
1 parent 07ca70a commit 9354660

File tree

2 files changed

+178
-25
lines changed

2 files changed

+178
-25
lines changed

tests/tool_use/test_qwen3coder_tool_parser.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def qwen3_xml_tool_parser(qwen3_tokenizer):
4040
return Qwen3XMLToolParser(qwen3_tokenizer)
4141

4242

43-
@pytest.fixture(params=["original", "xml"])
43+
@pytest.fixture(params=["xml"])
4444
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, request):
4545
"""Parameterized fixture that provides both parser types for testing"""
4646
if request.param == "original":
@@ -664,6 +664,9 @@ def test_extract_tool_calls_streaming(
664664

665665
# Verify we got all expected tool calls
666666
assert len(tool_states) == len(expected_tool_calls)
667+
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == len(
668+
expected_tool_calls
669+
)
667670

668671
# Verify each tool call
669672
for idx, expected_tool in enumerate(expected_tool_calls):
@@ -780,9 +783,10 @@ def test_extract_tool_calls_streaming_missing_closing_tag(
780783

781784
# Verify content was streamed
782785
assert "Let me check the weather for you:" in other_content
783-
784786
# Verify we got the tool call
785787
assert len(tool_states) == 1
788+
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1
789+
786790
state = tool_states[0]
787791
assert state["id"] is not None
788792
assert state["type"] == "function"
@@ -892,3 +896,83 @@ def test_extract_tool_calls_complex_type_with_single_quote(
892896

893897
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
894898
assert args["obj_param"] == {"key": "value"}
899+
900+
901+
def test_extract_tool_calls_streaming_missing_opening_tag(
902+
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
903+
):
904+
"""Test streaming with missing opening <tool_call> tag
905+
906+
This tests that the streaming parser correctly handles
907+
tool calls that start directly with <function=...>
908+
"""
909+
model_output = """I'll check the weather for you.
910+
911+
<function=get_current_weather>
912+
<parameter=city>
913+
Dallas
914+
</parameter>
915+
<parameter=state>
916+
TX
917+
</parameter>
918+
<parameter=unit>
919+
fahrenheit
920+
</parameter>
921+
</function>
922+
</tool_call>"""
923+
924+
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
925+
926+
other_content = ""
927+
tool_states = {}
928+
929+
for delta_message in stream_delta_message_generator(
930+
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
931+
):
932+
if delta_message.content:
933+
other_content += delta_message.content
934+
935+
if delta_message.tool_calls:
936+
for tool_call in delta_message.tool_calls:
937+
idx = tool_call.index
938+
939+
if idx not in tool_states:
940+
tool_states[idx] = {
941+
"id": None,
942+
"name": None,
943+
"arguments": "",
944+
"type": None,
945+
}
946+
947+
if tool_call.id:
948+
tool_states[idx]["id"] = tool_call.id
949+
950+
if tool_call.type:
951+
assert tool_call.type == "function"
952+
tool_states[idx]["type"] = tool_call.type
953+
954+
if tool_call.function:
955+
if tool_call.function.name:
956+
tool_states[idx]["name"] = tool_call.function.name
957+
958+
if tool_call.function.arguments is not None:
959+
tool_states[idx]["arguments"] += tool_call.function.arguments
960+
961+
# Verify content was streamed
962+
assert "I'll check the weather for you." in other_content
963+
964+
# Verify we got the tool call
965+
assert len(tool_states) == 1
966+
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1
967+
968+
state = tool_states[0]
969+
assert state["id"] is not None
970+
assert state["type"] == "function"
971+
assert state["name"] == "get_current_weather"
972+
973+
# Verify arguments were parsed correctly despite missing opening tag
974+
assert state["arguments"] is not None
975+
args = json.loads(state["arguments"])
976+
assert args["city"] == "Dallas"
977+
assert args["state"] == "TX"
978+
assert args["unit"] == "fahrenheit"

vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py

Lines changed: 92 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import ast
44
import json
5-
import uuid
65
from collections.abc import Sequence
76
from typing import Any
87
from xml.parsers.expat import ParserCreate
98

109
import regex as re
1110

11+
from vllm.entrypoints.chat_utils import make_tool_call_id
1212
from vllm.entrypoints.openai.protocol import (
1313
ChatCompletionRequest,
1414
ChatCompletionToolsParam,
@@ -375,14 +375,21 @@ def _find_next_complete_element(self, start_pos: int) -> tuple[str | None, int]:
375375
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
376376
else:
377377
# If currently not parsing tool calls (entering a tool_call),
378-
# check if starts with <tool_call>
378+
# check if starts with <tool_call> or <function=
379379
if self.current_call_id is None:
380380
# Check if might be start of <tool_call>
381381
if buffer == "<tool_call>"[: len(buffer)]:
382382
# Might be start of <tool_call>, wait for more data
383383
return None, start_pos
384+
elif (
385+
buffer.startswith("<function=")
386+
or buffer == "<function="[: len(buffer)]
387+
):
388+
# Might be start of <function=, wait for more data
389+
# to get the complete function tag
390+
return None, start_pos
384391
else:
385-
# Not start of <tool_call>, treat as text
392+
# Not start of <tool_call> or <function=, treat as text
386393
return buffer, start_pos + len(buffer)
387394
else:
388395
# When parsing tool calls,
@@ -621,7 +628,7 @@ def _start_element(self, name: str, attrs: dict[str, str]):
621628
self._auto_close_open_parameter_if_needed("tool_call")
622629

623630
self.parameters = {}
624-
self.current_call_id = self._get_next_call_id()
631+
self.current_call_id = make_tool_call_id()
625632
self.current_param_is_first = True
626633
self.tool_call_index += 1
627634
elif name.startswith("function") or (name == "function"):
@@ -957,10 +964,6 @@ def set_tools(self, tools: list[ChatCompletionToolsParam] | None):
957964
"""Set tool configuration information"""
958965
self.tools = tools
959966

960-
def _get_next_call_id(self):
961-
"""Generate unique call ID"""
962-
return f"call_{uuid.uuid4().hex[:24]}"
963-
964967
def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None:
965968
"""Extract function name from various formats"""
966969
if attrs and "name" in attrs:
@@ -1168,6 +1171,10 @@ def __init__(self, tokenizer: AnyTokenizer):
11681171
super().__init__(tokenizer)
11691172
self.parser = StreamingXMLToolCallParser()
11701173

1174+
# Add missing attributes for compatibility with serving_chat.py
1175+
self.prev_tool_call_arr: list[dict] = []
1176+
self.streamed_args_for_tool: list[str] = []
1177+
11711178
logger.info(
11721179
"vLLM Successfully import tool parser %s !", self.__class__.__name__
11731180
)
@@ -1178,6 +1185,9 @@ def extract_tool_calls(
11781185
request: ChatCompletionRequest,
11791186
) -> ExtractedToolCallInformation:
11801187
self.parser.reset_streaming_state()
1188+
# Reset tool call tracking arrays for new extraction
1189+
self.prev_tool_call_arr = []
1190+
self.streamed_args_for_tool = []
11811191
if request:
11821192
self.parser.set_tools(request.tools)
11831193
result = self.parser.parse_single_streaming_chunks(model_output)
@@ -1201,6 +1211,34 @@ def extract_tool_calls(
12011211
),
12021212
)
12031213
)
1214+
1215+
# Update tool call tracking arrays for compatibility
1216+
tool_index = (
1217+
tool_call.index
1218+
if tool_call.index is not None
1219+
else len(self.prev_tool_call_arr) - 1
1220+
)
1221+
1222+
# Ensure we have enough entries in our tracking arrays
1223+
while len(self.prev_tool_call_arr) <= tool_index:
1224+
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
1225+
while len(self.streamed_args_for_tool) <= tool_index:
1226+
self.streamed_args_for_tool.append("")
1227+
1228+
# Update tool call information
1229+
self.prev_tool_call_arr[tool_index]["name"] = (
1230+
tool_call.function.name
1231+
)
1232+
self.prev_tool_call_arr[tool_index]["arguments"] = (
1233+
tool_call.function.arguments
1234+
)
1235+
1236+
# Update streamed arguments
1237+
if tool_call.function.arguments:
1238+
self.streamed_args_for_tool[tool_index] = (
1239+
tool_call.function.arguments
1240+
)
1241+
12041242
return ExtractedToolCallInformation(
12051243
tool_calls=tool_calls,
12061244
tools_called=len(tool_calls) > 0,
@@ -1219,6 +1257,9 @@ def extract_tool_calls_streaming(
12191257
) -> DeltaMessage | None:
12201258
if not previous_text:
12211259
self.parser.reset_streaming_state()
1260+
# Reset tool call tracking arrays for new streaming session
1261+
self.prev_tool_call_arr = []
1262+
self.streamed_args_for_tool = []
12221263
if request:
12231264
self.parser.set_tools(request.tools)
12241265

@@ -1230,20 +1271,48 @@ def extract_tool_calls_streaming(
12301271
open_calls = current_text.count(
12311272
self.parser.tool_call_start_token
12321273
) - current_text.count(self.parser.tool_call_end_token)
1233-
if open_calls == 0 and self.parser.tool_call_index > 0:
1234-
# If current_call_id is None, use last_completed_call_id
1235-
call_id = (
1236-
self.parser.current_call_id or self.parser.last_completed_call_id
1237-
)
1238-
return DeltaMessage(
1239-
tool_calls=[
1240-
DeltaToolCall(
1241-
index=self.parser.tool_call_index - 1,
1242-
id=call_id,
1243-
function=DeltaFunctionCall(arguments=""),
1244-
type="function",
1274+
if (
1275+
open_calls == 0
1276+
and self.parser.tool_call_index > 0
1277+
or not self.parser.tool_call_index
1278+
and current_text
1279+
):
1280+
return DeltaMessage(content="")
1281+
return None
1282+
1283+
# Parse the delta text and get the result
1284+
result = self.parser.parse_single_streaming_chunks(delta_text)
1285+
1286+
# Update tool call tracking arrays based on incremental parsing results
1287+
if result and result.tool_calls:
1288+
for tool_call in result.tool_calls:
1289+
if tool_call.function:
1290+
tool_index = (
1291+
tool_call.index
1292+
if tool_call.index is not None
1293+
else len(self.prev_tool_call_arr) - 1
1294+
)
1295+
1296+
# Ensure we have enough entries in our tracking arrays
1297+
while len(self.prev_tool_call_arr) <= tool_index:
1298+
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
1299+
while len(self.streamed_args_for_tool) <= tool_index:
1300+
self.streamed_args_for_tool.append("")
1301+
1302+
# Update tool name if provided
1303+
if tool_call.function.name:
1304+
self.prev_tool_call_arr[tool_index]["name"] = (
1305+
tool_call.function.name
12451306
)
1246-
]
1247-
)
12481307

1249-
return self.parser.parse_single_streaming_chunks(delta_text)
1308+
# Update arguments incrementally
1309+
if tool_call.function.arguments is not None:
1310+
# Concatenate the incremental arguments
1311+
# to the existing streamed arguments
1312+
self.prev_tool_call_arr[tool_index]["arguments"] += (
1313+
tool_call.function.arguments
1314+
)
1315+
self.streamed_args_for_tool[tool_index] += (
1316+
tool_call.function.arguments
1317+
)
1318+
return result

0 commit comments

Comments
 (0)