|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 |
|
4 | | -from typing import NamedTuple |
5 | | - |
6 | 4 | import openai # use the official client for correctness check |
7 | 5 | import pytest |
8 | 6 | import pytest_asyncio |
@@ -39,53 +37,14 @@ async def client(server): |
39 | 37 | yield async_client |
40 | 38 |
|
41 | 39 |
|
42 | | -class TestCase(NamedTuple): |
43 | | - model_name: str |
44 | | - stream: bool |
45 | | - tool_choice: str |
46 | | - enable_thinking: bool |
47 | | - |
48 | | - |
49 | 40 | @pytest.mark.asyncio |
50 | | -@pytest.mark.parametrize( |
51 | | - "test_case", |
52 | | - [ |
53 | | - TestCase(model_name=MODEL_NAME, |
54 | | - stream=True, |
55 | | - tool_choice="auto", |
56 | | - enable_thinking=False), |
57 | | - TestCase(model_name=MODEL_NAME, |
58 | | - stream=False, |
59 | | - tool_choice="auto", |
60 | | - enable_thinking=False), |
61 | | - TestCase(model_name=MODEL_NAME, |
62 | | - stream=True, |
63 | | - tool_choice="required", |
64 | | - enable_thinking=False), |
65 | | - TestCase(model_name=MODEL_NAME, |
66 | | - stream=False, |
67 | | - tool_choice="required", |
68 | | - enable_thinking=False), |
69 | | - TestCase(model_name=MODEL_NAME, |
70 | | - stream=True, |
71 | | - tool_choice="auto", |
72 | | - enable_thinking=True), |
73 | | - TestCase(model_name=MODEL_NAME, |
74 | | - stream=False, |
75 | | - tool_choice="auto", |
76 | | - enable_thinking=True), |
77 | | - TestCase(model_name=MODEL_NAME, |
78 | | - stream=True, |
79 | | - tool_choice="required", |
80 | | - enable_thinking=True), |
81 | | - TestCase(model_name=MODEL_NAME, |
82 | | - stream=False, |
83 | | - tool_choice="required", |
84 | | - enable_thinking=True), |
85 | | - ], |
86 | | -) |
87 | | -async def test_function_tool_use(client: openai.AsyncOpenAI, |
88 | | - test_case: TestCase): |
| 41 | +@pytest.mark.parametrize("model_name", [MODEL_NAME]) |
| 42 | +@pytest.mark.parametrize("stream", [True, False]) |
| 43 | +@pytest.mark.parametrize("tool_choice", ["auto", "required"]) |
| 44 | +@pytest.mark.parametrize("enable_thinking", [True, False]) |
| 45 | +async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, |
| 46 | + stream: bool, tool_choice: str, |
| 47 | + enable_thinking: bool): |
89 | 48 | tools = [ |
90 | 49 | { |
91 | 50 | "type": "function", |
@@ -174,37 +133,37 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, |
174 | 133 | "forecast for the next 5 days, in fahrenheit?", |
175 | 134 | }, |
176 | 135 | ] |
177 | | - if not test_case.stream: |
| 136 | + if not stream: |
178 | 137 | # Non-streaming test |
179 | 138 | chat_completion = await client.chat.completions.create( |
180 | 139 | messages=messages, |
181 | | - model=test_case.model_name, |
| 140 | + model=model_name, |
182 | 141 | tools=tools, |
183 | | - tool_choice=test_case.tool_choice, |
| 142 | + tool_choice=tool_choice, |
184 | 143 | extra_body={ |
185 | 144 | "chat_template_kwargs": { |
186 | | - "enable_thinking": test_case.enable_thinking |
| 145 | + "enable_thinking": enable_thinking |
187 | 146 | } |
188 | 147 | }) |
189 | 148 |
|
190 | 149 | assert chat_completion.choices[0].message.tool_calls is not None |
191 | 150 | assert len(chat_completion.choices[0].message.tool_calls) > 0 |
192 | 151 | else: |
193 | 152 | # Streaming test |
194 | | - stream = await client.chat.completions.create( |
| 153 | + output_stream = await client.chat.completions.create( |
195 | 154 | messages=messages, |
196 | | - model=test_case.model_name, |
| 155 | + model=model_name, |
197 | 156 | tools=tools, |
198 | | - tool_choice=test_case.tool_choice, |
| 157 | + tool_choice=tool_choice, |
199 | 158 | stream=True, |
200 | 159 | extra_body={ |
201 | 160 | "chat_template_kwargs": { |
202 | | - "enable_thinking": test_case.enable_thinking |
| 161 | + "enable_thinking": enable_thinking |
203 | 162 | } |
204 | 163 | }) |
205 | 164 |
|
206 | 165 | output = [] |
207 | | - async for chunk in stream: |
| 166 | + async for chunk in output_stream: |
208 | 167 | if chunk.choices and chunk.choices[0].delta.tool_calls: |
209 | 168 | output.extend(chunk.choices[0].delta.tool_calls) |
210 | 169 |
|
|
0 commit comments