Skip to content

Commit 2a40cef

Browse files
committed
[3/N][CI/UT] add spec decode e2e UT
Signed-off-by: mengwei805 <mengwei25@huawei.com>
1 parent edeadde commit 2a40cef

File tree

13 files changed

+834
-33
lines changed

13 files changed

+834
-33
lines changed

tests/spec_decode/e2e/test_eagle_correctness.py

Lines changed: 483 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_mtp_correctness.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
"""This docstring details important information on the testing methodology.
20+
21+
Most of the tests rely on "greedy equality", where we expect the output of
22+
speculative decoding on a sequence to exactly match the output of normal non-
23+
speculative decoding.
24+
25+
Since speculative decoding with rejection sampling guarantees that the output
26+
distribution matches the target model's output distribution (up to hardware
27+
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
28+
equality.
29+
30+
However, we still need to verify below scenario could be passed:
31+
* Batch size 1 greedy equality
32+
* Batch size >1 greedy equality
33+
* Test greedy equality under preemption
34+
* Test greedy equality under various number of speculative tokens.
35+
36+
With those tests, we can say at least, mtp would not break the
37+
correctess for the target model outputs.
38+
"""
39+
40+
import pytest
41+
42+
from .conftest import run_equality_correctness_test
43+
44+
# main model
45+
# NOTE vLLM use fp8 model, vllm-ascend use bf16 model
46+
MAIN_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
47+
48+
# max. number of speculative tokens: this corresponds to
49+
# num_nextn_predict_layers in the config.json of the speculator model.
50+
MAX_SPEC_TOKENS = 1
51+
52+
# NOTE vLLM use bfloat, vllm-ascend should use float32.
53+
# when vllm-ascend support e2e float32, it should be set to float32.
54+
# precision
55+
PRECISION = "bfloat16"
56+
57+
58+
@pytest.mark.parametrize(
59+
"common_llm_kwargs",
60+
[{
61+
# Skip cuda graph recording for fast test.
62+
"enforce_eager": True,
63+
64+
# Print spec metrics.
65+
"disable_log_stats": False,
66+
67+
# Precision
68+
"dtype": PRECISION,
69+
70+
# Main model
71+
"model_name": MAIN_MODEL,
72+
73+
# GPU memory utilization
74+
"gpu_memory_utilization": 0.85
75+
}])
76+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
77+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
78+
@pytest.mark.parametrize("test_llm_kwargs", [
79+
{
80+
"num_speculative_tokens": MAX_SPEC_TOKENS,
81+
},
82+
])
83+
@pytest.mark.parametrize("output_len", [
84+
128,
85+
])
86+
@pytest.mark.parametrize("batch_size", [1, 32])
87+
@pytest.mark.parametrize("seed", [1])
88+
def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
89+
per_test_common_llm_kwargs,
90+
baseline_llm_kwargs, test_llm_kwargs,
91+
batch_size: int, output_len: int,
92+
seed: int):
93+
94+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
95+
per_test_common_llm_kwargs,
96+
baseline_llm_kwargs, test_llm_kwargs,
97+
batch_size, output_len, seed)
98+
99+
100+
@pytest.mark.parametrize(
101+
"common_llm_kwargs",
102+
[{
103+
# Skip cuda graph recording for fast test.
104+
"enforce_eager": True,
105+
106+
# Print spec metrics.
107+
"disable_log_stats": False,
108+
109+
# Precision
110+
"dtype": PRECISION,
111+
112+
# Main model
113+
"model_name": MAIN_MODEL,
114+
115+
# GPU memory utilization
116+
"gpu_memory_utilization": 0.85
117+
}])
118+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
119+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
120+
@pytest.mark.parametrize("test_llm_kwargs", [
121+
{
122+
"num_speculative_tokens": MAX_SPEC_TOKENS,
123+
"disable_logprobs_during_spec_decoding": False,
124+
},
125+
{
126+
"num_speculative_tokens": MAX_SPEC_TOKENS,
127+
"disable_logprobs_during_spec_decoding": True,
128+
},
129+
])
130+
@pytest.mark.parametrize("output_len", [
131+
128,
132+
])
133+
@pytest.mark.parametrize("batch_size", [8])
134+
@pytest.mark.parametrize("seed", [1])
135+
@pytest.mark.parametrize("logprobs", [1, 6])
136+
def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
137+
per_test_common_llm_kwargs,
138+
baseline_llm_kwargs, test_llm_kwargs,
139+
batch_size: int, output_len: int, seed: int,
140+
logprobs: int):
141+
142+
run_equality_correctness_test(vllm_runner,
143+
common_llm_kwargs,
144+
per_test_common_llm_kwargs,
145+
baseline_llm_kwargs,
146+
test_llm_kwargs,
147+
batch_size,
148+
output_len,
149+
seed,
150+
logprobs=logprobs,
151+
prompt_logprobs=logprobs,
152+
disable_logprobs=test_llm_kwargs[
153+
'disable_logprobs_during_spec_decoding'])
154+
155+
156+
# TODO: Open it when vllm-ascend support graph mode and
157+
# support enforce_eager status is False to run model in graph mode
158+
# @pytest.mark.parametrize(
159+
# "common_llm_kwargs",
160+
# [{
161+
# "enforce_eager": False,
162+
163+
# # Print spec metrics.
164+
# "disable_log_stats": False,
165+
166+
# # Precision
167+
# "dtype": PRECISION,
168+
169+
# # Main model
170+
# "model_name": MAIN_MODEL,
171+
# "gpu_memory_utilization": 0.85
172+
# }])
173+
# @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
174+
# @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
175+
# @pytest.mark.parametrize("test_llm_kwargs", [
176+
# {
177+
# "num_speculative_tokens": MAX_SPEC_TOKENS,
178+
# },
179+
# ])
180+
# @pytest.mark.parametrize("output_len", [
181+
# 128,
182+
# ])
183+
# @pytest.mark.parametrize("batch_size", [1, 32])
184+
# @pytest.mark.parametrize("seed", [1])
185+
# def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
186+
# per_test_common_llm_kwargs,
187+
# baseline_llm_kwargs,
188+
# test_llm_kwargs,
189+
# batch_size: int,
190+
# output_len: int, seed: int):
191+
# """Verify greedy equality with cuda graph enabled and different
192+
# batch sizes."""
193+
# run_equality_correctness_test(vllm_runner, common_llm_kwargs,
194+
# per_test_common_llm_kwargs,
195+
# baseline_llm_kwargs, test_llm_kwargs,
196+
# batch_size, output_len, seed)
197+
198+
199+
@pytest.mark.parametrize(
200+
"common_llm_kwargs",
201+
[{
202+
"block_size": 8,
203+
# 2 for small prompt, 256//8 for generated.
204+
"num_gpu_blocks_override": 2 + 256 // 8,
205+
"max_model_len": (2 + 256 // 8) * 8,
206+
207+
# Skip cuda graph recording for fast test.
208+
"enforce_eager": True,
209+
210+
# Precision
211+
"dtype": PRECISION,
212+
213+
# Main model
214+
"model_name": MAIN_MODEL,
215+
216+
# GPU memory utilization
217+
"gpu_memory_utilization": 0.9
218+
}])
219+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
220+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
221+
@pytest.mark.parametrize("test_llm_kwargs", [
222+
{
223+
"num_speculative_tokens": MAX_SPEC_TOKENS,
224+
},
225+
])
226+
@pytest.mark.parametrize(
227+
"output_len",
228+
[
229+
# Use small output len for fast test.
230+
128,
231+
])
232+
@pytest.mark.parametrize("batch_size", [4])
233+
@pytest.mark.parametrize("seed", [1])
234+
def test_mtp_e2e_greedy_correctness_with_preemption(
235+
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
236+
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
237+
seed: int):
238+
"""Verify greedy equality, even when some sequences are preempted mid-
239+
generation.
240+
"""
241+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
242+
per_test_common_llm_kwargs,
243+
baseline_llm_kwargs, test_llm_kwargs,
244+
batch_size, output_len, seed)
245+
246+
247+
@pytest.mark.parametrize(
248+
"common_llm_kwargs",
249+
[{
250+
# Skip cuda graph recording for fast test.
251+
"enforce_eager": True,
252+
253+
# Precision
254+
"dtype": PRECISION,
255+
256+
# Main model
257+
"model_name": MAIN_MODEL,
258+
259+
# GPU memory utilization
260+
"gpu_memory_utilization": 0.9
261+
}])
262+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
263+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
264+
@pytest.mark.parametrize(
265+
"test_llm_kwargs",
266+
[
267+
{
268+
"num_speculative_tokens": k,
269+
}
270+
# Try a range of num. speculative tokens
271+
for k in range(1, 1 + MAX_SPEC_TOKENS)
272+
])
273+
@pytest.mark.parametrize("batch_size", [2])
274+
@pytest.mark.parametrize(
275+
"output_len",
276+
[
277+
# Use smaller output len for fast test.
278+
32,
279+
])
280+
@pytest.mark.parametrize("seed", [1])
281+
def test_mtp_different_k(vllm_runner, common_llm_kwargs,
282+
per_test_common_llm_kwargs, baseline_llm_kwargs,
283+
test_llm_kwargs, batch_size: int, output_len: int,
284+
seed: int):
285+
"""Verify that mtp speculative decoding produces exact equality
286+
to without spec decode with different values of num_speculative_tokens.
287+
"""
288+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
289+
per_test_common_llm_kwargs,
290+
baseline_llm_kwargs, test_llm_kwargs,
291+
batch_size, output_len, seed)
292+
293+
294+
@pytest.mark.parametrize(
295+
"common_llm_kwargs",
296+
[{
297+
# Skip cuda graph recording for fast test.
298+
"enforce_eager": True,
299+
300+
# Precision
301+
"dtype": PRECISION,
302+
303+
# Main model
304+
"model_name": MAIN_MODEL,
305+
306+
# GPU memory utilization
307+
"gpu_memory_utilization": 0.9
308+
}])
309+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
310+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
311+
@pytest.mark.parametrize("test_llm_kwargs",
312+
[{
313+
"num_speculative_tokens": MAX_SPEC_TOKENS,
314+
"speculative_disable_by_batch_size": 4
315+
}])
316+
@pytest.mark.parametrize("batch_size", [1, 5])
317+
@pytest.mark.parametrize(
318+
"output_len",
319+
[
320+
# Use smaller output len for fast test.
321+
32,
322+
])
323+
@pytest.mark.parametrize("seed", [1])
324+
def test_mtp_disable_queue(vllm_runner, common_llm_kwargs,
325+
per_test_common_llm_kwargs, baseline_llm_kwargs,
326+
test_llm_kwargs, batch_size: int, output_len: int,
327+
seed: int):
328+
"""Verify that mtp speculative decoding produces exact equality
329+
to without spec decode when speculation is disabled for large
330+
batch sizes.
331+
"""
332+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
333+
per_test_common_llm_kwargs,
334+
baseline_llm_kwargs, test_llm_kwargs,
335+
batch_size, output_len, seed)
336+
337+
338+
if __name__ == "__main__":
339+
import pytest
340+
pytest.main([__file__])

