|
| 1 | +# |
| 2 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 3 | +# This file is a part of the vllm-ascend project. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +from typing import Any, Dict, Optional |
| 19 | + |
| 20 | +from vllm.config import ParallelConfig |
| 21 | +from vllm.logger import init_logger |
| 22 | +from vllm.model_executor.layers.rejection_sampler import RejectionSampler |
| 23 | +from vllm.model_executor.layers.spec_decode_base_sampler import \ |
| 24 | + SpecDecodeBaseSampler |
| 25 | +from vllm.model_executor.layers.typical_acceptance_sampler import \ |
| 26 | + TypicalAcceptanceSampler |
| 27 | +from vllm.spec_decode.medusa_worker import MedusaWorker |
| 28 | +from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker |
| 29 | +from vllm.spec_decode.multi_step_worker import MultiStepWorker |
| 30 | +from vllm.spec_decode.ngram_worker import NGramWorker |
| 31 | +from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker |
| 32 | +from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker |
| 33 | +from vllm.worker.worker_base import WorkerBase |
| 34 | + |
| 35 | +from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner |
| 36 | + |
| 37 | +logger = init_logger(__name__) |
| 38 | + |
| 39 | + |
| 40 | +def create_worker( |
| 41 | + cls, |
| 42 | + scorer_worker: WorkerBase, |
| 43 | + draft_worker_kwargs: Dict[str, Any], |
| 44 | + disable_mqa_scorer: bool, |
| 45 | + disable_by_batch_size: Optional[int], |
| 46 | + draft_token_acceptance_method: str, |
| 47 | + typical_acceptance_sampler_posterior_threshold: float, |
| 48 | + typical_acceptance_sampler_posterior_alpha: float, |
| 49 | + disable_logprobs: bool, |
| 50 | + disable_log_stats: bool, |
| 51 | + num_speculative_tokens: int, |
| 52 | +) -> "SpecDecodeWorker": |
| 53 | + |
| 54 | + allow_zero_draft_token_step = True |
| 55 | + enable_lm_head_weight_load = False |
| 56 | + num_spec_prefill_steps = 1 |
| 57 | + ngram_prompt_lookup_max = ( |
| 58 | + draft_worker_kwargs.pop("ngram_prompt_lookup_max")) |
| 59 | + ngram_prompt_lookup_min = ( |
| 60 | + draft_worker_kwargs.pop("ngram_prompt_lookup_min")) |
| 61 | + draft_model_config = draft_worker_kwargs["vllm_config"].model_config |
| 62 | + draft_parallel_config: ParallelConfig = draft_worker_kwargs[ |
| 63 | + 'vllm_config'].parallel_config |
| 64 | + if ngram_prompt_lookup_max > 0: |
| 65 | + draft_worker_kwargs[ |
| 66 | + "device_type"] = scorer_worker.device_config.device.type |
| 67 | + proposer_worker = NGramWorker(**draft_worker_kwargs) |
| 68 | + proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, |
| 69 | + ngram_prompt_lookup_max) |
| 70 | + else: |
| 71 | + draft_tp = draft_parallel_config.tensor_parallel_size |
| 72 | + target_tp = scorer_worker.parallel_config.tensor_parallel_size |
| 73 | + |
| 74 | + if draft_model_config.hf_config.model_type == "mlp_speculator": |
| 75 | + proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) |
| 76 | + elif draft_model_config.hf_config.model_type == "medusa": |
| 77 | + proposer_worker = MedusaWorker(**draft_worker_kwargs) |
| 78 | + else: |
| 79 | + # Note: The current version of the MTP module doer not support |
| 80 | + # the use of TP1DraftModelRunner |
| 81 | + if draft_tp == 1 and draft_model_config.hf_config.model_type !=\ |
| 82 | + "deepseek_mtp": |
| 83 | + draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner |
| 84 | + else: |
| 85 | + if draft_model_config.hf_config.model_type == "eagle": |
| 86 | + raise NotImplementedError( |
| 87 | + f"{draft_model_config.hf_config.model_type} " |
| 88 | + "does not support TP > 1 yet") |
| 89 | + |
| 90 | + allow_zero_draft_token_step = False |
| 91 | + |
| 92 | + # Load lm_head weight for eagle in init_device |
| 93 | + if draft_model_config.hf_config.model_type == "eagle": |
| 94 | + enable_lm_head_weight_load = True |
| 95 | + |
| 96 | + proposer_worker = MultiStepWorker(**draft_worker_kwargs) |
| 97 | + if draft_model_config.hf_config.model_type == "deepseek_mtp": |
| 98 | + num_spec_prefill_steps = num_speculative_tokens |
| 99 | + |
| 100 | + proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( |
| 101 | + proposer_worker, draft_tp, target_tp) |
| 102 | + |
| 103 | + logger.info("Configuring SpecDecodeWorker with proposer=%s", |
| 104 | + type(proposer_worker)) |
| 105 | + |
| 106 | + spec_decode_sampler: SpecDecodeBaseSampler = None |
| 107 | + if draft_token_acceptance_method == "rejection_sampler": |
| 108 | + spec_decode_sampler = RejectionSampler() |
| 109 | + elif draft_token_acceptance_method == "typical_acceptance_sampler": |
| 110 | + spec_decode_sampler = TypicalAcceptanceSampler( |
| 111 | + posterior_threshold=\ |
| 112 | + typical_acceptance_sampler_posterior_threshold, |
| 113 | + posterior_alpha=typical_acceptance_sampler_posterior_alpha, |
| 114 | + ) |
| 115 | + logger.info( |
| 116 | + "[Speculative Decoding] Configuring" |
| 117 | + " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) |
| 118 | + |
| 119 | + if not disable_mqa_scorer: |
| 120 | + if scorer_worker.model_runner.attn_backend.get_name() != "FLASH_ATTN": |
| 121 | + disable_mqa_scorer = True |
| 122 | + logger.info("[Speculative Decoding] Disabling MQA scorer as the " |
| 123 | + "MQA is only available with flash attn backend.") |
| 124 | + |
| 125 | + if draft_model_config and \ |
| 126 | + draft_model_config.max_model_len < \ |
| 127 | + scorer_worker.model_config.max_model_len: |
| 128 | + disable_mqa_scorer = True |
| 129 | + logger.info("[Speculative Decoding] Disabling MQA scorer as the " |
| 130 | + "draft model max_model_len is smaller than the target " |
| 131 | + "model max_model_len.") |
| 132 | + |
| 133 | + if not scorer_worker.model_runner.model_config.enforce_eager: |
| 134 | + disable_mqa_scorer = True |
| 135 | + logger.info("[Speculative Decoding] Disabling MQA scorer as the " |
| 136 | + "target model is not running in eager mode.") |
| 137 | + |
| 138 | + return SpecDecodeWorker( |
| 139 | + proposer_worker, |
| 140 | + scorer_worker, |
| 141 | + disable_mqa_scorer=disable_mqa_scorer, |
| 142 | + disable_logprobs=disable_logprobs, |
| 143 | + disable_log_stats=disable_log_stats, |
| 144 | + disable_by_batch_size=disable_by_batch_size, |
| 145 | + spec_decode_sampler=spec_decode_sampler, |
| 146 | + allow_zero_draft_token_step=allow_zero_draft_token_step, |
| 147 | + enable_lm_head_weight_load=enable_lm_head_weight_load, |
| 148 | + num_spec_prefill_steps=num_spec_prefill_steps) |
| 149 | + |
| 150 | + |
| 151 | +SpecDecodeWorker.create_worker = classmethod(create_worker) |
0 commit comments