Skip to content

Commit cfd4207

Browse files
mengwei805mengwei805
authored andcommitted
[v0.9.1-dev][CI/UT][bugfix]fix v0 spec decode
Signed-off-by: mengwei805 <mengwei25@huawei.com>
1 parent f1353d5 commit cfd4207

20 files changed

+377
-29
lines changed

.github/workflows/vllm_ascend_test_long_term.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ jobs:
9595
run: |
9696
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
9797
# spec decode test
98-
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py
98+
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py
9999
# TODO: revert me when test_v1_spec_decode.py::test_ngram_correctness is fixed
100-
# VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode/e2e/test_v1_spec_decode.py
101-
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process
102-
pytest -sv tests/long_term/spec_decode --ignore=tests/long_term/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/long_term/spec_decode/e2e/test_v1_spec_decode.py --ignore=tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py
100+
# VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode_v1/test_v1_spec_decode.py
101+
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode_v0/e2e/test_mtp_correctness.py # it needs a clean process
102+
pytest -sv tests/long_term/spec_decode_v0 --ignore=tests/long_term/spec_decode_v0/e2e/test_mtp_correctness.py
103103
pytest -sv tests/long_term/test_accuracy.py
104104
else
105105
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/test_deepseek_v2_lite_tp2_accuracy.py
Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_eagle_correctness.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
"""This docstring details important information on the testing methodology.
20+
21+
Most of the tests rely on "greedy equality", where we expect the output of
22+
speculative decoding on a sequence to exactly match the output of normal non-
23+
speculative decoding.
24+
25+
Since speculative decoding with rejection sampling guarantees that the output
26+
distribution matches the target model's output distribution (up to hardware
27+
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
28+
equality.
29+
30+
However, we still need to verify below scenario could be passed:
31+
* Batch size 1 greedy equality
32+
* Batch size >1 greedy equality
33+
* Test greedy equality under preemption
34+
* Test greedy equality under various number of speculative tokens.
35+
36+
With those tests, we can say at least, EAGLE would not break the
37+
correctness for the target model outputs.
38+
"""
39+
40+
import pytest
41+
42+
from tests.long_term.spec_decode_v0.e2e.conftest import \
43+
run_equality_correctness_test
44+
45+
# main model
46+
MAIN_MODEL = "JackFram/llama-68m"
47+
48+
# speculative model
49+
SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random"
50+
51+
# max. number of speculative tokens: this corresponds to
52+
# num_heads in the config.json of the speculator model.
53+
MAX_SPEC_TOKENS = 4
54+
55+
# precision
56+
PRECISION = "float32"
57+
58+
59+
@pytest.mark.parametrize(
60+
"common_llm_kwargs",
61+
[{
62+
# Skip cuda graph recording for fast test.
63+
"enforce_eager": True,
64+
65+
# Print spec metrics.
66+
"disable_log_stats": False,
67+
68+
# Precision
69+
"dtype": PRECISION,
70+
71+
# Main model
72+
"model_name": MAIN_MODEL,
73+
}])
74+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
75+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
76+
@pytest.mark.parametrize("test_llm_kwargs", [
77+
{
78+
"speculative_config": {
79+
"model": SPEC_MODEL,
80+
"num_speculative_tokens": MAX_SPEC_TOKENS,
81+
},
82+
},
83+
])
84+
@pytest.mark.parametrize("output_len", [
85+
128,
86+
])
87+
@pytest.mark.parametrize("batch_size", [1, 32])
88+
@pytest.mark.parametrize("seed", [1])
89+
def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
90+
per_test_common_llm_kwargs,
91+
baseline_llm_kwargs, test_llm_kwargs,
92+
batch_size: int, output_len: int,
93+
seed: int):
94+
95+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
96+
per_test_common_llm_kwargs,
97+
baseline_llm_kwargs, test_llm_kwargs,
98+
batch_size, output_len, seed)
99+
100+
101+
@pytest.mark.parametrize(
102+
"common_llm_kwargs",
103+
[{
104+
# Skip cuda graph recording for fast test.
105+
"enforce_eager": True,
106+
107+
# Print spec metrics.
108+
"disable_log_stats": False,
109+
110+
# Precision
111+
"dtype": PRECISION,
112+
113+
# Main model
114+
"model_name": MAIN_MODEL,
115+
}])
116+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
117+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
118+
@pytest.mark.parametrize("test_llm_kwargs", [{
119+
"speculative_config": {
120+
"model": SPEC_MODEL,
121+
"num_speculative_tokens": MAX_SPEC_TOKENS,
122+
"disable_logprobs": False,
123+
},
124+
}, {
125+
"speculative_config": {
126+
"model": SPEC_MODEL,
127+
"num_speculative_tokens": MAX_SPEC_TOKENS,
128+
"disable_logprobs": True,
129+
},
130+
}])
131+
@pytest.mark.parametrize("output_len", [
132+
128,
133+
])
134+
@pytest.mark.parametrize("batch_size", [8])
135+
@pytest.mark.parametrize("seed", [1])
136+
@pytest.mark.parametrize("logprobs", [1, 6])
137+
def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
138+
per_test_common_llm_kwargs,
139+
baseline_llm_kwargs, test_llm_kwargs,
140+
batch_size: int, output_len: int, seed: int,
141+
logprobs: int):
142+
143+
run_equality_correctness_test(
144+
vllm_runner,
145+
common_llm_kwargs,
146+
per_test_common_llm_kwargs,
147+
baseline_llm_kwargs,
148+
test_llm_kwargs,
149+
batch_size,
150+
output_len,
151+
seed,
152+
logprobs=logprobs,
153+
prompt_logprobs=logprobs,
154+
disable_logprobs=test_llm_kwargs["speculative_config"]
155+
["disable_logprobs"])
156+
157+
158+
@pytest.mark.skipif(True, reason="Open it when graph mode ready.")
159+
@pytest.mark.parametrize(
160+
"common_llm_kwargs",
161+
[{
162+
"enforce_eager": False,
163+
164+
# Print spec metrics.
165+
"disable_log_stats": False,
166+
167+
# Precision
168+
"dtype": PRECISION,
169+
170+
# Main model
171+
"model_name": MAIN_MODEL,
172+
}])
173+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
174+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
175+
@pytest.mark.parametrize("test_llm_kwargs", [
176+
{
177+
"speculative_config": {
178+
"model": SPEC_MODEL,
179+
"num_speculative_tokens": MAX_SPEC_TOKENS,
180+
},
181+
},
182+
])
183+
@pytest.mark.parametrize("output_len", [
184+
128,
185+
])
186+
@pytest.mark.parametrize("batch_size", [1, 32])
187+
@pytest.mark.parametrize("seed", [1])
188+
def test_eagle_e2e_greedy_correctness_cuda_graph(
189+
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
190+
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
191+
seed: int):
192+
"""Verify greedy equality with cuda graph enabled and different
193+
batch sizes."""
194+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
195+
per_test_common_llm_kwargs,
196+
baseline_llm_kwargs, test_llm_kwargs,
197+
batch_size, output_len, seed)
198+
199+
200+
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
201+
@pytest.mark.parametrize(
202+
"common_llm_kwargs",
203+
[{
204+
"block_size": 8,
205+
# 2 for small prompt, 256//8 for generated.
206+
"num_gpu_blocks_override": 2 + 256 // 8,
207+
"max_model_len": (2 + 256 // 8) * 8,
208+
209+
# Skip cuda graph recording for fast test.
210+
"enforce_eager": True,
211+
212+
# Precision
213+
"dtype": PRECISION,
214+
215+
# Main model
216+
"model_name": MAIN_MODEL,
217+
}])
218+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
219+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
220+
@pytest.mark.parametrize("test_llm_kwargs", [
221+
{
222+
"speculative_config": {
223+
"model": SPEC_MODEL,
224+
"num_speculative_tokens": MAX_SPEC_TOKENS,
225+
},
226+
},
227+
])
228+
@pytest.mark.parametrize(
229+
"output_len",
230+
[
231+
# Use small output len for fast test.
232+
128,
233+
])
234+
@pytest.mark.parametrize("batch_size", [4])
235+
@pytest.mark.parametrize("seed", [1])
236+
def test_eagle_e2e_greedy_correctness_with_preemption(
237+
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
238+
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
239+
seed: int):
240+
"""Verify greedy equality, even when some sequences are preempted mid-
241+
generation.
242+
"""
243+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
244+
per_test_common_llm_kwargs,
245+
baseline_llm_kwargs, test_llm_kwargs,
246+
batch_size, output_len, seed)
247+
248+
249+
@pytest.mark.parametrize(
250+
"common_llm_kwargs",
251+
[{
252+
# Skip cuda graph recording for fast test.
253+
"enforce_eager": True,
254+
255+
# Precision
256+
"dtype": PRECISION,
257+
258+
# Main model
259+
"model_name": MAIN_MODEL,
260+
}])
261+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
262+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
263+
@pytest.mark.parametrize(
264+
"test_llm_kwargs",
265+
[
266+
{
267+
"speculative_config": {
268+
"model": SPEC_MODEL,
269+
"num_speculative_tokens": k,
270+
},
271+
}
272+
# Try a range of num. speculative tokens
273+
for k in range(1, 1 + MAX_SPEC_TOKENS)
274+
])
275+
@pytest.mark.parametrize("batch_size", [2])
276+
@pytest.mark.parametrize(
277+
"output_len",
278+
[
279+
# Use smaller output len for fast test.
280+
32,
281+
])
282+
@pytest.mark.parametrize("seed", [1])
283+
def test_eagle_different_k(vllm_runner, common_llm_kwargs,
284+
per_test_common_llm_kwargs, baseline_llm_kwargs,
285+
test_llm_kwargs, batch_size: int, output_len: int,
286+
seed: int):
287+
"""Verify that eagle speculative decoding produces exact equality
288+
to without spec decode with different values of num_speculative_tokens.
289+
"""
290+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
291+
per_test_common_llm_kwargs,
292+
baseline_llm_kwargs, test_llm_kwargs,
293+
batch_size, output_len, seed)
294+
295+
296+
@pytest.mark.parametrize(
297+
"common_llm_kwargs",
298+
[{
299+
# Skip cuda graph recording for fast test.
300+
"enforce_eager": True,
301+
302+
# Precision
303+
"dtype": PRECISION,
304+
305+
# Main model
306+
"model_name": MAIN_MODEL,
307+
}])
308+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
309+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
310+
@pytest.mark.parametrize("test_llm_kwargs", [{
311+
"speculative_config": {
312+
"model": SPEC_MODEL,
313+
"num_speculative_tokens": MAX_SPEC_TOKENS,
314+
"disable_by_batch_size": 4,
315+
},
316+
}])
317+
@pytest.mark.parametrize("batch_size", [1, 5])
318+
@pytest.mark.parametrize(
319+
"output_len",
320+
[
321+
# Use smaller output len for fast test.
322+
32,
323+
])
324+
@pytest.mark.parametrize("seed", [1])
325+
def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
326+
per_test_common_llm_kwargs, baseline_llm_kwargs,
327+
test_llm_kwargs, batch_size: int, output_len: int,
328+
seed: int):
329+
"""Verify that eagle speculative decoding produces exact equality
330+
to without spec decode when speculation is disabled for large
331+
batch sizes.
332+
"""
333+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
334+
per_test_common_llm_kwargs,
335+
baseline_llm_kwargs, test_llm_kwargs,
336+
batch_size, output_len, seed)
337+
338+
339+
if __name__ == "__main__":
340+
import pytest
341+
pytest.main([__file__])

