|
6 | 6 | from typing import Any |
7 | 7 |
|
8 | 8 | import partial_json_parser |
| 9 | +from openai.types.responses import ( |
| 10 | + FunctionTool, |
| 11 | + ToolChoiceFunction, |
| 12 | +) |
| 13 | +from openai.types.responses.tool import Tool |
9 | 14 | from partial_json_parser.core.options import Allow |
10 | 15 |
|
| 16 | +from vllm.entrypoints.openai.protocol import ( |
| 17 | + ChatCompletionNamedToolChoiceParam, |
| 18 | + ChatCompletionToolsParam, |
| 19 | +) |
| 20 | + |
11 | 21 |
|
12 | 22 | def find_common_prefix(s1: str, s2: str) -> str: |
13 | 23 | """ |
@@ -122,3 +132,98 @@ def consume_space(i: int, s: str) -> int: |
122 | 132 | while i < len(s) and s[i].isspace(): |
123 | 133 | i += 1 |
124 | 134 | return i |
| 135 | + |
| 136 | + |
| 137 | +def _extract_tool_info( |
| 138 | + tool: Tool | ChatCompletionToolsParam, |
| 139 | +) -> tuple[str, dict[str, Any] | None]: |
| 140 | + if isinstance(tool, FunctionTool): |
| 141 | + return tool.name, tool.parameters |
| 142 | + elif isinstance(tool, ChatCompletionToolsParam): |
| 143 | + return tool.function.name, tool.function.parameters |
| 144 | + else: |
| 145 | + raise TypeError(f"Unsupported tool type: {type(tool)}") |
| 146 | + |
| 147 | + |
| 148 | +def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict: |
| 149 | + name, params = _extract_tool_info(tool) |
| 150 | + params = params if params else {"type": "object", "properties": {}} |
| 151 | + return { |
| 152 | + "properties": { |
| 153 | + "name": {"type": "string", "enum": [name]}, |
| 154 | + "parameters": params, |
| 155 | + }, |
| 156 | + "required": ["name", "parameters"], |
| 157 | + } |
| 158 | + |
| 159 | + |
| 160 | +def _get_tool_schema_defs( |
| 161 | + tools: list[Tool | ChatCompletionToolsParam], |
| 162 | +) -> dict: |
| 163 | + all_defs: dict[str, dict[str, Any]] = {} |
| 164 | + for tool in tools: |
| 165 | + _, params = _extract_tool_info(tool) |
| 166 | + if params is None: |
| 167 | + continue |
| 168 | + defs = params.pop("$defs", {}) |
| 169 | + for def_name, def_schema in defs.items(): |
| 170 | + if def_name in all_defs and all_defs[def_name] != def_schema: |
| 171 | + raise ValueError( |
| 172 | + f"Tool definition '{def_name}' has multiple schemas, " |
| 173 | + "which is not supported." |
| 174 | + ) |
| 175 | + all_defs[def_name] = def_schema |
| 176 | + return all_defs |
| 177 | + |
| 178 | + |
| 179 | +def _get_json_schema_from_tools( |
| 180 | + tools: list[Tool | ChatCompletionToolsParam], |
| 181 | +) -> dict: |
| 182 | + json_schema = { |
| 183 | + "type": "array", |
| 184 | + "minItems": 1, |
| 185 | + "items": { |
| 186 | + "type": "object", |
| 187 | + "anyOf": [_get_tool_schema_from_tool(tool) for tool in tools], |
| 188 | + }, |
| 189 | + } |
| 190 | + json_schema_defs = _get_tool_schema_defs(tools) |
| 191 | + if json_schema_defs: |
| 192 | + json_schema["$defs"] = json_schema_defs |
| 193 | + return json_schema |
| 194 | + |
| 195 | + |
| 196 | +def get_json_schema_from_tools( |
| 197 | + tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam, |
| 198 | + tools: list[FunctionTool | ChatCompletionToolsParam] | None, |
| 199 | +) -> str | dict | None: |
| 200 | + # tool_choice: "none" |
| 201 | + if tool_choice in ("none", None) or tools is None: |
| 202 | + return None |
| 203 | + # tool_choice: Forced Function (Responses) |
| 204 | + if (not isinstance(tool_choice, str)) and isinstance( |
| 205 | + tool_choice, ToolChoiceFunction |
| 206 | + ): |
| 207 | + tool_name = tool_choice.name |
| 208 | + tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)} |
| 209 | + if tool_name not in tool_map: |
| 210 | + raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") |
| 211 | + return tool_map[tool_name].parameters |
| 212 | + # tool_choice: Forced Function (ChatCompletion) |
| 213 | + if (not isinstance(tool_choice, str)) and isinstance( |
| 214 | + tool_choice, ChatCompletionNamedToolChoiceParam |
| 215 | + ): |
| 216 | + tool_name = tool_choice.function.name |
| 217 | + tool_map = { |
| 218 | + tool.function.name: tool |
| 219 | + for tool in tools |
| 220 | + if isinstance(tool, ChatCompletionToolsParam) |
| 221 | + } |
| 222 | + if tool_name not in tool_map: |
| 223 | + raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") |
| 224 | + return tool_map[tool_name].function.parameters |
| 225 | + # tool_choice: "required" |
| 226 | + if tool_choice == "required": |
| 227 | + return _get_json_schema_from_tools(tools) |
| 228 | + # tool_choice: "auto" |
| 229 | + return None |
0 commit comments