Skip to content

Commit e835d88

Browse files
committed
spec decode MultiStepWorker support TP1DraftModelRunner fully, support
run the draft_model_runner with multi-step prepare on the NPU directly and support draft_model_runner use MLA. Signed-off-by: mengwei805 <mengwei25@huawei.com>
1 parent 12390af commit e835d88

File tree

3 files changed

+97
-3
lines changed

3 files changed

+97
-3
lines changed

vllm_ascend/patch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@
1717
import vllm_ascend.patch.patch_cache_dtype # noqa
1818
import vllm_ascend.patch.patch_metrics # noqa
1919
import vllm_ascend.patch.patch_minicpm # noqa
20+
import vllm_ascend.patch.patch_multi_step_worker # noqa
2021
import vllm_ascend.patch.patch_rejection_sampler # noqa
2122
import vllm_ascend.patch.patch_spec_decode_worker # noqa
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import List, Set, Tuple
19+
20+
import torch
21+
from vllm.model_executor.layers.sampler import SamplerOutput
22+
from vllm.sequence import ExecuteModelRequest
23+
from vllm.spec_decode.multi_step_worker import MultiStepWorker
24+
25+
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
26+
27+
28+
def sampler_output(
29+
self,
30+
execute_model_req: ExecuteModelRequest,
31+
sample_len: int,
32+
seq_ids_with_bonus_token_in_last_step: Set[int],
33+
) -> Tuple[List[SamplerOutput], bool]:
34+
"""Run the model forward pass sample_len times. Returns the list of
35+
sampler output, one per model forward pass, along with indicator of
36+
whether torch tensor in sampler output need to be transposed in latter
37+
sampler_output_to_torch logic.
38+
39+
For multi step worker, this indicator shall be True.
40+
"""
41+
self._raise_if_unsupported(execute_model_req)
42+
# Expand the batch for sequences with a bonus token.
43+
# Perform a forward pass on the expanded batch and filter the
44+
# response to retain only the original sequences' responses.
45+
expanded_request, indices_of_seq_with_bonus_tokens =\
46+
self._expand_execute_model_request(
47+
execute_model_req, seq_ids_with_bonus_token_in_last_step)
48+
49+
# Run model sample_len times.
50+
model_outputs: List[SamplerOutput] = []
51+
52+
if isinstance(self.model_runner, TP1DraftModelRunner) and \
53+
self.model_runner.supports_gpu_multi_step(expanded_request):
54+
# Here we run the draft_model_runner with multi-step prepare
55+
# on the GPU directly
56+
expanded_request.num_steps = sample_len
57+
self.model_runner.set_indices_of_seq_with_bonus_tokens(
58+
indices_of_seq_with_bonus_tokens)
59+
model_outputs = self.execute_model(execute_model_req=expanded_request)
60+
else:
61+
# Here we run multi-step directly, with every step prepared
62+
# on the CPU.
63+
# TODO: Remove this branch once DraftModelRunner supports TP>1
64+
# and other restrictions that are part of DraftModelRunner's
65+
# supports_gpu_multi_step(..)
66+
for _ in range(sample_len):
67+
model_output: List[SamplerOutput] = self.worker.execute_model(
68+
execute_model_req=expanded_request)
69+
assert (len(model_output) == 1
70+
), "composing multistep workers not supported"
71+
model_output = model_output[0]
72+
73+
self._append_new_tokens(model_output,
74+
expanded_request.seq_group_metadata_list,
75+
indices_of_seq_with_bonus_tokens)
76+
model_outputs.append(model_output)
77+
78+
# move indices to device to avoid stream sync
79+
indices_of_seq_with_bonus_tokens = torch.tensor(
80+
indices_of_seq_with_bonus_tokens, device=self.device)
81+
filtered_model_outputs = self._filter_model_output(
82+
model_outputs, indices_of_seq_with_bonus_tokens)
83+
return filtered_model_outputs, True
84+
85+
86+
MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output)

vllm_ascend/worker/draft_model_runner.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,10 @@ def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
113113
attn_metadata=attn_metadata,
114114
seq_lens=attn_metadata.seq_lens,
115115
query_lens=model_input.query_lens,
116-
lora_mapping=model_input.lora_mapping,
117-
lora_requests=model_input.lora_requests,
116+
# Notes: If vllm_ascend supports LORA, we need to
117+
# add the following two params.
118+
# lora_mapping=model_input.lora_mapping,
119+
# lora_requests=model_input.lora_requests,
118120
multi_modal_kwargs=model_input.multi_modal_kwargs,
119121
sampling_metadata=model_input.sampling_metadata,
120122
is_prompt=False,
@@ -157,7 +159,8 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
157159
return False
158160

159161
# TODO: Add support for other attn backends
160-
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
162+
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA",
163+
"ASCEND"):
161164
return False
162165

163166
# TODO: Add support for LORA
@@ -266,6 +269,10 @@ def execute_model(
266269
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
267270
with set_forward_context(model_input.attn_metadata,
268271
self.vllm_config):
272+
273+
if model_input.attn_metadata is not None:
274+
model_input.attn_metadata.input_positions = model_input.input_positions
275+
269276
hidden_states = model_executable(
270277
input_ids=model_input.input_tokens,
271278
positions=model_input.input_positions,

0 commit comments

Comments
 (0)