tests/long_term/spec_decode/e2e/test_medusa_correctness.py renamed to tests/long_term/spec_decode_v0/e2e/test_medusa_correctness.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141

4242
import pytest
4343

44-
from tests.long_term.spec_decode.e2e.conftest import \
44+
from tests.long_term.spec_decode_v0.e2e.conftest import \
4545
run_equality_correctness_test
46-
from tests.long_term.spec_decode.utils import maybe_enable_chunked_prefill
46+
from tests.long_term.spec_decode_v0.utils import \
47+
maybe_enable_chunked_prefill
4748

4849
# main model
4950
# lmsys/vicuna-7b-v1.3 was to be used but it's causing

tests/long_term/spec_decode/e2e/test_mlp_correctness.py renamed to tests/long_term/spec_decode_v0/e2e/test_mlp_correctness.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141
from vllm.model_executor.layers.vocab_parallel_embedding import \
4242
pad_vocab_size # noqa: F401
4343

44-
from tests.long_term.spec_decode.e2e.conftest import \
44+
from tests.long_term.spec_decode_v0.e2e.conftest import \
4545
run_equality_correctness_test
46-
from tests.long_term.spec_decode.utils import maybe_enable_chunked_prefill
46+
from tests.long_term.spec_decode_v0.utils import \
47+
maybe_enable_chunked_prefill
4748

4849
# main model
4950
MAIN_MODEL = "JackFram/llama-160m"

tests/long_term/spec_decode/e2e/test_ngram_correctness.py renamed to tests/long_term/spec_decode_v0/e2e/test_ngram_correctness.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@
4444

4545
import pytest
4646

47-
from tests.long_term.spec_decode.e2e.conftest import \
47+
from tests.long_term.spec_decode_v0.e2e.conftest import \
4848
run_equality_correctness_test
49-
from tests.long_term.spec_decode.utils import maybe_enable_chunked_prefill
49+
from tests.long_term.spec_decode_v0.utils import \
50+
maybe_enable_chunked_prefill
5051

5152

5253
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)