Skip to content

Commit 3567816

Browse files
[Refactor] move tool parsing logic from protocol.py to the tool parser (#27383)
Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
1 parent e0ef8a2 commit 3567816

File tree

9 files changed

+131
-75
lines changed

9 files changed

+131
-75
lines changed

tests/tool_use/test_tool_choice_required.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from pydantic import TypeAdapter
1010

1111
from vllm.entrypoints.openai.protocol import (
12-
ChatCompletionRequest,
1312
ChatCompletionToolsParam,
1413
)
1514
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
15+
from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools
1616

1717
pytestmark = pytest.mark.cpu_test
1818

@@ -67,8 +67,9 @@
6767
def _compile_and_check(
6868
tools: list[ChatCompletionToolsParam], sample_output, should_match: bool
6969
):
70-
self = MagicMock(tool_choice="required", tools=tools)
71-
schema = ChatCompletionRequest._get_json_schema_from_tool(self)
70+
# self = MagicMock(tool_choice="required", tools=tools)
71+
# schema = ChatCompletionRequest._get_json_schema_from_tool(self)
72+
schema = get_json_schema_from_tools(tools=tools, tool_choice="required")
7273
assert isinstance(schema, dict)
7374

7475
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide

vllm/entrypoints/openai/protocol.py

Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -854,8 +854,7 @@ def to_sampling_params(
854854
self.structured_outputs = StructuredOutputsParams(**kwargs)
855855

856856
response_format = self.response_format
857-
json_schema_from_tool = self._get_json_schema_from_tool()
858-
if response_format is not None or json_schema_from_tool is not None:
857+
if response_format is not None:
859858
# If structured outputs wasn't already enabled,
860859
# we must enable it for these features to work
861860
if self.structured_outputs is None:
@@ -881,10 +880,6 @@ def to_sampling_params(
881880
s_tag_obj = structural_tag.model_dump(by_alias=True)
882881
self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
883882

884-
# Set structured output params for tool calling
885-
if json_schema_from_tool is not None:
886-
self.structured_outputs.json = json_schema_from_tool
887-
888883
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
889884
if self.kv_transfer_params:
890885
# Pass in kv_transfer_params via extra_args
@@ -924,72 +919,6 @@ def to_sampling_params(
924919
extra_args=extra_args or None,
925920
)
926921

927-
def _get_json_schema_from_tool(self) -> str | dict | None:
928-
# user has chosen to not use any tool
929-
if self.tool_choice == "none" or self.tools is None:
930-
return None
931-
932-
# user has chosen to use a named tool
933-
if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
934-
tool_name = self.tool_choice.function.name
935-
tools = {tool.function.name: tool.function for tool in self.tools}
936-
if tool_name not in tools:
937-
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
938-
tool = tools[tool_name]
939-
return tool.parameters
940-
941-
if self.tool_choice == "required":
942-
# Pydantic schema generation cannot be used since the JSON schema
943-
# has to be constructed for a specific instantiation of a tool list
944-
# so that parameters of a function are correctly generated
945-
# based on the chosen function name
946-
def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
947-
return {
948-
"properties": {
949-
"name": {"type": "string", "enum": [tool.function.name]},
950-
# parameters are always generated as '{}' in the final
951-
# output if they are missing from the request
952-
# (i.e. are None or '{}') so the schema is
953-
# updated to produce an empty object in that case
954-
"parameters": tool.function.parameters
955-
if tool.function.parameters
956-
else {"type": "object", "properties": {}},
957-
},
958-
"required": ["name", "parameters"],
959-
}
960-
961-
def get_tool_schema_defs(tools: list[ChatCompletionToolsParam]) -> dict:
962-
all_defs = dict[str, dict[str, Any]]()
963-
for tool in tools:
964-
if tool.function.parameters is None:
965-
continue
966-
defs = tool.function.parameters.pop("$defs", {})
967-
for def_name, def_schema in defs.items():
968-
if def_name in all_defs and all_defs[def_name] != def_schema:
969-
raise ValueError(
970-
f"Tool definition '{def_name}' has "
971-
"multiple schemas, which is not "
972-
"supported."
973-
)
974-
else:
975-
all_defs[def_name] = def_schema
976-
return all_defs
977-
978-
json_schema = {
979-
"type": "array",
980-
"minItems": 1,
981-
"items": {
982-
"type": "object",
983-
"anyOf": [get_tool_schema(tool) for tool in self.tools],
984-
},
985-
}
986-
json_schema_defs = get_tool_schema_defs(self.tools)
987-
if json_schema_defs:
988-
json_schema["$defs"] = json_schema_defs
989-
return json_schema
990-
991-
return None
992-
993922
@model_validator(mode="before")
994923
@classmethod
995924
def validate_stream_options(cls, data):

vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
DeltaMessage,
1111
ExtractedToolCallInformation,
1212
)
13+
from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools
1314
from vllm.logger import init_logger
15+
from vllm.sampling_params import (
16+
StructuredOutputsParams,
17+
)
1418
from vllm.transformers_utils.tokenizer import AnyTokenizer
1519
from vllm.utils.collection_utils import is_list_of
1620
from vllm.utils.import_utils import import_from_path
@@ -44,6 +48,18 @@ def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionReques
4448
"""
4549
Static method that used to adjust the request parameters.
4650
"""
51+
if not request.tools:
52+
return request
53+
json_schema_from_tool = get_json_schema_from_tools(
54+
tool_choice=request.tool_choice, tools=request.tools
55+
)
56+
# Set structured output params for tool calling
57+
if json_schema_from_tool is not None:
58+
if request.structured_outputs is None:
59+
request.structured_outputs = StructuredOutputsParams()
60+
# tool_choice: "Forced Function" or "required" will override
61+
# structured output json settings to make tool calling work correctly
62+
request.structured_outputs.json = json_schema_from_tool
4763
return request
4864

4965
def extract_tool_calls(

vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def tool_call_delta_buffer(self, delta_text: str):
112112
return delta_text
113113

114114
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
115+
request = super().adjust_request(request)
115116
if request.tools and request.tool_choice != "none":
116117
# do not skip special tokens because the tool_call tokens are
117118
# marked "special" in some models. Since they are skipped

vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(self, tokenizer: AnyTokenizer):
3535
self.position = 0
3636

3737
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
38+
request = super().adjust_request(request)
3839
if request.tools and request.tool_choice != "none":
3940
# do not skip special tokens because internlm use the special
4041
# tokens to indicate the start and end of the tool calls

vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(self, tokenizer: AnyTokenizer):
6868
)
6969

7070
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
71+
request = super().adjust_request(request)
7172
if request.tools and request.tool_choice != "none":
7273
# do not skip special tokens because jamba use the special
7374
# tokens to indicate the start and end of the tool calls

vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(self, tokenizer: AnyTokenizer):
9494
)
9595

9696
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
97+
request = super().adjust_request(request)
9798
if (
9899
not isinstance(self.model_tokenizer, MistralTokenizer)
99100
and request.tools

vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self, tokenizer: AnyTokenizer):
5151
self.tool_block_finished = False
5252

5353
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
54+
request = super().adjust_request(request)
5455
if request.tools and request.tool_choice != "none":
5556
request.skip_special_tokens = False
5657
return request

vllm/entrypoints/openai/tool_parsers/utils.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,18 @@
66
from typing import Any
77

88
import partial_json_parser
9+
from openai.types.responses import (
10+
FunctionTool,
11+
ToolChoiceFunction,
12+
)
13+
from openai.types.responses.tool import Tool
914
from partial_json_parser.core.options import Allow
1015

16+
from vllm.entrypoints.openai.protocol import (
17+
ChatCompletionNamedToolChoiceParam,
18+
ChatCompletionToolsParam,
19+
)
20+
1121

1222
def find_common_prefix(s1: str, s2: str) -> str:
1323
"""
@@ -122,3 +132,98 @@ def consume_space(i: int, s: str) -> int:
122132
while i < len(s) and s[i].isspace():
123133
i += 1
124134
return i
135+
136+
137+
def _extract_tool_info(
138+
tool: Tool | ChatCompletionToolsParam,
139+
) -> tuple[str, dict[str, Any] | None]:
140+
if isinstance(tool, FunctionTool):
141+
return tool.name, tool.parameters
142+
elif isinstance(tool, ChatCompletionToolsParam):
143+
return tool.function.name, tool.function.parameters
144+
else:
145+
raise TypeError(f"Unsupported tool type: {type(tool)}")
146+
147+
148+
def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict:
149+
name, params = _extract_tool_info(tool)
150+
params = params if params else {"type": "object", "properties": {}}
151+
return {
152+
"properties": {
153+
"name": {"type": "string", "enum": [name]},
154+
"parameters": params,
155+
},
156+
"required": ["name", "parameters"],
157+
}
158+
159+
160+
def _get_tool_schema_defs(
161+
tools: list[Tool | ChatCompletionToolsParam],
162+
) -> dict:
163+
all_defs: dict[str, dict[str, Any]] = {}
164+
for tool in tools:
165+
_, params = _extract_tool_info(tool)
166+
if params is None:
167+
continue
168+
defs = params.pop("$defs", {})
169+
for def_name, def_schema in defs.items():
170+
if def_name in all_defs and all_defs[def_name] != def_schema:
171+
raise ValueError(
172+
f"Tool definition '{def_name}' has multiple schemas, "
173+
"which is not supported."
174+
)
175+
all_defs[def_name] = def_schema
176+
return all_defs
177+
178+
179+
def _get_json_schema_from_tools(
180+
tools: list[Tool | ChatCompletionToolsParam],
181+
) -> dict:
182+
json_schema = {
183+
"type": "array",
184+
"minItems": 1,
185+
"items": {
186+
"type": "object",
187+
"anyOf": [_get_tool_schema_from_tool(tool) for tool in tools],
188+
},
189+
}
190+
json_schema_defs = _get_tool_schema_defs(tools)
191+
if json_schema_defs:
192+
json_schema["$defs"] = json_schema_defs
193+
return json_schema
194+
195+
196+
def get_json_schema_from_tools(
197+
tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam,
198+
tools: list[FunctionTool | ChatCompletionToolsParam] | None,
199+
) -> str | dict | None:
200+
# tool_choice: "none"
201+
if tool_choice in ("none", None) or tools is None:
202+
return None
203+
# tool_choice: Forced Function (Responses)
204+
if (not isinstance(tool_choice, str)) and isinstance(
205+
tool_choice, ToolChoiceFunction
206+
):
207+
tool_name = tool_choice.name
208+
tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
209+
if tool_name not in tool_map:
210+
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
211+
return tool_map[tool_name].parameters
212+
# tool_choice: Forced Function (ChatCompletion)
213+
if (not isinstance(tool_choice, str)) and isinstance(
214+
tool_choice, ChatCompletionNamedToolChoiceParam
215+
):
216+
tool_name = tool_choice.function.name
217+
tool_map = {
218+
tool.function.name: tool
219+
for tool in tools
220+
if isinstance(tool, ChatCompletionToolsParam)
221+
}
222+
if tool_name not in tool_map:
223+
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
224+
return tool_map[tool_name].function.parameters
225+
# tool_choice: "required"
226+
if tool_choice == "required":
227+
return _get_json_schema_from_tools(tools)
228+
# tool_choice: "auto"
229+
return None

0 commit comments

Comments
 (0)