diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 7a9ebe53b7..478505e533 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -18,3 +18,4 @@ import vllm_ascend.patch.patch_metrics # noqa import vllm_ascend.patch.patch_minicpm # noqa import vllm_ascend.patch.patch_rejection_sampler # noqa +import vllm_ascend.patch.patch_spec_decode_worker # noqa diff --git a/vllm_ascend/patch/patch_spec_decode_worker.py b/vllm_ascend/patch/patch_spec_decode_worker.py new file mode 100644 index 0000000000..223fa3d36f --- /dev/null +++ b/vllm_ascend/patch/patch_spec_decode_worker.py @@ -0,0 +1,151 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Dict, Optional + +from vllm.config import ParallelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.spec_decode_base_sampler import \ + SpecDecodeBaseSampler +from vllm.model_executor.layers.typical_acceptance_sampler import \ + TypicalAcceptanceSampler +from vllm.spec_decode.medusa_worker import MedusaWorker +from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.ngram_worker import NGramWorker +from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker +from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker +from vllm.worker.worker_base import WorkerBase + +from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner + +logger = init_logger(__name__) + + +def create_worker( + cls, + scorer_worker: WorkerBase, + draft_worker_kwargs: Dict[str, Any], + disable_mqa_scorer: bool, + disable_by_batch_size: Optional[int], + draft_token_acceptance_method: str, + typical_acceptance_sampler_posterior_threshold: float, + typical_acceptance_sampler_posterior_alpha: float, + disable_logprobs: bool, + disable_log_stats: bool, + num_speculative_tokens: int, +) -> "SpecDecodeWorker": + + allow_zero_draft_token_step = True + enable_lm_head_weight_load = False + num_spec_prefill_steps = 1 + ngram_prompt_lookup_max = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_max")) + ngram_prompt_lookup_min = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + draft_model_config = draft_worker_kwargs["vllm_config"].model_config + draft_parallel_config: ParallelConfig = draft_worker_kwargs[ + 'vllm_config'].parallel_config + if ngram_prompt_lookup_max > 0: + draft_worker_kwargs[ + "device_type"] = scorer_worker.device_config.device.type + proposer_worker = NGramWorker(**draft_worker_kwargs) + proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, + ngram_prompt_lookup_max) + else: + draft_tp = draft_parallel_config.tensor_parallel_size + target_tp = scorer_worker.parallel_config.tensor_parallel_size + + if draft_model_config.hf_config.model_type == "mlp_speculator": + proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) + elif draft_model_config.hf_config.model_type == "medusa": + proposer_worker = MedusaWorker(**draft_worker_kwargs) + else: + # Note: The current version of the MTP module doer not support + # the use of TP1DraftModelRunner + if draft_tp == 1 and draft_model_config.hf_config.model_type !=\ + "deepseek_mtp": + draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner + else: + if draft_model_config.hf_config.model_type == "eagle": + raise NotImplementedError( + f"{draft_model_config.hf_config.model_type} " + "does not support TP > 1 yet") + + allow_zero_draft_token_step = False + + # Load lm_head weight for eagle in init_device + if draft_model_config.hf_config.model_type == "eagle": + enable_lm_head_weight_load = True + + proposer_worker = MultiStepWorker(**draft_worker_kwargs) + if draft_model_config.hf_config.model_type == "deepseek_mtp": + num_spec_prefill_steps = num_speculative_tokens + + proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( + proposer_worker, draft_tp, target_tp) + + logger.info("Configuring SpecDecodeWorker with proposer=%s", + type(proposer_worker)) + + spec_decode_sampler: SpecDecodeBaseSampler = None + if draft_token_acceptance_method == "rejection_sampler": + spec_decode_sampler = RejectionSampler() + elif draft_token_acceptance_method == "typical_acceptance_sampler": + spec_decode_sampler = TypicalAcceptanceSampler( + posterior_threshold=\ + typical_acceptance_sampler_posterior_threshold, + posterior_alpha=typical_acceptance_sampler_posterior_alpha, + ) + logger.info( + "[Speculative Decoding] Configuring" + " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) + + if not disable_mqa_scorer: + if scorer_worker.model_runner.attn_backend.get_name() != "FLASH_ATTN": + disable_mqa_scorer = True + logger.info("[Speculative Decoding] Disabling MQA scorer as the " + "MQA is only available with flash attn backend.") + + if draft_model_config and \ + draft_model_config.max_model_len < \ + scorer_worker.model_config.max_model_len: + disable_mqa_scorer = True + logger.info("[Speculative Decoding] Disabling MQA scorer as the " + "draft model max_model_len is smaller than the target " + "model max_model_len.") + + if not scorer_worker.model_runner.model_config.enforce_eager: + disable_mqa_scorer = True + logger.info("[Speculative Decoding] Disabling MQA scorer as the " + "target model is not running in eager mode.") + + return SpecDecodeWorker( + proposer_worker, + scorer_worker, + disable_mqa_scorer=disable_mqa_scorer, + disable_logprobs=disable_logprobs, + disable_log_stats=disable_log_stats, + disable_by_batch_size=disable_by_batch_size, + spec_decode_sampler=spec_decode_sampler, + allow_zero_draft_token_step=allow_zero_draft_token_step, + enable_lm_head_weight_load=enable_lm_head_weight_load, + num_spec_prefill_steps=num_spec_prefill_steps) + + +SpecDecodeWorker.create_worker = classmethod(create_worker) diff --git a/vllm_ascend/worker/draft_model_runner.py b/vllm_ascend/worker/draft_model_runner.py new file mode 100644 index 0000000000..f5c76b7d24 --- /dev/null +++ b/vllm_ascend/worker/draft_model_runner.py @@ -0,0 +1,315 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List, Optional + +import torch +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MultiModalKwargs +from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerWrapperBase) + +from vllm_ascend.attention import AscendMetadata as FlashAttentionMetadata + +logger = init_logger(__name__) + +# A flag to enable debug prints for the updated input tensors +# before each step. +debug_advance_input = False +# A flag to allow GPU advance step for draft model runner. +# Set to False for debugging. +allow_gpu_advance_step = True + + +class TP1DraftModelRunner(ModelRunnerWrapperBase): + """Specialized model runner for speculative decoding draft model. + Since the draft model always execute k forward passes consecutively to + generate k speculative tokens in a single speculative decoding step, + we could get rid of most CPU-GPU synchronization and data transfer + overheads by keeping model input and output tensors on GPU all the time. + + TODOs: + 1. Currently supports only flash-attn, add support for other attn_backends. + 2. Support TP > 1 (this requires some designs because we do not expect + any broadcasting inside execute_model). + """ + + def __init__(self, model_runner: ModelRunnerBase): + if hasattr( + model_runner, + "return_hidden_states") and model_runner.return_hidden_states: + raise ValueError( + "return_hidden_states is not supported for TP1DraftModelRunner." + ) + super().__init__(model_runner) + + self.indices_of_seq_with_bonus_tokens = None + + def _update_sampling_metadata(self, sampling_metadata, num_seqs, + num_queries): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + + def _gpu_advance_step(self, model_input: ModelRunnerInputBase, + last_output: SamplerOutput) -> ModelRunnerInputBase: + # Currently, we expect "decode mode" only + assert not model_input.is_prompt + + # Get num_seqs + num_seqs = len(model_input.seq_lens) + num_queries = len(model_input.query_lens) + + # Get output tokens GPU tensor + sampled_token_ids = last_output.sampled_token_ids + assert sampled_token_ids is not None + + # Update attn_metadata + attn_metadata = model_input.attn_metadata + assert isinstance(attn_metadata, FlashAttentionMetadata) + + attn_metadata.advance_step(model_input, sampled_token_ids, + self.block_size, num_seqs, num_queries) + + # Update sampling_metadata + sampling_metadata = model_input.sampling_metadata + self._update_sampling_metadata(sampling_metadata, num_seqs, + num_queries) + + # Create new input + new_model_input = self._model_input_cls( + input_tokens=model_input.input_tokens, + input_positions=model_input.input_positions, + attn_metadata=attn_metadata, + seq_lens=attn_metadata.seq_lens, + query_lens=model_input.query_lens, + lora_mapping=model_input.lora_mapping, + lora_requests=model_input.lora_requests, + multi_modal_kwargs=model_input.multi_modal_kwargs, + sampling_metadata=model_input.sampling_metadata, + is_prompt=False, + ) + + # Ensure we skip CPU samples + assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True + # We can reuse sampling tensors since every decode iteration is the same + new_model_input.sampling_metadata.reuse_sampling_tensors = True + + if debug_advance_input: + logger.debug("NEW INPUT: ") + logger.debug(" input_tokens = %s", new_model_input.input_tokens) + logger.debug(" input_positions = %s", + new_model_input.input_positions) + logger.debug(" seq_lens = %d", new_model_input.seq_lens) + logger.debug(" query_lens = %d", new_model_input.query_lens) + logger.debug(" attn_metadata:") + logger.debug(" seq_lens_tensor: %s", + attn_metadata.seq_lens_tensor) + logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping) + logger.debug(" block_tables: %s", attn_metadata.block_tables) + + return new_model_input + + def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): + """Determines if draft_model_runner GPU multi-step can be used. + Currently required conditions are: + 1. Only decodes + 2. Only flash-attn + 3. No LORA + 4. No prompt_adapter_config + """ + if not allow_gpu_advance_step: + return False + + # We allow multi-step GPU only in decode mode + for seq_group in execute_model_req.seq_group_metadata_list: + if seq_group.is_prompt: + return False + + # TODO: Add support for other attn backends + if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"): + return False + + # TODO: Add support for LORA + if self.lora_config: + return False + + # TODO: Add soft-tuning prompt adapter support + return not self.prompt_adapter_config + + def set_indices_of_seq_with_bonus_tokens(self, + indices_of_seq_with_bonus_tokens): + self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelRunnerInputBase, + kv_caches: List[torch.Tensor], + previous_hidden_states: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + **kwargs, + ) -> Optional[List[SamplerOutput]]: + """Executes num_steps forward passes with advacement of input tensors + on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. + + Optimizations used: + 1. Input tensors are updated on the GPU directly + 2. Skips GPU=>CPU serialization of sampler outputs (we don't need + them since we do batch expansion later that uses GPU outputs) + 3. Reuses sampling tensors (since we run only decodes and they have + a repeating sampling logic) + """ + + # When num_steps == 1, we execute the fallback here for the GPU + # advance_step, which runs prepare_inputs on CPU and for each spec + # iteration invokes this function only once + # (Look at multi-step-worker code) + is_fallback = num_steps == 1 + if not is_fallback: + # Since we do not broadcast data inside execute_model anymore, + # we need to figure out the best way to support TP > 1 in this + # case, because we will at least need to broadcast the sampled + # tokens to all workers. + if not self.is_driver_worker: + raise ValueError("TP1DraftModelRunner only supports TP=1.") + + # Sanity + if self.lora_config is not None: + raise ValueError("TP1DraftModelRunner has no support for LORA") + if self.prompt_adapter_config is not None: + raise ValueError("TP1DraftModelRunner has no support for " + "prompt_adapter_config") + if model_input.multi_modal_kwargs: + raise ValueError( + "TP1DraftModelRunner has no support for multi_modal_kwargs" + ) + else: + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + self.attn_state.begin_forward(model_input) + + # Detect exec mode + assert model_input.attn_metadata is not None + if model_input.attn_metadata.num_prefills > 0: + # In this case, execute_model(..) was called directly + if num_steps > 1: + raise ValueError( + "execute_model(..) of draft_model_runner can be called " + "directly only with a single-step prefill") + else: + # We can skip CPU samples for spec token generation. + # (We do allow CPU samples for num_steps == 1 to support the + # fallback case, where supports_gpu_multi_step(..) does not pass) + model_input.sampling_metadata.skip_sampler_cpu_output = ( + not is_fallback) + + model_executable = self.model + hidden_states = previous_hidden_states + + outputs: List[SamplerOutput] = [] + for step in range(num_steps): + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + + model_execute_kwargs = {"previous_hidden_states": hidden_states} \ + if previous_hidden_states is not None else {} + + compute_logits_kwargs = {} + # Run model + if hasattr(self.model.config, "num_nextn_predict_layers"): + # for DeepSeek MTP only to use the corresponding layer for + # each step + spec_step_idx = kwargs.get("spec_step_idx", step) + model_execute_kwargs["spec_step_idx"] = spec_step_idx + compute_logits_kwargs["spec_step_idx"] = spec_step_idx + with set_forward_context(model_input.attn_metadata, + self.vllm_config): + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, + device=self.device), + **model_execute_kwargs, + ) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata, + **compute_logits_kwargs) + if not self.is_driver_worker: + return [] + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + outputs.append(output) + + if model_input.attn_metadata.num_prefills == 0 \ + and self.indices_of_seq_with_bonus_tokens is not None: + assert output.sampled_token_ids is not None + # output.sampled_token_ids should be of shape (num_seqs, 1) + nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape + assert num_tokens_per_seq == 1 + count = 0 + for i in range(nums_seqs): + bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[ + count] + if i != bonus_seq_idx: + # The following might cause a cpu->gpu sync + # However, the performance impact is negligible as we + # benchmarked on H100. + output.sampled_token_ids[ + i, :] = model_input.input_tokens[bonus_seq_idx] + else: + count += 1 + + # Prepare inputs for the next step + if step != num_steps - 1: + model_input = self._gpu_advance_step(model_input, outputs[-1]) + + return outputs