11#
22# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3- # Adapted from vllm/model_executor/models/qwen2_vl .py
3+ # Adapted from vllm/model_executor/models/deepseek_mtp .py
44# Copyright 2023 The vLLM team.
55#
66# This file is a part of the vllm-ascend project.
1717# See the License for the specific language governing permissions and
1818# limitations under the License.
1919
20- from typing import List , Optional
20+ from typing import Optional
2121
2222import torch
2323import torch .nn as nn
2424from transformers import PretrainedConfig
25- from vllm .attention .backends .abstract import AttentionMetadata
2625from vllm .config import CacheConfig , ModelConfig , VllmConfig
2726from vllm .model_executor .layers .layernorm import RMSNorm
2827from vllm .model_executor .layers .logits_processor import LogitsProcessor
@@ -70,8 +69,6 @@ def forward(
7069 self ,
7170 input_ids : torch .Tensor ,
7271 positions : torch .Tensor ,
73- kv_cache : torch .Tensor ,
74- attn_metadata : AttentionMetadata ,
7572 previous_hidden_states : torch .Tensor ,
7673 inputs_embeds : Optional [torch .Tensor ] = None ,
7774 spec_step_index : int = 0 ,
@@ -91,8 +88,6 @@ def forward(
9188
9289 hidden_states , residual = self .mtp_block (positions = positions ,
9390 hidden_states = hidden_states ,
94- kv_cache = kv_cache ,
95- attn_metadata = attn_metadata ,
9691 residual = None )
9792 hidden_states = residual + hidden_states
9893 return hidden_states
@@ -130,8 +125,6 @@ def forward(
130125 self ,
131126 input_ids : torch .Tensor ,
132127 positions : torch .Tensor ,
133- kv_caches : List [torch .Tensor ],
134- attn_metadata : AttentionMetadata ,
135128 previous_hidden_states : torch .Tensor ,
136129 inputs_embeds : Optional [torch .Tensor ] = None ,
137130 spec_step_idx : int = 0 ,
@@ -140,8 +133,6 @@ def forward(
140133 return self .layers_list [current_step_idx ](
141134 input_ids ,
142135 positions ,
143- kv_caches [current_step_idx ],
144- attn_metadata ,
145136 previous_hidden_states ,
146137 inputs_embeds ,
147138 current_step_idx ,
@@ -162,6 +153,14 @@ def compute_logits(
162153
163154
164155class CustomDeepSeekMTP (DeepSeekMTP ):
156+ # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
157+ # NOTE 2.The description file generated by the current msmodelslim tool does not have
158+ # MTP layer info. Please manually add it and set the value to FLOAT.
159+ packed_modules_mapping = {
160+ "gate_up_proj" : ["gate_proj" , "up_proj" ],
161+ "experts" :
162+ ["experts.0.gate_proj" , "experts.0.up_proj" , "experts.0.down_proj" ]
163+ }
165164
166165 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
167166 nn .Module .__init__ (self )
0 commit comments