Skip to content

Commit 6dcb8ed

Browse files
committed
[MTP] follow custom deepseek modeling changes to support graph mode
Signed-off-by: mengwei805 <mengwei25@huawei.com>
1 parent 2e20797 commit 6dcb8ed

File tree

15 files changed

+293
-39
lines changed

15 files changed

+293
-39
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,18 @@ jobs:
138138
speculative_tests_changed:
139139
- "tests/singlecard/spec_decode/**"
140140
- "tests/multicard/spec_decode_e2e/**"
141+
- "vllm_ascend/worker/worker.py"
142+
- "vllm_ascend/worker/model_runner.py"
141143
- "vllm_ascend/worker/multi_step_runner.py"
142144
- "vllm_ascend/worker/multi_step_worker.py"
143-
- "vllm_ascend/patch/patch_rejection_sampler.py"
144-
- "vllm_ascend/patch/patch_spec_decode_worker.py"
145-
- "vllm_ascend/patch/patch_multi_step_worker.py"
145+
- "vllm_ascend/worker/draft_model_runner.py"
146+
- "vllm_ascend/patch/worker/patch_common/patch_metrics.py"
147+
- "vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py"
148+
- "vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py"
146149
147150
- name: Run vllm-project/vllm-ascend Speculative Decode test
151+
env:
152+
VLLM_USE_V1: 0
148153
if: steps.filter_spec_decode.outputs.speculative_tests_changed == 'true'
149154
run: |
150155
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from vllm_ascend.patch.worker import patch_common # noqa: F401
19+
from vllm_ascend.patch.worker import patch_main # noqa: F401
20+
from vllm_ascend.utils import vllm_version_is
21+
22+
if vllm_version_is("0.8.4"):
23+
from vllm_ascend.patch.worker import patch_0_8_4 # noqa: F401

tests/singlecard/spec_decode/e2e/conftest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
# limitations under the License.
1818
#
1919

20+
import shutil
2021
from itertools import cycle
22+
from pathlib import Path
2123
from typing import List, Optional, Sequence, Tuple, Union
2224

