Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm_ascend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ def register_model():
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")

ModelRegistry.register_model(
"DeepSeekMTPModel",
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
181 changes: 181 additions & 0 deletions vllm_ascend/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Adapted from vllm/model_executor/models/deepseek_mtp.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding
from vllm.model_executor.models.deepseek_mtp import (
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
SharedHead)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata

from .deepseek_v2 import CustomDeepseekV2DecoderLayer


class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):

def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
nn.Module.__init__(self)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)

self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
model_config,
cache_config,
quant_config)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that this masking is done like the following in vllm. why we use torch.where here? Is there any benifits?

inputs_embeds[positions == 0] = 0

Copy link
Collaborator Author

@mengwei805 mengwei805 Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original vLLM method cannot be used in torchair, so it is replaced with an equivalent method.
And the writing of inputs_embeds[positions == 0] has poor performance on Ascend devices

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks for this explanation

torch.zeros_like(inputs_embeds),
inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)

hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))

hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=None)
hidden_states = residual + hidden_states
return hidden_states


class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict({
str(idx): CustomDeepSeekMultiTokenPredictorLayer(
config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})

# Note: torch._dynamo.exc.Unsupported: builtin: str
self.layers_list = [
self.layers[str(idx)]
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
]
self.logits_processor = LogitsProcessor(config.vocab_size)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers)
return self.layers_list[current_step_idx](
input_ids,
positions,
kv_caches[current_step_idx],
attn_metadata,
previous_hidden_states,
inputs_embeds,
current_step_idx,
)

def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers)
mtp_layer = self.layers_list[current_step_idx]
logits = self.logits_processor(mtp_layer.shared_head.head,
mtp_layer.shared_head(hidden_states),
sampling_metadata)
return logits


class CustomDeepSeekMTP(DeepSeekMTP):
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config.model_config.hf_config
self.model = CustomDeepSeekMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))

self.sampler = get_sampler()
Loading