Skip to content

Commit c29fb54

Browse files
aarnphmsimon-mo
andauthored
[gpt-oss] tool parser supports for /chat/completions [1/n] (#22386)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Simon Mo <simon.mo@hey.com>
1 parent 65e0389 commit c29fb54

File tree

8 files changed

+573
-63
lines changed

8 files changed

+573
-63
lines changed

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from __future__ import annotations
5+
46
import asyncio
57
from contextlib import suppress
68
from dataclasses import dataclass, field
7-
from typing import Any, Optional
9+
from typing import TYPE_CHECKING, Any, Optional
810
from unittest.mock import MagicMock
911

1012
import pytest
13+
import pytest_asyncio
1114

1215
from vllm.config import MultiModalConfig
1316
from vllm.engine.multiprocessing.client import MQLLMEngineClient
@@ -17,6 +20,164 @@
1720
OpenAIServingModels)
1821
from vllm.transformers_utils.tokenizer import get_tokenizer
1922

23+
from ...utils import RemoteOpenAIServer
24+
25+
if TYPE_CHECKING:
26+
from openai import OpenAI
27+
28+
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
29+
30+
31+
@pytest.fixture(scope="module")
32+
def monkeypatch_module():
33+
from _pytest.monkeypatch import MonkeyPatch
34+
mpatch = MonkeyPatch()
35+
yield mpatch
36+
mpatch.undo()
37+
38+
39+
@pytest.fixture(scope="module")
40+
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch):
41+
with monkeypatch_module.context() as m:
42+
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
43+
args = [
44+
"--enforce-eager",
45+
"--max-model-len",
46+
"8192",
47+
"--tool-call-parser",
48+
"openai",
49+
"--reasoning-parser",
50+
"openai_gptoss",
51+
"--enable-auto-tool-choice",
52+
]
53+
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server:
54+
yield remote_server
55+
56+
57+
@pytest_asyncio.fixture
58+
async def gptoss_client(gptoss_server):
59+
async with gptoss_server.get_async_client() as async_client:
60+
yield async_client
61+
62+
63+
@pytest.mark.asyncio
64+
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
65+
tools = [{
66+
"type": "function",
67+
"function": {
68+
"name": "get_current_weather",
69+
"description": "Get the current weather in a given location",
70+
"parameters": {
71+
"type": "object",
72+
"properties": {
73+
"city": {
74+
"type": "string"
75+
},
76+
"state": {
77+
"type": "string"
78+
},
79+
"unit": {
80+
"type": "string",
81+
"enum": ["celsius", "fahrenheit"],
82+
},
83+
},
84+
"required": ["city", "state", "unit"],
85+
},
86+
},
87+
}]
88+
89+
messages = [
90+
{
91+
"role": "user",
92+
"content": "What is the weather in Dallas, TX?"
93+
},
94+
]
95+
96+
stream = await gptoss_client.chat.completions.create(
97+
model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True)
98+
99+
name = None
100+
args_buf = ""
101+
async for chunk in stream:
102+
delta = chunk.choices[0].delta
103+
if delta.tool_calls:
104+
tc = delta.tool_calls[0]
105+
if tc.function and tc.function.name:
106+
name = tc.function.name
107+
if tc.function and tc.function.arguments:
108+
args_buf += tc.function.arguments
109+
110+
assert name is not None
111+
assert len(args_buf) > 0
112+
113+
114+
@pytest.mark.asyncio
115+
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
116+
tools = [{
117+
"type": "function",
118+
"function": {
119+
"name": "get_current_weather",
120+
"description": "Get the current weather in a given location",
121+
"parameters": {
122+
"type": "object",
123+
"properties": {
124+
"city": {
125+
"type": "string"
126+
},
127+
"state": {
128+
"type": "string"
129+
},
130+
"unit": {
131+
"type": "string",
132+
"enum": ["celsius", "fahrenheit"],
133+
},
134+
},
135+
"required": ["city", "state", "unit"],
136+
},
137+
},
138+
}]
139+
140+
messages = [
141+
{
142+
"role": "system",
143+
"content": "you are a helpful assistant"
144+
},
145+
{
146+
"role": "user",
147+
"content": "What is the weather in Dallas, TX?"
148+
},
149+
]
150+
151+
first = await gptoss_client.chat.completions.create(
152+
model=GPT_OSS_MODEL_NAME,
153+
messages=messages,
154+
tools=tools,
155+
temperature=0.0,
156+
)
157+
first_msg = first.choices[0].message
158+
assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
159+
tc = first_msg.tool_calls[0]
160+
assert tc.function is not None and tc.function.name == "get_current_weather"
161+
args1 = tc.function.arguments
162+
assert args1 is not None and len(args1) > 0
163+
164+
messages.append({"role": "assistant", "content": args1})
165+
messages.append({
166+
"role": "user",
167+
"content": "Now convert to celsius and return JSON only"
168+
})
169+
170+
second = await gptoss_client.chat.completions.create(
171+
model=GPT_OSS_MODEL_NAME,
172+
messages=messages,
173+
tools=tools,
174+
temperature=0.0,
175+
)
176+
second_msg = second.choices[0].message
177+
assert (second_msg.content is not None and len(second_msg.content) > 0) or \
178+
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501
179+
180+
20181
MODEL_NAME = "openai-community/gpt2"
21182
CHAT_TEMPLATE = "Dummy chat template for testing {}"
22183
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import json
5+
6+
import pytest
7+
from openai_harmony import (Conversation, DeveloperContent,
8+
HarmonyEncodingName, Message, Role, SystemContent,
9+
load_harmony_encoding)
10+
11+
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
12+
from vllm.entrypoints.openai.tool_parsers import OpenAIToolParser
13+
from vllm.transformers_utils.tokenizer import get_tokenizer
14+
15+
MODEL = "gpt2"
16+
17+
18+
@pytest.fixture(scope="module")
19+
def openai_tokenizer():
20+
# The parser does not use the tokenizer, but the constructor requires it.
21+
return get_tokenizer(MODEL)
22+
23+
24+
@pytest.fixture
25+
def openai_tool_parser(openai_tokenizer):
26+
return OpenAIToolParser(openai_tokenizer)
27+
28+
29+
@pytest.fixture(scope="module")
30+
def harmony_encoding():
31+
return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
32+
33+
34+
def assert_tool_calls(
35+
actual_tool_calls: list[ToolCall],
36+
expected_tool_calls: list[ToolCall],
37+
):
38+
assert len(actual_tool_calls) == len(expected_tool_calls)
39+
40+
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
41+
expected_tool_calls):
42+
assert isinstance(actual_tool_call.id, str)
43+
assert len(actual_tool_call.id) > 16 # Default from protocol.py
44+
assert actual_tool_call.type == "function"
45+
assert actual_tool_call.function == expected_tool_call.function
46+
47+
48+
def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
49+
convo = Conversation.from_messages([
50+
Message.from_role_and_content(
51+
Role.SYSTEM,
52+
SystemContent.new(),
53+
),
54+
Message.from_role_and_content(
55+
Role.DEVELOPER,
56+
DeveloperContent.new().with_instructions("Talk like a pirate!")),
57+
Message.from_role_and_content(Role.USER, "Arrr, how be you?"),
58+
Message.from_role_and_content(Role.ASSISTANT,
59+
"This is a test").with_channel("final")
60+
])
61+
token_ids = harmony_encoding.render_conversation_for_completion(
62+
convo, Role.ASSISTANT)
63+
extracted_info = openai_tool_parser.extract_tool_calls(
64+
"",
65+
request=None,
66+
token_ids=token_ids,
67+
)
68+
assert not extracted_info.tools_called
69+
assert extracted_info.tool_calls == []
70+
assert extracted_info.content == "This is a test"
71+
72+
73+
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding):
74+
convo = Conversation.from_messages([
75+
Message.from_role_and_content(Role.USER,
76+
"What is the weather in Tokyo?"),
77+
Message.from_role_and_content(
78+
Role.ASSISTANT,
79+
'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501
80+
).with_channel("analysis"),
81+
Message.from_role_and_content(
82+
Role.ASSISTANT,
83+
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
84+
"functions.get_current_weather").with_content_type("json"),
85+
])
86+
token_ids = harmony_encoding.render_conversation_for_completion(
87+
convo, Role.ASSISTANT)
88+
89+
extracted_info = openai_tool_parser.extract_tool_calls(
90+
"",
91+
request=None,
92+
token_ids=token_ids,
93+
)
94+
assert extracted_info.tools_called
95+
expected_tool_calls = [
96+
ToolCall(function=FunctionCall(
97+
name="get_current_weather",
98+
arguments=json.dumps({"location": "Tokyo"}),
99+
))
100+
]
101+
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
102+
assert extracted_info.content is None
103+
104+
105+
def test_extract_tool_calls_multiple_tools(
106+
openai_tool_parser,
107+
harmony_encoding,
108+
):
109+
convo = Conversation.from_messages([
110+
Message.from_role_and_content(
111+
Role.USER, "What is the weather in Tokyo based on where I'm at?"),
112+
Message.from_role_and_content(
113+
Role.ASSISTANT,
114+
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
115+
).with_channel("analysis"),
116+
Message.from_role_and_content(
117+
Role.ASSISTANT,
118+
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
119+
"functions.get_current_weather").with_content_type("json"),
120+
Message.from_role_and_content(
121+
Role.ASSISTANT,
122+
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
123+
"functions.get_user_location").with_content_type("json"),
124+
])
125+
token_ids = harmony_encoding.render_conversation_for_completion(
126+
convo,
127+
Role.ASSISTANT,
128+
)
129+
130+
extracted_info = openai_tool_parser.extract_tool_calls(
131+
"",
132+
request=None,
133+
token_ids=token_ids,
134+
)
135+
assert extracted_info.tools_called
136+
expected_tool_calls = [
137+
ToolCall(function=FunctionCall(
138+
name="get_current_weather",
139+
arguments=json.dumps({"location": "Tokyo"}),
140+
)),
141+
ToolCall(function=FunctionCall(
142+
name="get_user_location",
143+
arguments=json.dumps({"location": "Tokyo"}),
144+
))
145+
]
146+
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
147+
assert extracted_info.content is None

0 commit comments

Comments
 (0)