2325
import pytest
@@ -177,6 +179,12 @@ def _check_logprobs_when_output_disabled(
177179
assert spec_pos_logprob_token_id in baseline_pos_logprobs
178180

179181

182+
def _clean_torchair_cache():
183+
cache_path = Path.cwd() / '.torchair_cache'
184+
if cache_path.exists() and cache_path.is_dir():
185+
shutil.rmtree(cache_path)
186+
187+
180188
def run_equality_correctness_test(
181189
vllm_runner,
182190
common_llm_kwargs,
@@ -219,10 +227,20 @@ def run_equality_correctness_test(
219227
logprobs=logprobs,
220228
prompt_logprobs=prompt_logprobs)
221229

230+
# TODO current torchair graph mode needs clean torchair cache.
231+
# if do not clean, it will raise error
232+
additional_config = common_llm_kwargs.get("additional_config")
233+
enable_graph_mode = additional_config.get(
234+
"enable_graph_mode") if additional_config else False
235+
222236
with vllm_runner(**org_args) as vllm_model:
237+
if enable_graph_mode:
238+
_clean_torchair_cache()
223239
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
224240

225241
with vllm_runner(**sd_args) as vllm_model:
242+
if enable_graph_mode:
243+
_clean_torchair_cache()
226244
if ensure_all_accepted or expected_acceptance_rate is not None:
227245
# Force log interval to be 0 to catch all metrics.
228246
stat_logger = vllm_model.model.llm_engine.stat_loggers[

tests/singlecard/spec_decode/e2e/test_mtp_correctness.py

Lines changed: 125 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,32 @@
3636
With those tests, we can say at least, mtp would not break the
3737
correctess for the target model outputs.
3838
"""
39+
import os
3940

4041
import pytest
4142

4243
from .conftest import run_equality_correctness_test
4344

44-
# main model
45-
# NOTE vLLM use fp8 model, vllm-ascend use bf16 model
46-
MAIN_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
45+
# NOTE both main model and MTP are bfloat16
46+
FLOAT_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
47+
48+
# NOTE main model is w8a8, MTP is bfloat16
49+
QUANT_MODEL = "wemaster/deepseek_mtp_main_random_w8a8_part"
50+
51+
# TODO when msmodelslim can quantify both main and MTP model
52+
# This UT should use w8a8 fully weights.
4753

4854
# max. number of speculative tokens: this corresponds to
4955
# num_nextn_predict_layers in the config.json of the speculator model.
5056
MAX_SPEC_TOKENS = 1
5157

5258
# precision
5359
PRECISION = "bfloat16"
60+
os.environ["VLLM_USE_MODELSCOPE"] = "True"
5461

5562

63+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
64+
reason="mtp is not supported on v1")
5665
@pytest.mark.parametrize(
5766
"common_llm_kwargs",
5867
[{
@@ -66,7 +75,7 @@
6675
"dtype": PRECISION,
6776
6877
# Main model
69-
"model_name": MAIN_MODEL,
78+
"model_name": FLOAT_MODEL,
7079
7180
# GPU memory utilization
7281
"gpu_memory_utilization": 0.85
@@ -97,6 +106,7 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
97106
batch_size, output_len, seed)
98107

99108

109+
@pytest.mark.skipif(True, reason="quant model is not ready.")
100110
@pytest.mark.parametrize(
101111
"common_llm_kwargs",
102112
[{
@@ -110,7 +120,53 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
110120
"dtype": PRECISION,
111121
112122
# Main model
113-
"model_name": MAIN_MODEL,
123+
"model_name": QUANT_MODEL,
124+
125+
# GPU memory utilization
126+
"gpu_memory_utilization": 0.85
127+
}])
128+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
129+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
130+
@pytest.mark.parametrize("test_llm_kwargs", [
131+
{
132+
"speculative_config": {
133+
"num_speculative_tokens": MAX_SPEC_TOKENS,
134+
},
135+
},
136+
])
137+
@pytest.mark.parametrize("output_len", [
138+
128,
139+
])
140+
@pytest.mark.parametrize("batch_size", [1, 32])
141+
@pytest.mark.parametrize("seed", [1])
142+
def test_mtp_e2e_quant_greedy_correctness(vllm_runner, common_llm_kwargs,
143+
per_test_common_llm_kwargs,
144+
baseline_llm_kwargs, test_llm_kwargs,
145+
batch_size: int, output_len: int,
146+
seed: int):
147+
148+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
149+
per_test_common_llm_kwargs,
150+
baseline_llm_kwargs, test_llm_kwargs,
151+
batch_size, output_len, seed)
152+
153+
154+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
155+
reason="mtp is not supported on v1")
156+
@pytest.mark.parametrize(
157+
"common_llm_kwargs",
158+
[{
159+
# Skip cuda graph recording for fast test.
160+
"enforce_eager": True,
161+
162+
# Print spec metrics.
163+
"disable_log_stats": False,
164+
165+
# Precision
166+
"dtype": PRECISION,
167+
168+
# Main model
169+
"model_name": FLOAT_MODEL,
114170
115171
# GPU memory utilization
116172
"gpu_memory_utilization": 0.85
@@ -158,15 +214,13 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
158214
["disable_logprobs"])
159215

160216

161-
@pytest.mark.skipif(
162-
True,
163-
reason=
164-
"Open it when vllm-ascend support graph mode and support enforce_eager status is False to run model in graph mode"
165-
)
217+
@pytest.mark.skipif(True, reason="torchair ut can not clean mem.")
166218
@pytest.mark.parametrize(
167219
"common_llm_kwargs",
168220
[{
169-
"enforce_eager": False,
221+
"additional_config": {
222+
'enable_graph_mode': True,
223+
},
170224
171225
# Print spec metrics.
172226
"disable_log_stats": False,
@@ -175,7 +229,7 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
175229
"dtype": PRECISION,
176230
177231
# Main model
178-
"model_name": MAIN_MODEL,
232+
"model_name": FLOAT_MODEL,
179233
"gpu_memory_utilization": 0.85
180234
}])
181235
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@@ -192,20 +246,64 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
192246
])
193247
@pytest.mark.parametrize("batch_size", [1, 32])
194248
@pytest.mark.parametrize("seed", [1])
195-
def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
196-
per_test_common_llm_kwargs,
197-
baseline_llm_kwargs,
198-
test_llm_kwargs,
199-
batch_size: int,
200-
output_len: int, seed: int):
201-
"""Verify greedy equality with cuda graph enabled and different
202-
batch sizes."""
249+
def test_mtp_e2e_greedy_correctness_torchair_graph(
250+
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
251+
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
252+
seed: int):
253+
"""Verify greedy equality with torchair graph enabled and different
254+
batch sizes using bfloat16 weights."""
255+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
256+
per_test_common_llm_kwargs,
257+
baseline_llm_kwargs, test_llm_kwargs,
258+
batch_size, output_len, seed)
259+
260+
261+
@pytest.mark.skipif(True, reason="quant model is not ready.")
262+
@pytest.mark.parametrize(
263+
"common_llm_kwargs",
264+
[{
265+
"additional_config": {
266+
'enable_graph_mode': True,
267+
},
268+
269+
# Print spec metrics.
270+
"disable_log_stats": False,
271+
272+
# Precision
273+
"dtype": PRECISION,
274+
275+
# Main model
276+
"model_name": QUANT_MODEL,
277+
"gpu_memory_utilization": 0.85
278+
}])
279+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
280+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
281+
@pytest.mark.parametrize("test_llm_kwargs", [
282+
{
283+
"speculative_config": {
284+
"num_speculative_tokens": MAX_SPEC_TOKENS,
285+
},
286+
},
287+
])
288+
@pytest.mark.parametrize("output_len", [
289+
128,
290+
])
291+
@pytest.mark.parametrize("batch_size", [1, 32])
292+
@pytest.mark.parametrize("seed", [1])
293+
def test_mtp_e2e_quant_greedy_correctness_torchair_graph(
294+
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
295+
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
296+
seed: int):
297+
"""Verify greedy equality with torchair graph enabled and different
298+
batch sizes using quant weights."""
203299
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
204300
per_test_common_llm_kwargs,
205301
baseline_llm_kwargs, test_llm_kwargs,
206302
batch_size, output_len, seed)
207303

208304

305+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
306+
reason="mtp is not supported on v1")
209307
@pytest.mark.parametrize(
210308
"common_llm_kwargs",
211309
[{
@@ -221,7 +319,7 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
221319
"dtype": PRECISION,
222320
223321
# Main model
224-
"model_name": MAIN_MODEL,
322+
"model_name": FLOAT_MODEL,
225323
226324
# GPU memory utilization
227325
"gpu_memory_utilization": 0.9
@@ -256,6 +354,8 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
256354
batch_size, output_len, seed)
257355

258356

357+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
358+
reason="mtp is not supported on v1")
259359
@pytest.mark.parametrize(
260360
"common_llm_kwargs",
261361
[{
@@ -266,7 +366,7 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
266366
"dtype": PRECISION,
267367
268368
# Main model
269-
"model_name": MAIN_MODEL,
369+
"model_name": FLOAT_MODEL,
270370
271371
# GPU memory utilization
272372
"gpu_memory_utilization": 0.9
@@ -305,6 +405,8 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
305405
batch_size, output_len, seed)
306406

307407

408+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
409+
reason="mtp is not supported on v1")
308410
@pytest.mark.parametrize(
309411
"common_llm_kwargs",
310412
[{
@@ -315,7 +417,7 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
315417
"dtype": PRECISION,
316418
317419
# Main model
318-
"model_name": MAIN_MODEL,
420+
"model_name": FLOAT_MODEL,
319421
320422
# GPU memory utilization
321423
"gpu_memory_utilization": 0.9

tests/singlecard/spec_decode/test_dynamic_spec_decode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler
3131
from tests.singlecard.spec_decode.utils import create_batch, mock_worker
32-
from vllm_ascend.patch.worker import patch_common # noqa: F401
3332

3433

3534
@pytest.mark.parametrize('queue_size', [4])

tests/singlecard/spec_decode/test_multi_step_worker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
assert_logprobs_dict_allclose, create_batch,
3434
create_seq_group_metadata_from_prompts, create_worker,
3535
patch_execute_model_with_seeds, zero_kv_cache)
36-
from vllm_ascend.patch.worker import patch_common # noqa: F401
3736
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
3837
from vllm_ascend.worker.worker import NPUWorker
3938

tests/singlecard/spec_decode/test_ngram_worker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
from tests.singlecard.spec_decode.utils import (
2626
create_seq_group_metadata_from_prompts, create_worker)
27-
from vllm_ascend.patch.worker import patch_common # noqa: F401
2827

2928

3029
def test_ngram_algo_correctness_for_single_no_match():

tests/singlecard/spec_decode/test_spec_decode_worker.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@
3939
from tests.singlecard.spec_decode.utils import (create_batch,
4040
create_sampler_output_list,
4141
create_worker, mock_worker)
42-
# patch SpecDecodeWorker, AsyncMetricsCollector
43-
from vllm_ascend.patch.worker import patch_common # noqa: F401
4442
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
4543
from vllm_ascend.worker.worker import NPUWorker
4644

0 commit comments

Comments
 (0)