Skip to content

Commit 257e200

Browse files
[V1][Frontend] Add Testing For V1 Runtime Parameters (#14159)
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
1 parent 47d4a7e commit 257e200

File tree

3 files changed

+201
-17
lines changed

3 files changed

+201
-17
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
4+
import pytest
5+
6+
from vllm import LLM, SamplingParams
7+
8+
if os.getenv("VLLM_USE_V1", "0") != "1":
9+
pytest.skip("Test package requires V1", allow_module_level=True)
10+
11+
MODEL = "meta-llama/Llama-3.2-1B"
12+
PROMPT = "Hello my name is Robert and I"
13+
14+
15+
@pytest.fixture(scope="module")
16+
def model() -> LLM:
17+
return LLM(MODEL, enforce_eager=True)
18+
19+
20+
def test_n_gt_1(model):
21+
"""ParallelSampling is supported."""
22+
23+
params = SamplingParams(n=3)
24+
outputs = model.generate(PROMPT, params)
25+
assert len(outputs[0].outputs) == 3
26+
27+
28+
def test_best_of(model):
29+
"""Raise a ValueError since best_of is deprecated."""
30+
31+
params = SamplingParams(n=2, best_of=3)
32+
with pytest.raises(ValueError):
33+
_ = model.generate(PROMPT, params)
34+
35+
36+
def test_penalties(model):
37+
"""Check that we do not get errors if applied."""
38+
39+
params = SamplingParams(
40+
temperature=1.2,
41+
presence_penalty=1.2,
42+
frequency_penalty=1.2,
43+
repetition_penalty=1.2,
44+
min_p=0.5,
45+
top_p=0.5,
46+
top_k=3,
47+
)
48+
_ = model.generate(PROMPT, params)
49+
50+
51+
def test_stop(model):
52+
"""Check that we respect the stop words."""
53+
54+
output = model.generate(PROMPT, SamplingParams(temperature=0))
55+
split_text = output[0].outputs[0].text.split()
56+
57+
STOP_IDX = 5
58+
params = SamplingParams(temperature=0, stop=split_text[STOP_IDX])
59+
output = model.generate(PROMPT, params)
60+
new_split_text = output[0].outputs[0].text.split()
61+
62+
# Output should not contain the stop word.
63+
assert len(new_split_text) == STOP_IDX
64+
65+
params = SamplingParams(temperature=0,
66+
stop=split_text[STOP_IDX],
67+
include_stop_str_in_output=True)
68+
output = model.generate(PROMPT, params)
69+
new_split_text = output[0].outputs[0].text.split()
70+
71+
# Output should contain the stop word.
72+
assert len(new_split_text) == STOP_IDX + 1
73+
74+
75+
def test_stop_token_ids(model):
76+
"""Check that we respect the stop token ids."""
77+
78+
output = model.generate(PROMPT, SamplingParams(temperature=0))
79+
80+
stop_token_id_0 = output[0].outputs[0].token_ids[5]
81+
stop_token_id_1 = output[0].outputs[0].token_ids[6]
82+
83+
stop_token_ids = [stop_token_id_1, stop_token_id_0]
84+
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
85+
output = model.generate(PROMPT, params)
86+
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0
87+
88+
stop_token_ids = [stop_token_id_0, stop_token_id_1]
89+
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
90+
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0
91+
92+
93+
def test_bad_words(model):
94+
"""Check that we respect bad words."""
95+
96+
with pytest.raises(ValueError):
97+
_ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"]))
98+
99+
100+
def test_logits_processor(model):
101+
"""Check that we reject logits processor."""
102+
103+
# This sample logits processor gives infinite score to the i-th token,
104+
# where i is the length of the input sequence.
105+
# We therefore expect the output token sequence to be [0, 1, 2, ...]
106+
def pick_ith(token_ids, logits):
107+
logits[len(token_ids)] = float("inf")
108+
return logits
109+
110+
with pytest.raises(ValueError):
111+
_ = model.generate(PROMPT,
112+
SamplingParams(logits_processors=[pick_ith]))
113+
114+
115+
def test_allowed_token_ids(model):
116+
"""Check that we can use allowed_token_ids."""
117+
118+
TOKEN_ID = 10
119+
allowed_token_ids = [TOKEN_ID]
120+
output = model.generate(
121+
PROMPT, SamplingParams(allowed_token_ids=allowed_token_ids))
122+
assert output[0].outputs[0].token_ids[-1] == TOKEN_ID
123+
124+
# Reject negative token id.
125+
with pytest.raises(ValueError):
126+
_ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[-1]))
127+
128+
# Reject out of vocabulary.
129+
with pytest.raises(ValueError):
130+
_ = model.generate(PROMPT,
131+
SamplingParams(allowed_token_ids=[10000000]))
132+
133+
134+
def test_priority(model):
135+
"""Check that we reject requests with priority."""
136+
137+
# Reject all allowed token ids
138+
with pytest.raises(ValueError):
139+
_ = model.generate(PROMPT, priority=[1])
140+
141+
142+
def test_seed(model):
143+
"""Check that seed impacts randomness."""
144+
145+
out_1 = model.generate(PROMPT, SamplingParams(seed=42))
146+
out_2 = model.generate(PROMPT, SamplingParams(seed=42))
147+
out_3 = model.generate(PROMPT, SamplingParams(seed=43))
148+
149+
assert out_1[0].outputs[0].text == out_2[0].outputs[0].text
150+
assert out_1[0].outputs[0].text != out_3[0].outputs[0].text

