Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
import vllm_ascend.patch.patch_cache_dtype # noqa
import vllm_ascend.patch.patch_metrics # noqa
import vllm_ascend.patch.patch_minicpm # noqa
import vllm_ascend.patch.patch_multi_step_worker # noqa
import vllm_ascend.patch.patch_rejection_sampler # noqa
import vllm_ascend.patch.patch_spec_decode_worker # noqa
87 changes: 87 additions & 0 deletions vllm_ascend/patch/patch_multi_step_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import List, Set, Tuple

import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.multi_step_worker import MultiStepWorker

from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner


def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.

For multi step worker, this indicator shall be True.
"""
self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)

# Run model sample_len times.
model_outputs: List[SamplerOutput] = []

# TODO: supports_gpu_multi_step is False in ASCEND
if isinstance(self.model_runner, TP1DraftModelRunner) and \
self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
self.model_runner.set_indices_of_seq_with_bonus_tokens(
indices_of_seq_with_bonus_tokens)
model_outputs = self.execute_model(execute_model_req=expanded_request)
else:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]

self._append_new_tokens(model_output,
expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)

# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens = torch.tensor(
indices_of_seq_with_bonus_tokens, device=self.device)
filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True


MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output)
13 changes: 10 additions & 3 deletions vllm_ascend/worker/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
attn_metadata=attn_metadata,
seq_lens=attn_metadata.seq_lens,
query_lens=model_input.query_lens,
lora_mapping=model_input.lora_mapping,
lora_requests=model_input.lora_requests,
# Notes: If vllm_ascend supports LORA, we need to
# add the following two params.
# lora_mapping=model_input.lora_mapping,
# lora_requests=model_input.lora_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
sampling_metadata=model_input.sampling_metadata,
is_prompt=False,
Expand Down Expand Up @@ -156,7 +158,8 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
if seq_group.is_prompt:
return False

# TODO: Add support for other attn backends
# TODO: Add support for ASCEND when outer multi_step_worker
# could work correct.
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
return False

Expand Down Expand Up @@ -266,6 +269,10 @@ def execute_model(
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
with set_forward_context(model_input.attn_metadata,
self.vllm_config):

if model_input.attn_metadata is not None:
model_input.attn_metadata.input_positions = model_input.input_positions

hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
Expand Down