-
Notifications
You must be signed in to change notification settings - Fork 555
[v0.7.3]support MTP in deepseek w8a8 quant model #502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
wangxiyuan
merged 1 commit into
vllm-project:v0.7.3-dev
from
mengwei805:v0.7.3-mtp-quant
Apr 11, 2025
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,181 @@ | ||
| # | ||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # Adapted from vllm/model_executor/models/deepseek_mtp.py | ||
| # Copyright 2023 The vLLM team. | ||
| # | ||
| # This file is a part of the vllm-ascend project. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from transformers import PretrainedConfig | ||
| from vllm.attention.backends.abstract import AttentionMetadata | ||
| from vllm.config import CacheConfig, ModelConfig, VllmConfig | ||
| from vllm.model_executor.layers.layernorm import RMSNorm | ||
| from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
| from vllm.model_executor.layers.quantization import QuantizationConfig | ||
| from vllm.model_executor.layers.sampler import get_sampler | ||
| from vllm.model_executor.layers.vocab_parallel_embedding import \ | ||
| VocabParallelEmbedding | ||
| from vllm.model_executor.models.deepseek_mtp import ( | ||
| DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, | ||
| SharedHead) | ||
| from vllm.model_executor.models.utils import maybe_prefix | ||
| from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
|
|
||
| from .deepseek_v2 import CustomDeepseekV2DecoderLayer | ||
|
|
||
|
|
||
| class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: PretrainedConfig, | ||
| prefix: str, | ||
| model_config: ModelConfig, | ||
| cache_config: Optional[CacheConfig] = None, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| ) -> None: | ||
| nn.Module.__init__(self) | ||
| self.embed_tokens = VocabParallelEmbedding( | ||
| config.vocab_size, | ||
| config.hidden_size, | ||
| ) | ||
|
|
||
| self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
| self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
| self.eh_proj = nn.Linear(config.hidden_size * 2, | ||
| config.hidden_size, | ||
| bias=False) | ||
| self.shared_head = SharedHead(config=config, quant_config=quant_config) | ||
| self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix, | ||
| model_config, | ||
| cache_config, | ||
| quant_config) | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| kv_cache: torch.Tensor, | ||
| attn_metadata: AttentionMetadata, | ||
| previous_hidden_states: torch.Tensor, | ||
| inputs_embeds: Optional[torch.Tensor] = None, | ||
| spec_step_index: int = 0, | ||
| ) -> torch.Tensor: | ||
| if inputs_embeds is None: | ||
| inputs_embeds = self.embed_tokens(input_ids) | ||
| assert inputs_embeds is not None | ||
| # masking inputs at position 0, as not needed by MTP | ||
| inputs_embeds = torch.where((positions == 0).unsqueeze(-1), | ||
| torch.zeros_like(inputs_embeds), | ||
| inputs_embeds) | ||
| inputs_embeds = self.enorm(inputs_embeds) | ||
| previous_hidden_states = self.hnorm(previous_hidden_states) | ||
|
|
||
| hidden_states = self.eh_proj( | ||
| torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) | ||
|
|
||
| hidden_states, residual = self.mtp_block(positions=positions, | ||
| hidden_states=hidden_states, | ||
| kv_cache=kv_cache, | ||
| attn_metadata=attn_metadata, | ||
| residual=None) | ||
| hidden_states = residual + hidden_states | ||
| return hidden_states | ||
|
|
||
|
|
||
| class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): | ||
|
|
||
| def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
| nn.Module.__init__(self) | ||
| config = vllm_config.model_config.hf_config | ||
| self.mtp_start_layer_idx = config.num_hidden_layers | ||
| self.num_mtp_layers = config.num_nextn_predict_layers | ||
| # to map the exact layer index from weights | ||
| self.layers = torch.nn.ModuleDict({ | ||
| str(idx): CustomDeepSeekMultiTokenPredictorLayer( | ||
| config, | ||
| f"{prefix}.layers.{idx}", | ||
| model_config=vllm_config.model_config, | ||
| cache_config=vllm_config.cache_config, | ||
| quant_config=vllm_config.quant_config, | ||
| ) | ||
| for idx in range(self.mtp_start_layer_idx, | ||
| self.mtp_start_layer_idx + self.num_mtp_layers) | ||
| }) | ||
|
|
||
| # Note: torch._dynamo.exc.Unsupported: builtin: str | ||
| self.layers_list = [ | ||
| self.layers[str(idx)] | ||
| for idx in range(self.mtp_start_layer_idx, | ||
| self.mtp_start_layer_idx + self.num_mtp_layers) | ||
| ] | ||
| self.logits_processor = LogitsProcessor(config.vocab_size) | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| kv_caches: List[torch.Tensor], | ||
| attn_metadata: AttentionMetadata, | ||
| previous_hidden_states: torch.Tensor, | ||
| inputs_embeds: Optional[torch.Tensor] = None, | ||
| spec_step_idx: int = 0, | ||
| ) -> torch.Tensor: | ||
| current_step_idx = (spec_step_idx % self.num_mtp_layers) | ||
| return self.layers_list[current_step_idx]( | ||
| input_ids, | ||
| positions, | ||
| kv_caches[current_step_idx], | ||
| attn_metadata, | ||
| previous_hidden_states, | ||
| inputs_embeds, | ||
| current_step_idx, | ||
| ) | ||
|
|
||
| def compute_logits( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| sampling_metadata: SamplingMetadata, | ||
| spec_step_idx: int = 0, | ||
| ) -> torch.Tensor: | ||
| current_step_idx = (spec_step_idx % self.num_mtp_layers) | ||
| mtp_layer = self.layers_list[current_step_idx] | ||
| logits = self.logits_processor(mtp_layer.shared_head.head, | ||
| mtp_layer.shared_head(hidden_states), | ||
| sampling_metadata) | ||
| return logits | ||
|
|
||
|
|
||
| class CustomDeepSeekMTP(DeepSeekMTP): | ||
| # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; | ||
| # NOTE 2.The description file generated by the current msmodelslim tool does not have | ||
| # MTP layer info. Please manually add it and set the value to FLOAT. | ||
| packed_modules_mapping = { | ||
| "gate_up_proj": ["gate_proj", "up_proj"], | ||
| "experts": | ||
| ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] | ||
| } | ||
|
|
||
| def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
| nn.Module.__init__(self) | ||
| self.config = vllm_config.model_config.hf_config | ||
| self.model = CustomDeepSeekMultiTokenPredictor(vllm_config=vllm_config, | ||
| prefix=maybe_prefix( | ||
| prefix, "model")) | ||
|
|
||
| self.sampler = get_sampler() | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that this masking is done like the following in vllm. why we use
torch.wherehere? Is there any benifits?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original vLLM method cannot be used in torchair, so it is replaced with an equivalent method.
And the writing of
inputs_embeds[positions == 0]has poor performance on Ascend devicesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thanks for this explanation