From 8f9b832de0bbb2a5936388130cd3d513dfc0d0c9 Mon Sep 17 00:00:00 2001 From: mengwei805 Date: Thu, 10 Apr 2025 21:18:04 +0800 Subject: [PATCH] [v0.7.3]support MTP in deepseek w8a8 quant model Signed-off-by: mengwei805 --- vllm_ascend/models/__init__.py | 4 + vllm_ascend/models/deepseek_mtp.py | 181 +++++++++++++++++++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 vllm_ascend/models/deepseek_mtp.py diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 0295c01a5e..d68eb8f8d5 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -13,3 +13,7 @@ def register_model(): ModelRegistry.register_model( "DeepseekV3ForCausalLM", "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") + + ModelRegistry.register_model( + "DeepSeekMTPModel", + "vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP") diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py new file mode 100644 index 0000000000..bb5c171253 --- /dev/null +++ b/vllm_ascend/models/deepseek_mtp.py @@ -0,0 +1,181 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Adapted from vllm/model_executor/models/deepseek_mtp.py +# Copyright 2023 The vLLM team. +# +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import \ + VocabParallelEmbedding +from vllm.model_executor.models.deepseek_mtp import ( + DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, + SharedHead) +from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.sampling_metadata import SamplingMetadata + +from .deepseek_v2 import CustomDeepseekV2DecoderLayer + + +class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + nn.Module.__init__(self) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix, + model_config, + cache_config, + quant_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds = torch.where((positions == 0).unsqueeze(-1), + torch.zeros_like(inputs_embeds), + inputs_embeds) + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): CustomDeepSeekMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + + # Note: torch._dynamo.exc.Unsupported: builtin: str + self.layers_list = [ + self.layers[str(idx)] + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + ] + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + return self.layers_list[current_step_idx]( + input_ids, + positions, + kv_caches[current_step_idx], + attn_metadata, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers_list[current_step_idx] + logits = self.logits_processor(mtp_layer.shared_head.head, + mtp_layer.shared_head(hidden_states), + sampling_metadata) + return logits + + +class CustomDeepSeekMTP(DeepSeekMTP): + # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; + # NOTE 2.The description file generated by the current msmodelslim tool does not have + # MTP layer info. Please manually add it and set the value to FLOAT. + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config.model_config.hf_config + self.model = CustomDeepSeekMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + + self.sampler = get_sampler()