vllm_ascend/device_allocator/camem.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import torch
2525
from acl.rt import memcpy # type: ignore # noqa: F401
26-
from vllm.logger import init_logger
26+
from vllm.logger import logger
2727

2828
try:
2929
import torch_npu # noqa: F401
@@ -32,8 +32,6 @@
3232

3333
from vllm.utils import is_pin_memory_available
3434

35-
logger = init_logger(__name__)
36-
3735

3836
def find_loaded_library(lib_name) -> Optional[str]:
3937
"""

vllm_ascend/patch/patch_spec_decode_worker.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Any, Dict, Optional
1919

2020
from vllm.config import ParallelConfig
21-
from vllm.logger import init_logger
21+
from vllm.logger import logger
2222
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
2323
from vllm.model_executor.layers.spec_decode_base_sampler import \
2424
SpecDecodeBaseSampler
@@ -34,8 +34,6 @@
3434

3535
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
3636

37-
logger = init_logger(__name__)
38-
3937

4038
def create_worker(
4139
cls,

vllm_ascend/platform.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch_npu # noqa: F401
2424
import vllm.envs as envs
2525
from vllm.config import CompilationLevel, VllmConfig
26-
from vllm.logger import init_logger
26+
from vllm.logger import logger
2727

2828
try:
2929
# register custom ops into torch_library here
@@ -46,8 +46,6 @@
4646

4747
os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1"
4848

49-
logger = init_logger(__name__)
50-
5149

5250
class NPUPlatform(Platform):
5351

0 commit comments

Comments
 (0)