Skip to content

Commit 11f4971

Browse files
authored
v0.7.3 support speculative decoding (#252)
### What this PR does / why we need it? support speculative decoding in Ascend, including speculating with a draft model、by matching n-grams in the prompt、using MLP speculators and using EAGLE based draft models. ### Does this PR introduce _any_ user-facing change? u can refer to https://docs.vllm.ai/en/latest/features/spec_decode.html# ### How was this patch tested? Four modes of speculative decoding have been tested, consistent with GPU devices --------- Signed-off-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: mengwei805 <mengwei25@huawei.com>
1 parent 2bad9a7 commit 11f4971

File tree

3 files changed

+467
-0
lines changed

3 files changed

+467
-0
lines changed

vllm_ascend/patch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818
import vllm_ascend.patch.patch_metrics # noqa
1919
import vllm_ascend.patch.patch_minicpm # noqa
2020
import vllm_ascend.patch.patch_rejection_sampler # noqa
21+
import vllm_ascend.patch.patch_spec_decode_worker # noqa
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Any, Dict, Optional
19+
20+
from vllm.config import ParallelConfig
21+
from vllm.logger import init_logger
22+
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
23+
from vllm.model_executor.layers.spec_decode_base_sampler import \
24+
SpecDecodeBaseSampler
25+
from vllm.model_executor.layers.typical_acceptance_sampler import \
26+
TypicalAcceptanceSampler
27+
from vllm.spec_decode.medusa_worker import MedusaWorker
28+
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
29+
from vllm.spec_decode.multi_step_worker import MultiStepWorker
30+
from vllm.spec_decode.ngram_worker import NGramWorker
31+
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
32+
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
33+
from vllm.worker.worker_base import WorkerBase
34+
35+
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
36+
37+
logger = init_logger(__name__)
38+
39+
40+
def create_worker(
41+
cls,
42+
scorer_worker: WorkerBase,
43+
draft_worker_kwargs: Dict[str, Any],
44+
disable_mqa_scorer: bool,
45+
disable_by_batch_size: Optional[int],
46+
draft_token_acceptance_method: str,
47+
typical_acceptance_sampler_posterior_threshold: float,
48+
typical_acceptance_sampler_posterior_alpha: float,
49+
disable_logprobs: bool,
50+
disable_log_stats: bool,
51+
num_speculative_tokens: int,
52+
) -> "SpecDecodeWorker":
53+
54+
allow_zero_draft_token_step = True
55+
enable_lm_head_weight_load = False
56+
num_spec_prefill_steps = 1
57+
ngram_prompt_lookup_max = (
58+
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
59+
ngram_prompt_lookup_min = (
60+
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
61+
draft_model_config = draft_worker_kwargs["vllm_config"].model_config
62+
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
63+
'vllm_config'].parallel_config
64+
if ngram_prompt_lookup_max > 0:
65+
draft_worker_kwargs[
66+
"device_type"] = scorer_worker.device_config.device.type
67+
proposer_worker = NGramWorker(**draft_worker_kwargs)
68+
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
69+
ngram_prompt_lookup_max)
70+
else:
71+
draft_tp = draft_parallel_config.tensor_parallel_size
72+
target_tp = scorer_worker.parallel_config.tensor_parallel_size
73+
74+
if draft_model_config.hf_config.model_type == "mlp_speculator":
75+
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
76+
elif draft_model_config.hf_config.model_type == "medusa":
77+
proposer_worker = MedusaWorker(**draft_worker_kwargs)
78+
else:
79+
# Note: The current version of the MTP module doer not support
80+
# the use of TP1DraftModelRunner
81+
if draft_tp == 1 and draft_model_config.hf_config.model_type !=\
82+
"deepseek_mtp":
83+
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
84+
else:
85+
if draft_model_config.hf_config.model_type == "eagle":
86+
raise NotImplementedError(
87+
f"{draft_model_config.hf_config.model_type} "
88+
"does not support TP > 1 yet")
89+
90+
allow_zero_draft_token_step = False
91+
92+
# Load lm_head weight for eagle in init_device
93+
if draft_model_config.hf_config.model_type == "eagle":
94+
enable_lm_head_weight_load = True
95+
96+
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
97+
if draft_model_config.hf_config.model_type == "deepseek_mtp":
98+
num_spec_prefill_steps = num_speculative_tokens
99+
100+
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
101+
proposer_worker, draft_tp, target_tp)
102+
103+
logger.info("Configuring SpecDecodeWorker with proposer=%s",
104+
type(proposer_worker))
105+
106+
spec_decode_sampler: SpecDecodeBaseSampler = None
107+
if draft_token_acceptance_method == "rejection_sampler":
108+
spec_decode_sampler = RejectionSampler()
109+
elif draft_token_acceptance_method == "typical_acceptance_sampler":
110+
spec_decode_sampler = TypicalAcceptanceSampler(
111+
posterior_threshold=\
112+
typical_acceptance_sampler_posterior_threshold,
113+
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
114+
)
115+
logger.info(
116+
"[Speculative Decoding] Configuring"
117+
" SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))
118+
119+
if not disable_mqa_scorer:
120+
if scorer_worker.model_runner.attn_backend.get_name() != "FLASH_ATTN":
121+
disable_mqa_scorer = True
122+
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
123+
"MQA is only available with flash attn backend.")
124+
125+
if draft_model_config and \
126+
draft_model_config.max_model_len < \
127+
scorer_worker.model_config.max_model_len:
128+
disable_mqa_scorer = True
129+
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
130+
"draft model max_model_len is smaller than the target "
131+
"model max_model_len.")
132+
133+
if not scorer_worker.model_runner.model_config.enforce_eager:
134+
disable_mqa_scorer = True
135+
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
136+
"target model is not running in eager mode.")
137+
138+
return SpecDecodeWorker(
139+
proposer_worker,
140+
scorer_worker,
141+
disable_mqa_scorer=disable_mqa_scorer,
142+
disable_logprobs=disable_logprobs,
143+
disable_log_stats=disable_log_stats,
144+
disable_by_batch_size=disable_by_batch_size,
145+
spec_decode_sampler=spec_decode_sampler,
146+
allow_zero_draft_token_step=allow_zero_draft_token_step,
147+
enable_lm_head_weight_load=enable_lm_head_weight_load,
148+
num_spec_prefill_steps=num_spec_prefill_steps)
149+
150+
151+
SpecDecodeWorker.create_worker = classmethod(create_worker)

0 commit comments

Comments
 (0)