Skip to content

Commit a226864

Browse files
committed
Merge branch 'main' into luka/custom-op-matching-2
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
2 parents e99a759 + 0a9ef0c commit a226864

37 files changed

+1176
-477
lines changed

.github/CODEOWNERS

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
/vllm/attention @LucasWilkinson
66
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
77
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
8-
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
98
/vllm/model_executor/layers/fused_moe @mgoin
10-
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @NickLucche
119
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256
1210
/vllm/model_executor/layers/mamba @tdoublep
1311
/vllm/model_executor/model_loader @22quinn
@@ -26,7 +24,6 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
2624
/vllm/config/cache.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345
2725

2826
# vLLM V1
29-
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
3027
/vllm/v1/attention @LucasWilkinson
3128
/vllm/v1/attention/backends/flashinfer.py @mgoin
3229
/vllm/v1/attention/backends/triton_attn.py @tdoublep

csrc/layernorm_kernels.cu

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "dispatch_utils.h"
33
#include "cub_helpers.h"
44
#include "core/batch_invariant.hpp"
5+
#include "quantization/vectorization_utils.cuh"
56

67
#include <torch/cuda.h>
78
#include <c10/cuda/CUDAGuard.h>
@@ -18,11 +19,22 @@ __global__ void rms_norm_kernel(
1819
const float epsilon, const int num_tokens, const int hidden_size) {
1920
__shared__ float s_variance;
2021
float variance = 0.0f;
22+
const scalar_t* input_row = input + blockIdx.x * input_stride;
2123

22-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
23-
const float x = (float)input[blockIdx.x * input_stride + idx];
24+
constexpr int VEC_SIZE = 8;
25+
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
26+
#pragma unroll
27+
for (int i = 0; i < VEC_SIZE; ++i) {
28+
float x = static_cast<float>(vec.val[i]);
29+
variance += x * x;
30+
}
31+
};
32+
auto scalar_op = [&variance](const scalar_t& val) {
33+
float x = static_cast<float>(val);
2434
variance += x * x;
25-
}
35+
};
36+
vllm::vectorize_read_with_alignment<VEC_SIZE>(
37+
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
2638

2739
using BlockReduce = cub::BlockReduce<float, 1024>;
2840
__shared__ typename BlockReduce::TempStorage reduceStore;

csrc/layernorm_quant_kernels.cu

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "dispatch_utils.h"
1111
#include "cub_helpers.h"
1212
#include "core/batch_invariant.hpp"
13+
#include "quantization/vectorization_utils.cuh"
1314

1415
#include <torch/cuda.h>
1516
#include <c10/cuda/CUDAGuard.h>
@@ -28,10 +29,22 @@ __global__ void rms_norm_static_fp8_quant_kernel(
2829
__shared__ float s_variance;
2930
float variance = 0.0f;
3031

31-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
32-
const float x = (float)input[blockIdx.x * input_stride + idx];
32+
const scalar_t* input_row = input + blockIdx.x * input_stride;
33+
34+
constexpr int VEC_SIZE = 8;
35+
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
36+
#pragma unroll
37+
for (int i = 0; i < VEC_SIZE; ++i) {
38+
float x = static_cast<float>(vec.val[i]);
39+
variance += x * x;
40+
}
41+
};
42+
auto scalar_op = [&variance](const scalar_t& val) {
43+
float x = static_cast<float>(val);
3344
variance += x * x;
34-
}
45+
};
46+
vllm::vectorize_read_with_alignment<VEC_SIZE>(
47+
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
3548

3649
using BlockReduce = cub::BlockReduce<float, 1024>;
3750
__shared__ typename BlockReduce::TempStorage reduceStore;

docs/features/tool_calling.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,16 @@ Supported models:
352352

353353
Flags: `--tool-call-parser qwen3_xml`
354354

355+
### Olmo 3 Models (`olmo3`)
356+
357+
Olmo 3 models output tool calls in a format that is very similar to the one expected by the `pythonic` parser (see below), with a few differences. Each tool call is a pythonic string, but the parallel tool calls are newline-delimited, and the calls are wrapped within XML tags as `<function_calls>..</function_calls>`. In addition, the parser also allows JSON boolean and null literals (`true`, `false`, and `null`) in addition to the pythonic ones (`True`, `False`, and `None`).
358+
359+
Supported models:
360+
361+
* TODO (will be updated after Olmo 3 release)
362+
363+
Flags: `--tool-call-parser olmo3`
364+
355365
### Models with Pythonic Tool Calls (`pythonic`)
356366

357367
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.

tests/compile/test_fusion_attn.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm.attention.selector import global_force_attn_backend_context_manager
1515
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
1616
from vllm.compilation.fx_utils import find_op_nodes
17+
from vllm.compilation.matcher_utils import QUANT_OPS
1718
from vllm.compilation.noop_elimination import NoOpEliminationPass
1819
from vllm.compilation.post_cleanup import PostCleanupPass
1920
from vllm.config import (
@@ -28,6 +29,7 @@
2829
)
2930
from vllm.forward_context import get_forward_context, set_forward_context
3031
from vllm.model_executor.layers.quantization.utils.quant_utils import (
32+
QuantKey,
3133
kFp8StaticTensorSym,
3234
kNvfp4Quant,
3335
)
@@ -305,7 +307,6 @@ def test_attention_quant_pattern(
305307
backend: _Backend,
306308
use_inductor_graph_partition: bool,
307309
dist_init,
308-
caplog_vllm,
309310
):
310311
"""Test AttentionStaticQuantPattern fusion pass"""
311312
if backend == _Backend.FLASHINFER and (
@@ -423,7 +424,7 @@ def test_attention_quant_pattern(
423424
)
424425

425426
# Check attn fusion support
426-
quant_key = model_class.quant_key
427+
quant_key: QuantKey = model_class.quant_key
427428
attn_fusion_supported = [
428429
layer.impl.fused_output_quant_supported(quant_key)
429430
for key, layer in vllm_config.compilation_config.static_forward_context.items()
@@ -432,6 +433,17 @@ def test_attention_quant_pattern(
432433
"All layers should support attention fusion"
433434
)
434435

436+
# Check quantization ops in the graph before and after fusion
437+
quant_op = (
438+
torch.ops.aten.reciprocal
439+
if "-quant_fp8" in custom_ops_list
440+
else QUANT_OPS[quant_key]
441+
)
442+
443+
# Note: for fp8, fully_replaced=False because query quant ops remain in graph.
444+
# Only output quant ops are fused into attention.
445+
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant)
446+
435447
# access the underlying `AttnFusionPass` on the `LazyInitPass`
436448
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
437449

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from unittest.mock import MagicMock, patch
5+
6+
import pytest
7+
8+
from tests.entrypoints.openai.tool_parsers.utils import (
9+
run_tool_extraction,
10+
run_tool_extraction_streaming,
11+
)
12+
from vllm.entrypoints.openai.protocol import FunctionCall
13+
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
14+
15+
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
16+
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
17+
SIMPLE_FUNCTION_CALL = FunctionCall(
18+
name="get_weather",
19+
arguments='{"city": "San Francisco", "metric": "celsius"}',
20+
)
21+
MORE_TYPES_FUNCTION_OUTPUT = (
22+
"register_user(name='John Doe', "
23+
"age=37, "
24+
"address={'city': 'San Francisco', 'state': 'CA'}, "
25+
"role=None, "
26+
"passed_test=True, "
27+
"aliases=['John', 'Johnny'])"
28+
)
29+
MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS = (
30+
"register_user(name='John Doe', "
31+
"age=37, "
32+
"address={'city': 'San Francisco', 'state': 'CA'}, "
33+
"role=null, "
34+
"passed_test=true, "
35+
"aliases=['John', 'Johnny'])"
36+
)
37+
MORE_TYPES_FUNCTION_CALL = FunctionCall(
38+
name="register_user",
39+
arguments='{"name": "John Doe", '
40+
'"age": 37, '
41+
'"address": {"city": "San Francisco", "state": "CA"}, '
42+
'"role": null, '
43+
'"passed_test": true, '
44+
'"aliases": ["John", "Johnny"]}',
45+
)
46+
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
47+
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
48+
name="get_weather",
49+
arguments="{}",
50+
)
51+
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
52+
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
53+
name="do_something_cool",
54+
arguments='{"additional_data": {}}',
55+
)
56+
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
57+
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
58+
name="do_something_cool",
59+
arguments='{"steps": []}',
60+
)
61+
ESCAPED_STRING_FUNCTION_OUTPUT = (
62+
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
63+
)
64+
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
65+
name="get_weather",
66+
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
67+
)
68+
69+
70+
@pytest.mark.parametrize("streaming", [True, False])
71+
def test_no_tool_call(streaming: bool):
72+
mock_tokenizer = MagicMock()
73+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
74+
model_output = "How can I help you today?"
75+
76+
content, tool_calls = run_tool_extraction(
77+
tool_parser, model_output, streaming=streaming
78+
)
79+
80+
assert content == model_output
81+
assert len(tool_calls) == 0
82+
83+
84+
TEST_CASES = [
85+
pytest.param(
86+
True,
87+
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
88+
[SIMPLE_FUNCTION_CALL],
89+
id="simple_streaming",
90+
),
91+
pytest.param(
92+
False,
93+
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
94+
[SIMPLE_FUNCTION_CALL],
95+
id="simple_nonstreaming",
96+
),
97+
pytest.param(
98+
True,
99+
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
100+
[MORE_TYPES_FUNCTION_CALL],
101+
id="more_types_streaming",
102+
),
103+
pytest.param(
104+
False,
105+
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
106+
[MORE_TYPES_FUNCTION_CALL],
107+
id="more_types_nonstreaming",
108+
),
109+
pytest.param(
110+
True,
111+
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
112+
[MORE_TYPES_FUNCTION_CALL],
113+
id="more_types_streaming_json_literals",
114+
),
115+
pytest.param(
116+
False,
117+
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
118+
[MORE_TYPES_FUNCTION_CALL],
119+
id="more_types_nonstreaming_json_literals",
120+
),
121+
pytest.param(
122+
True,
123+
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
124+
[PARAMETERLESS_FUNCTION_CALL],
125+
id="parameterless_streaming",
126+
),
127+
pytest.param(
128+
False,
129+
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
130+
[PARAMETERLESS_FUNCTION_CALL],
131+
id="parameterless_nonstreaming",
132+
),
133+
pytest.param(
134+
True,
135+
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
136+
[EMPTY_DICT_FUNCTION_CALL],
137+
id="empty_dict_streaming",
138+
),
139+
pytest.param(
140+
False,
141+
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
142+
[EMPTY_DICT_FUNCTION_CALL],
143+
id="empty_dict_nonstreaming",
144+
),
145+
pytest.param(
146+
True,
147+
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
148+
[EMPTY_LIST_FUNCTION_CALL],
149+
id="empty_list_streaming",
150+
),
151+
pytest.param(
152+
False,
153+
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
154+
[EMPTY_LIST_FUNCTION_CALL],
155+
id="empty_list_nonstreaming",
156+
),
157+
pytest.param(
158+
True,
159+
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
160+
[ESCAPED_STRING_FUNCTION_CALL],
161+
id="escaped_string_streaming",
162+
),
163+
pytest.param(
164+
False,
165+
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
166+
[ESCAPED_STRING_FUNCTION_CALL],
167+
id="escaped_string_nonstreaming",
168+
),
169+
pytest.param(
170+
True,
171+
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
172+
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
173+
id="parallel_calls_streaming",
174+
),
175+
pytest.param(
176+
False,
177+
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
178+
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
179+
id="parallel_calls_nonstreaming",
180+
),
181+
]
182+
183+
184+
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
185+
def test_tool_call(
186+
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
187+
):
188+
mock_tokenizer = MagicMock()
189+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
190+
191+
content, tool_calls = run_tool_extraction(
192+
tool_parser, model_output, streaming=streaming
193+
)
194+
195+
assert content is None
196+
assert len(tool_calls) == len(expected_tool_calls)
197+
for actual, expected in zip(tool_calls, expected_tool_calls):
198+
assert actual.type == "function"
199+
assert actual.function == expected
200+
201+
202+
def test_streaming_tool_call_with_large_steps():
203+
mock_tokenizer = MagicMock()
204+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
205+
model_output_deltas = [
206+
"<function_calls>get_weather(city='San",
207+
" Francisco', metric='celsius')\n"
208+
f"{PARAMETERLESS_FUNCTION_OUTPUT}\n"
209+
f"{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
210+
]
211+
212+
reconstructor = run_tool_extraction_streaming(
213+
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
214+
)
215+
216+
assert reconstructor.other_content == ""
217+
assert len(reconstructor.tool_calls) == 3
218+
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
219+
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
220+
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
221+
222+
223+
@pytest.mark.parametrize("streaming", [False])
224+
def test_regex_timeout_handling(streaming: bool):
225+
"""test regex timeout is handled gracefully"""
226+
mock_tokenizer = MagicMock()
227+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
228+
229+
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
230+
231+
# create a mock regex that raises TimeoutError
232+
mock_regex = MagicMock()
233+
mock_regex.match.side_effect = TimeoutError("Regex timeout")
234+
235+
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
236+
content, tool_calls = run_tool_extraction(
237+
tool_parser, fake_problematic_input, streaming=streaming
238+
)
239+
240+
# should treat as regular text when regex times out
241+
assert content == fake_problematic_input
242+
assert len(tool_calls) == 0
243+
mock_regex.match.assert_called_once()

0 commit comments

Comments
 (0)