Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions vllm_ascend/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.deepseek_mtp import (
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
SharedHead)
Expand All @@ -40,6 +40,20 @@
from .deepseek_v2 import CustomDeepseekV2DecoderLayer


class CustomDeepSeekShareHead(SharedHead):

def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
nn.Module.__init__(self)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "head"))


class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):

def __init__(
Expand All @@ -61,7 +75,10 @@ def __init__(
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.shared_head = CustomDeepSeekShareHead(config=config,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "shared_head"))
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
model_config,
cache_config,
Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
Expand Down
23 changes: 23 additions & 0 deletions vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
from vllm.model_executor.parameter import PerTensorScaleParameter
from vllm.model_executor.utils import set_weight_attrs

Expand Down Expand Up @@ -107,6 +109,12 @@ def get_quant_method(self, layer: torch.nn.Module,
return AscendUnquantizedFusedMoEMethod()
return AscendFusedMoEMethod(self, prefix,
self.packed_modules_mapping)
elif isinstance(layer, VocabParallelEmbedding):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return UnquantizedEmbeddingMethod()
return AscendEmbeddingMethod(self, prefix,
self.packed_modules_mapping)
return None

def is_layer_skipped_ascend(
Expand Down Expand Up @@ -319,3 +327,18 @@ def apply(
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)


class AscendEmbeddingMethod(AscendLinearMethod):
"""Embedding method for Ascend quantization.
This class calls AscendQuantizer to search a specific quantization
implementations supported on ascend hardware for Embedding methods.
Args:
quant_config: The Ascend quantization config.
"""

def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str, Any]) -> None:
self.quantizer = AscendQuantizer.get_quantizer(
quant_config.quant_description, prefix, packed_modules_mapping)
self.quant_method = self.quantizer.build_linear_method()
Loading