Skip to content

Commit eff3e5f

Browse files
authored
[FEAT] Refactor spec decode to support efficient padded speculation (#3528)
### What this PR does / why we need it? 1. Refactor the file `mtp_proposer.py`, splits torchair related codes into `mtp_torchair_proposer.py` 2. According to vllm-project/vllm#24539, implements padded speculative decoding as described in vllm-project/vllm#21984. ### Does this PR introduce _any_ user-facing change? User can use `disable_padded_drafter_batch` to disable/enable padded speculation, default is `False`. offline example: ``` speculative_config={"method": "deepseek_mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False} ``` ### How was this patch tested? - [x] egaer with pad/unpad: - [x] aclgraph with pad/unpad - [x] torchair with pad/unpad performance test of deepseek-r1 with tp16、dp1 aclgraph with pad ITL: 168ms aclgraph with unpad ITL: 169ms original: 178ms - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@83f478b --------- Signed-off-by: xuyexiong <xuyexiong@huawei.com>
1 parent 10772d9 commit eff3e5f

File tree

7 files changed

+1207
-444
lines changed

7 files changed

+1207
-444
lines changed

tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from __future__ import annotations
22

3+
import os
4+
35
import pytest
46
from vllm import SamplingParams
57
from vllm.config import CompilationConfig, CUDAGraphMode
68

79
from tests.e2e.conftest import VllmRunner
810

11+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
12+
913

1014
@pytest.fixture
1115
def sampling_config():
@@ -17,12 +21,12 @@ def model_name():
1721
return "wemaster/deepseek_mtp_main_random_bf16"
1822

1923

20-
def mtp_correctness(
21-
sampling_config: SamplingParams,
22-
model_name: str,
23-
num_speculative_tokens: int,
24-
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
25-
):
24+
def mtp_correctness(sampling_config: SamplingParams,
25+
model_name: str,
26+
num_speculative_tokens: int,
27+
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
28+
enforce_eager=False,
29+
disable_padded_drafter_batch=True):
2630
example_prompts = [
2731
"Hello, my name is",
2832
"The president of the United States is",
@@ -37,7 +41,7 @@ def mtp_correctness(
3741
tensor_parallel_size=1,
3842
gpu_memory_utilization=0.7,
3943
max_model_len=256,
40-
enforce_eager=False) as ref_llm:
44+
enforce_eager=enforce_eager) as ref_llm:
4145
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
4246

4347
graph_mode_str = "PIECEWISE"
@@ -54,8 +58,9 @@ def mtp_correctness(
5458
speculative_config={
5559
"method": "deepseek_mtp",
5660
"num_speculative_tokens": num_speculative_tokens,
61+
"disable_padded_drafter_batch": disable_padded_drafter_batch,
5762
},
58-
enforce_eager=False,
63+
enforce_eager=enforce_eager,
5964
max_model_len=2000,
6065
compilation_config=CompilationConfig(
6166
cudagraph_mode=graph_mode_str),
@@ -82,6 +87,20 @@ def mtp_correctness(
8287
del spec_llm
8388

8489

90+
def test_mtp1_correctness_eager(
91+
sampling_config: SamplingParams,
92+
model_name: str,
93+
):
94+
mtp_correctness(sampling_config, model_name, 1, enforce_eager=True)
95+
96+
97+
def test_mtp2_correctness_eager(
98+
sampling_config: SamplingParams,
99+
model_name: str,
100+
):
101+
mtp_correctness(sampling_config, model_name, 2, enforce_eager=True)
102+
103+
85104
@pytest.mark.skip("TODO(cmq): Revert me when mtp aclgraph is fixed")
86105
def test_mtp1_correctness_piecewise_graph(
87106
sampling_config: SamplingParams,
@@ -110,3 +129,47 @@ def test_mtp2_correctness_full_graph(
110129
model_name: str,
111130
):
112131
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
132+
133+
134+
def test_mtp1_correctness_eager_with_pad(
135+
sampling_config: SamplingParams,
136+
model_name: str,
137+
):
138+
mtp_correctness(sampling_config,
139+
model_name,
140+
1,
141+
enforce_eager=True,
142+
disable_padded_drafter_batch=False)
143+
144+
145+
def test_mtp2_correctness_eager_with_pad(
146+
sampling_config: SamplingParams,
147+
model_name: str,
148+
):
149+
mtp_correctness(sampling_config,
150+
model_name,
151+
2,
152+
enforce_eager=True,
153+
disable_padded_drafter_batch=False)
154+
155+
156+
@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed")
157+
def test_mtp1_correctness_piecewise_graph_with_pad(
158+
sampling_config: SamplingParams,
159+
model_name: str,
160+
):
161+
mtp_correctness(sampling_config,
162+
model_name,
163+
1,
164+
disable_padded_drafter_batch=False)
165+
166+
167+
@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed")
168+
def test_mtp2_correctness_piecewise_graph_with_pad(
169+
sampling_config: SamplingParams,
170+
model_name: str,
171+
):
172+
mtp_correctness(sampling_config,
173+
model_name,
174+
2,
175+
disable_padded_drafter_batch=False)

vllm_ascend/spec_decode/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,21 @@
1919
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
2020
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
2121
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
22+
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
2223

2324

24-
def get_spec_decode_method(method, vllm_config, device, runner):
25+
def get_spec_decode_method(method,
26+
vllm_config,
27+
device,
28+
runner,
29+
is_torchair_graph=False):
2530
if method == "ngram":
2631
return NgramProposer(vllm_config, device, runner)
2732
elif method in ["eagle", "eagle3"]:
2833
return EagleProposer(vllm_config, device, runner)
2934
elif method == 'deepseek_mtp':
35+
if is_torchair_graph:
36+
return TorchairMtpProposer(vllm_config, device, runner)
3037
return MtpProposer(vllm_config, device, runner)
3138
else:
3239
raise ValueError("Unknown speculative decoding method: "

0 commit comments

Comments
 (0)