vllm/v1/engine/processor.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,8 @@ def __init__(
5555

5656
def _validate_logprobs(
5757
self,
58-
params: Union[SamplingParams, PoolingParams],
58+
params: SamplingParams,
5959
) -> None:
60-
if not isinstance(params, SamplingParams):
61-
return
62-
6360
max_logprobs = self.model_config.max_logprobs
6461
# Validate sample logprobs.
6562
if params.logprobs and params.logprobs > max_logprobs:
@@ -79,17 +76,10 @@ def _validate_logprobs(
7976
raise ValueError("Prefix caching with prompt logprobs not yet "
8077
"supported on VLLM V1.")
8178

82-
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
83-
if lora_request is not None and not self.lora_config:
84-
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
85-
"not enabled!")
86-
87-
def _validate_allowed_token_ids(
79+
def _validate_sampling_params(
8880
self,
89-
params: Union[SamplingParams, PoolingParams],
81+
params: SamplingParams,
9082
) -> None:
91-
if not isinstance(params, SamplingParams):
92-
return
9383
if params.allowed_token_ids is None:
9484
return
9585
if not params.allowed_token_ids:
@@ -99,6 +89,42 @@ def _validate_allowed_token_ids(
9989
raise ValueError(
10090
"allowed_token_ids contains out-of-vocab token id!")
10191

92+
def _validate_supported_sampling_params(
93+
self,
94+
params: SamplingParams,
95+
) -> None:
96+
# Best of not yet supported.
97+
if params.best_of:
98+
raise ValueError("VLLM V1 does not yet support best_of.")
99+
# Bad words not yet supported.
100+
if params.bad_words:
101+
raise ValueError("VLLM V1 does not yet support bad_words.")
102+
# Logits processors not supported.
103+
if params.logits_processors:
104+
raise ValueError("VLLM V1 does not support per request "
105+
"user provided logits processors.")
106+
107+
def _validate_params(
108+
self,
109+
params: Union[SamplingParams, PoolingParams],
110+
):
111+
"""
112+
Validate supported SamplingParam.
113+
Should raise ValueError if unsupported for API Server.
114+
"""
115+
116+
if not isinstance(params, SamplingParams):
117+
raise ValueError("V1 does not yet support Pooling models.")
118+
119+
self._validate_logprobs(params)
120+
self._validate_sampling_params(params)
121+
self._validate_supported_sampling_params(params)
122+
123+
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
124+
if lora_request is not None and not self.lora_config:
125+
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
126+
"not enabled!")
127+
102128
def process_inputs(
103129
self,
104130
request_id: str,
@@ -114,14 +140,17 @@ def process_inputs(
114140
# TODO(woosuk): Support pooling models.
115141
# TODO(woosuk): Support encoder-decoder models.
116142

117-
self._validate_logprobs(params)
118143
self._validate_lora(lora_request)
119-
self._validate_allowed_token_ids(params)
144+
self._validate_params(params)
145+
if priority != 0:
146+
raise ValueError("V1 does not support priority yet.")
147+
if trace_headers is not None:
148+
raise ValueError("V1 does not support tracing yet.")
149+
if prompt_adapter_request is not None:
150+
raise ValueError("V1 does not support prompt_adapter_request.")
120151

121152
if arrival_time is None:
122153
arrival_time = time.time()
123-
assert priority == 0, "vLLM V1 does not support priority at the moment."
124-
assert trace_headers is None, "vLLM V1 does not support tracing yet."
125154

126155
# Process inputs, which includes:
127156
# 1. Tokenize text prompt, with LoRA request if one exists.

vllm/v1/worker/gpu_input_batch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,11 @@ def add_request(
298298
if sampling_params.logit_bias is not None:
299299
self.logit_bias[req_index] = sampling_params.logit_bias
300300

301+
# FIXME: this implementation is incorrect. We create this mask
302+
# then apply -inf to these specific tokens, which means we never
303+
# select the allowed tokens! We cannot do the reverse, since
304+
# this will impact the requests that do not have allowed_token_ids.
305+
# This feature is currently disabled on V1 (we reject in Processor).
301306
if sampling_params.allowed_token_ids:
302307
self.has_allowed_token_ids.add(req_id)
303308
if self.allowed_token_ids_mask_cpu_tensor is None:

0 commit comments

Comments
 (0)