From b624c4b7cd7d233a8571d0d95da5f61a51ab0156 Mon Sep 17 00:00:00 2001 From: underfitc Date: Fri, 30 May 2025 14:58:59 +0800 Subject: [PATCH 1/2] A2_DeeepSeek_prefill_opt --- vllm_ascend/distributed/parallel_state.py | 17 + vllm_ascend/model_executor/__init__.py | 0 vllm_ascend/model_executor/layers/__init__.py | 0 vllm_ascend/model_executor/layers/linear.py | 659 ++++++++++++++ vllm_ascend/models/__init__.py | 23 +- vllm_ascend/models/deepseek_v2_a2.py | 847 ++++++++++++++++++ 6 files changed, 1545 insertions(+), 1 deletion(-) create mode 100644 vllm_ascend/model_executor/__init__.py create mode 100644 vllm_ascend/model_executor/layers/__init__.py create mode 100644 vllm_ascend/model_executor/layers/linear.py create mode 100644 vllm_ascend/models/deepseek_v2_a2.py diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 895b7ffca0..e14dc1203f 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -20,6 +20,10 @@ def get_etp_group() -> GroupCoordinator: "expert tensor parallel group is not initialized") return _ETP +def get_wp_group() -> GroupCoordinator: + assert _WP is not None, ( + "world group is not initialized") + return _WP def init_ascend_model_parallel( tensor_model_parallel_size: int = 1, @@ -59,6 +63,14 @@ def init_ascend_model_parallel( backend, group_name="etp") + global _WP + all_ranks = torch.arange(world_size) + group_ranks = all_ranks.view(-1, world_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _WP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="wp") def destory_ascend_model_parallel(): global _EP @@ -70,3 +82,8 @@ def destory_ascend_model_parallel(): if _ETP: _ETP.destroy() _ETP = None + + global _WP + if _WP: + _WP.destroy() + _WP = None \ No newline at end of file diff --git a/vllm_ascend/model_executor/__init__.py b/vllm_ascend/model_executor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/model_executor/layers/__init__.py b/vllm_ascend/model_executor/layers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/model_executor/layers/linear.py b/vllm_ascend/model_executor/layers/linear.py new file mode 100644 index 0000000000..12108b95c9 --- /dev/null +++ b/vllm_ascend/model_executor/layers/linear.py @@ -0,0 +1,659 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import itertools +from abc import abstractmethod +from typing import Any, Literal, Optional, Union + +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter, UninitializedParameter + +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.utils import dispatch_unquantized_gemm +# yapf: disable +from vllm.model_executor.parameter import (BasevLLMParameter, + BlockQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + PerTensorScaleParameter, + RowvLLMParameter) +# yapf: enable +from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, + adjust_bitblas_shard, + adjust_marlin_shard, + adjust_bitsandbytes_4bit_shard, + adjust_scalar_to_fused_array, + LinearBase) +from vllm_ascend.distributed.parallel_state import get_wp_group +logger = init_logger(__name__) + + + +class RowParallelLinearWp(LinearBase): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + skip_bias_add: This was added to enable performance optimization where + bias can be fused with other element-wise operations. + We skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + # Divide the weight matrix along the first dimension. + self.tp_rank = get_wp_group().rank_in_group + self.tp_size = get_wp_group().world_size + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_wp_group().rank_in_group + tp_size = get_wp_group().world_size + input_dim = getattr(param, "input_dim", None) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + weight_shape = list(loaded_weight.shape) + if input_dim: + weight_shape[input_dim] = weight_shape[input_dim] // tp_size + param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) + + param_data = param.data + if input_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, + shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def weight_loader_v2(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + + param.load_row_parallel_weight(loaded_weight=loaded_weight) + + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_wp_group().rank_in_group + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + output = get_wp_group().all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + def extra_repr(self) -> str: + s = f"input_features={self.input_size_per_partition}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={self.tp_size}" + s += f", reduce_results={self.reduce_results}" + return s + +class ColumnParallelLinearWp(LinearBase): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Args: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + output_sizes: list of output sizes packed into one output, like for QKV + the list would be size 3. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[list[int]] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + # Divide the weight matrix along the last dimension. + self.tp_size = get_wp_group().world_size + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, self.tp_size) + for output_size in self.output_sizes + ] + + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + + self.gather_output = gather_output + + if output_sizes is None: + output_sizes = [output_size] + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_wp_group().rank_in_group + output_dim = getattr(param, "output_dim", None) + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + final_shape = list(loaded_weight.shape) + if output_dim is not None: + tp_size = get_wp_group().world_size + assert final_shape[output_dim] % tp_size == 0 + final_shape[output_dim] = final_shape[output_dim] // tp_size + param.materialize(final_shape, dtype=loaded_weight.dtype) + + param_data = param.data + if output_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[output_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + param.load_column_parallel_weight(loaded_weight=loaded_weight) + + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = get_wp_group().all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size_per_partition}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_wp_group().world_size}" + s += f", gather_output={self.gather_output}" + return s + +class MergedColumnParallelLinearWp(ColumnParallelLinearWp): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Args: + input_size: input dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make the output + available to all GPUs, otherwise, every GPU will have + its own output. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + self.output_sizes = output_sizes + tp_size = get_wp_group().world_size + assert all(output_size % tp_size == 0 for output_size in output_sizes) + super().__init__(input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias) + + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + + # Special case for GGUF + # initialize GGUF param after we know the quantize type + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + if loaded_shard_id is not None: + param.data[loaded_shard_id].copy_(loaded_weight) + param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + else: + param.shard_weight_type = { + i: loaded_weight.item() + for i, _ in enumerate(self.output_sizes) + } + return + + if is_gguf_weight: + tp_size = get_wp_group().world_size + tp_rank = get_wp_group().rank_in_group + + output_dim = getattr(param, "output_dim", None) + shard_size = loaded_weight.size(output_dim) // tp_size + start_idx = tp_rank * shard_size + + if loaded_shard_id is not None: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + param.shard_id.append(loaded_shard_id) + param.shard_id_map[loaded_shard_id] = len(param.data_container) + param.data_container.append(loaded_weight) + if len(param.data_container) == 2: + self.qweight = param.materialize_nested() + return + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + # Special case for per-tensor scale to load scalar into fused array. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (mlp). + # (e.g., Phi-3's gate_up_proj). + if output_dim is None: + if needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + current_shard_offset = 0 + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + shard_offsets: list[tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) + + if use_bitsandbytes_4bit: + index = list(itertools.accumulate([0] + self.output_sizes)) + orig_offsets = { + str(i): (index[i], size) + for i, size in enumerate(self.output_sizes) + } + orig_offsets["total"] = (self.output_size, 0) + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_offsets, str(shard_id)) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + assert loaded_shard_id < len(self.output_sizes) + tp_rank = get_wp_group().rank_in_group + tp_size = get_wp_group().world_size + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + # Special case for quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) + + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + + if use_bitsandbytes_4bit: + shard_size = loaded_weight.shape[output_dim] + shard_offset = loaded_weight.shape[output_dim] * \ + loaded_shard_id + + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + start_idx = tp_rank * shard_size + if not is_sharded_weight: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + # Special case for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_offset = loaded_shard_id * shard_size + param_data = param_data.narrow(0, shard_offset, shard_size) + + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id) + + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions.") + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): + """ + Handle special case for models where MLP layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + + current_shard_offset = 0 + shard_offsets: list[tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if isinstance(param, (PackedColumnParameter, PackedvLLMParameter + )) and param.packed_dim == param.output_dim: + shard_size, shard_offset = \ + param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset) + + loaded_weight_shard = loaded_weight.narrow(param.output_dim, + shard_offset, + shard_size) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2(self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + if loaded_shard_id is None: + if isinstance(param, PerTensorScaleParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight, + shard_id=0) + return + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight) + return + # TODO: @dsikka - move to parameter.py + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id < len(self.output_sizes) + + tp_size = get_wp_group().world_size + + if isinstance(param, BlockQuantScaleParameter): + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8LinearMethod, Fp8MoEMethod) + assert self.quant_method is not None + assert isinstance(self.quant_method, + (Fp8LinearMethod, Fp8MoEMethod)) + weight_block_size = self.quant_method.quant_config.weight_block_size + assert weight_block_size is not None + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = ( + (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // + block_n) // tp_size + shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // + block_n // tp_size) + else: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + + param.load_merged_column_weight(loaded_weight=loaded_weight, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size) \ No newline at end of file diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index e7f021fdb5..369c5ea3b9 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -1,13 +1,17 @@ from vllm import ModelRegistry - +import torch_npu def register_model(): from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 + from .deepseek_v2_a2 import CustomDeepseekV2ForCausalLM # noqa: F401 + from .deepseek_v2_a2 import CustomDeepseekV3ForCausalLM # noqa: F401 from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401 from .qwen2_5_vl import \ AscendQwen2_5_VLForConditionalGeneration # noqa: F401 from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 + torch_npu.npu._lazy_init() + soc_version = torch_npu._C.npu_get_soc_version ModelRegistry.register_model( "DeepSeekMTPModel", @@ -30,6 +34,23 @@ def register_model(): "DeepseekV3ForCausalLM", "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") + if soc_version == 223 or soc_version == 224: + ModelRegistry.register_model( + "DeepseekV2ForCausalLM", + "vllm_ascend.models.deepseek_v2_a2:CustomDeepseekV2ForCausalLM") + + ModelRegistry.register_model( + "DeepseekV3ForCausalLM", + "vllm_ascend.models.deepseek_v2_a2:CustomDeepseekV3ForCausalLM") + else: + ModelRegistry.register_model( + "DeepseekV2ForCausalLM", + "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM") + + ModelRegistry.register_model( + "DeepseekV3ForCausalLM", + "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") + ModelRegistry.register_model( "Qwen3MoeForCausalLM", "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM") diff --git a/vllm_ascend/models/deepseek_v2_a2.py b/vllm_ascend/models/deepseek_v2_a2.py new file mode 100644 index 0000000000..2f1d80d2f4 --- /dev/null +++ b/vllm_ascend/models/deepseek_v2_a2.py @@ -0,0 +1,847 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Adapted from +# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py +# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py +# """Inference-only DeepseekV2/DeepseekV3 model.""" + +import os +from contextlib import nullcontext +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.distributed as dist +import torch_npu +import torchair +import vllm.envs as envs +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import Attention, AttentionMetadata +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) +from vllm.distributed import (get_dp_group, get_pp_group, + get_tensor_model_parallel_world_size, + get_tp_group, tensor_model_parallel_all_reduce) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.deepseek_v2 import \ + DeepseekV2ForCausalLM # ruff: noqa: E501 +from vllm.model_executor.models.deepseek_v2 import \ + yarn_get_mscale # ruff: noqa: E501 +from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, + DeepseekV2DecoderLayer, + DeepseekV2MLAAttention) +from vllm.model_executor.models.utils import ( + PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) +from vllm.sequence import IntermediateTensors + +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod +from vllm_ascend.utils import dispose_tensor +from vllm_ascend.distributed.parallel_state import get_wp_group +from vllm_ascend.model_executor.layers.linear import ( + MergedColumnParallelLinearWp, RowParallelLinearWp) +from torch.nn import functional as F +from vllm.logger import logger +@dataclass +class DPMetadataForPadding: + cu_tokens_across_dp_cpu: torch.Tensor + lengths: torch.Tensor + max_length: int + pad_size: torch.Tensor + atten_unpad_mask: torch.Tensor + +_dp_metadata_for_padding: Optional[DPMetadataForPadding] = None + +def padding_aligned_tp(dp_rank, data: torch.Tensor) -> torch.Tensor: + + + pad_size = _dp_metadata_for_padding.pad_size + + if pad_size[dp_rank] == 0: + return data + + return F.pad(data, (0, 0, 0, pad_size[dp_rank])) + +def padding_aligned_wp(data: torch.Tensor, is_prefill, layer_idx) -> torch.Tensor: + + lengths = _dp_metadata_for_padding.lengths + max_length = _dp_metadata_for_padding.max_length + + merged_data = torch.zeros((max_length*len(lengths), data.shape[1]), + dtype=data.dtype, device=data.device) + padded_starts = 0 + current_pos = 0 + for dp_rank in range(len(lengths)): + seq_len = lengths[dp_rank].item() + + merged_data[current_pos:current_pos + seq_len] = data[padded_starts:padded_starts + seq_len] + + current_pos += max_length + padded_starts += seq_len + return merged_data + + +def unpadding_aligned_tp(padded_data: torch.Tensor) -> torch.Tensor: + + atten_unpad_mask = _dp_metadata_for_padding.atten_unpad_mask + merged_data = padded_data[atten_unpad_mask, :] + return merged_data + +def unpadding_aligned_wp(dp_rank, padded_data: torch.Tensor) -> torch.Tensor: + + lengths = _dp_metadata_for_padding.lengths + max_length = _dp_metadata_for_padding.max_length + seq_len = lengths[dp_rank].item() + + padded_data = padded_data[max_length * dp_rank :max_length * dp_rank + seq_len] + return padded_data + +class CustomDeepseekV2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + # NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant + self.is_dynamic_quant = not isinstance( + self.gate_up_proj.quant_method, + UnquantizedLinearMethod) and isinstance( + self.gate_up_proj.quant_method.quant_method, + AscendW8A8DynamicLinearMethod) + + def forward(self, x, is_prefill: bool = False, reduce_results: bool = True) -> torch.Tensor: + if self.is_dynamic_quant: + x, dynamic_scale = torch_npu.npu_dynamic_quant(x) + x = torch_npu.npu_quant_matmul( + x, + self.gate_up_proj.weight, + self.gate_up_proj.weight_scale, + output_dtype=torch.int32, + ) + x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( + x=x, + weight_scale=self.gate_up_proj.weight_scale_fp32, + activation_scale=dynamic_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=None, + activate_left=True, + quant_mode=1) + x = torch_npu.npu_quant_matmul( + x, + self.down_proj.weight, + self.down_proj.weight_scale, + pertoken_scale=dynamic_scale, + output_dtype=torch.bfloat16, + ) + if reduce_results and self.down_proj.tp_size > 1: + x = tensor_model_parallel_all_reduce(x) + return x + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + +class CustomDeepseekV2SharedExpertMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinearWp( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinearWp(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + # NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant + self.is_dynamic_quant = not isinstance( + self.gate_up_proj.quant_method, + UnquantizedLinearMethod) and isinstance( + self.gate_up_proj.quant_method.quant_method, + AscendW8A8DynamicLinearMethod) + + def forward(self, x, is_prefill: bool = False, reduce_results: bool = True) -> torch.Tensor: + if self.is_dynamic_quant: + x, dynamic_scale = torch_npu.npu_dynamic_quant(x) + x = torch_npu.npu_quant_matmul( + x, + self.gate_up_proj.weight, + self.gate_up_proj.weight_scale, + output_dtype=torch.int32, + ) + x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( + x=x, + weight_scale=self.gate_up_proj.weight_scale_fp32, + activation_scale=dynamic_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=None, + activate_left=True, + quant_mode=1) + x = torch_npu.npu_quant_matmul( + x, + self.down_proj.weight, + self.down_proj.weight_scale, + pertoken_scale=dynamic_scale, + output_dtype=torch.bfloat16, + ) + if reduce_results and self.down_proj.tp_size > 1: + x = tensor_model_parallel_all_reduce(x) + return x + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + +class CustomDeepseekV2MoE(nn.Module): + + top_k: int + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + self.routed_scaling_factor = config.routed_scaling_factor + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}.") + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts)) + else: + self.gate.e_score_correction_bias = None + + self.experts = AscendFusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = CustomDeepseekV2SharedExpertMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + CustomDeepseekV2MoE.top_k = config.num_experts_per_tok + + self.params_dtype = torch.get_default_dtype() + self.tp_rank_in_group = get_tp_group().rank_in_group + self.tp_group = get_tp_group().device_group + self.dp_size = get_dp_group().world_size + self.enable_graph_mode = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) + + def forward(self, + hidden_states: torch.Tensor, + is_prefill: bool = False) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + # hidden_states = hidden_states.view(-1, hidden_dim) + + # MC2 no mc2 + # prefill_req allreduce+allreduce allreduce+allreduce + # decode_req all_gather+allreduce allreduce+allreduce + + if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill and self.tp_size > 1: + chunks = torch.chunk(hidden_states, self.tp_size, dim=0) + hidden_states = chunks[self.tp_rank_in_group] + + if self.dp_size > 1 and self.enable_graph_mode and not is_prefill: + stream_ctx = torchair.scope.npu_stream_switch( + "CustomDeepseekV2MoE_dp_graph_decode") + else: + stream_ctx = nullcontext() + + with stream_ctx: + # router_logits: (num_tokens, n_experts) + # gating after all_gather + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=CustomDeepseekV2MoE.top_k + ) * self.routed_scaling_factor + + return hidden_states + + +class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.o_proj") + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + # In the MLA backend, kv_cache includes both k_c and + # pe (i.e. decoupled position embeddings). In particular, + # the concat_and_cache_mla op requires + # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) + # i.e. + # kv_lora_rank + qk_rope_head_dim == head_size + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + ) + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + self.enable_graph_mode = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if self.q_lora_rank is not None: + ckq = self.q_a_proj(hidden_states)[0] + hidden_states_or_q_c = self.q_a_layernorm(ckq) + else: + hidden_states_or_q_c = hidden_states + if self.enable_graph_mode: + forward_kwargs = {} + if envs.VLLM_USE_V1: + output_shape = hidden_states.shape + output = torch.empty(output_shape, + dtype=hidden_states_or_q_c.dtype, + device=hidden_states_or_q_c.device) + forward_kwargs['output'] = output + + output = self.mla_attn.impl.forward(self.mla_attn, + hidden_states_or_q_c, + hidden_states, None, kv_cache, + attn_metadata, + **forward_kwargs) + if envs.VLLM_USE_V1: + output = output.view(-1, output_shape[-1]) + return output + else: + kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + return self.mla_attn(hidden_states_or_q_c, + kv_c_normed, + k_pe, + output_shape=hidden_states.shape) + + +class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + self.config = config + # TODO: enable mla in vllm-ascend + if model_config.use_mla: + attn_cls = CustomDeepseekV2MLAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.is_moe = config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace \ + and layer_idx % config.moe_layer_freq == 0 + if self.is_moe: + self.mlp = CustomDeepseekV2MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.shared_experts = self.mlp.shared_experts if config.n_shared_experts is not None else None + else: + self.mlp = CustomDeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + + self.tp_rank_in_group = get_tp_group().rank_in_group + self.tp_size = get_tp_group().world_size + self.dp_size = get_dp_group().world_size + self.dp_rank = (0 if self.dp_size == 1 else get_dp_group().rank_in_group) + self.enable_graph_mode = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) + + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + is_prefill: bool = False, + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + + # Fully Connected + hidden_states, residual = self.post_attention_process(hidden_states, residual, is_prefill) + if self.is_moe: + shared_output = None + if self.config.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states, is_prefill = is_prefill, reduce_results=False) + + hidden_states = self.mlp(hidden_states, is_prefill) + + if shared_output is not None: + hidden_states = hidden_states + shared_output + + hidden_states, residual = self.post_moe_process(shared_output, hidden_states, residual, is_prefill) + + else: + hidden_states = self.mlp(hidden_states, is_prefill) + hidden_states = hidden_states + residual + residual = None + + if isinstance( + self.mlp, + CustomDeepseekV2MLP) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1. / self.routed_scaling_factor + + return hidden_states, residual + +class CustomDeepseekV2Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: CustomDeepseekV2DecoderLayer( + config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + is_prefill: bool = False, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, residual, + kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata, is_prefill) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states = self.norm(hidden_states) + return hidden_states + + +class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): + # add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = CustomDeepseekV2Model(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + self.dp_size = get_dp_group().world_size + self.dp_rank = (0 if self.dp_size == 1 else get_dp_group().rank_in_group) + self.tp_size = get_tp_group().world_size + self.enable_graph_mode = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + is_prefill: bool = False, + ) -> Union[torch.Tensor, IntermediateTensors]: + if is_prefill or not self.enable_graph_mode: + cu_tokens_across_dp_cpu = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu + # get padding data + lengths = torch.cat([cu_tokens_across_dp_cpu[:1], cu_tokens_across_dp_cpu[1:] - cu_tokens_across_dp_cpu[:-1]]) + max_length = lengths.max().item() + max_length = ((max_length + self.tp_size - 1) // self.tp_size) * self.tp_size + pad_size = -(lengths - max_length) + + # generate padding index + position_matrix = torch.arange(max_length).expand(self.dp_size, max_length) + lengths_tensor = lengths.view(-1, 1) + atten_unpad_mask = (position_matrix < lengths_tensor).view(-1).to("npu", non_blocking=True) + + global _dp_metadata_for_padding + _dp_metadata_for_padding = DPMetadataForPadding(cu_tokens_across_dp_cpu, lengths, max_length, pad_size, atten_unpad_mask) + + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds, is_prefill) + if is_prefill or not self.enable_graph_mode: + del atten_unpad_mask + return hidden_states + + +class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): + pass \ No newline at end of file From 2fe44e221b5eae6231867baa62852e1c95834de0 Mon Sep 17 00:00:00 2001 From: liziyu <56102866+liziyu179@users.noreply.github.com> Date: Fri, 30 May 2025 15:34:16 +0800 Subject: [PATCH 2/2] DeepSeek prefill optimization using flash_comm_v1 (A2) --- vllm_ascend/models/deepseek_v2_a2.py | 80 +++++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/models/deepseek_v2_a2.py b/vllm_ascend/models/deepseek_v2_a2.py index 2f1d80d2f4..a548e60f16 100644 --- a/vllm_ascend/models/deepseek_v2_a2.py +++ b/vllm_ascend/models/deepseek_v2_a2.py @@ -625,6 +625,84 @@ def __init__( self.enable_graph_mode = additional_config.get( "enable_graph_mode", False) + def post_attention_process(self, hidden_states, residual, is_prefill): + if is_prefill or not self.enable_graph_mode: + if self.dp_size <= 1 or self.layer_idx < self.config.first_k_dense_replace: + hidden_states = get_tp_group().all_reduce(hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + else: + # padding hidden_states + hidden_states = padding_aligned_tp(self.dp_rank, hidden_states) + # RS hidden_states + hidden_states = dist._functional_collectives.reduce_scatter_tensor( + hidden_states, + "sum", + scatter_dim=0, + group=get_tp_group().device_group) + if self.layer_idx == self.config.first_k_dense_replace: + # padding and slice residual + reduce_scatter_tokens = hidden_states.size(0) + residual = F.pad(residual, (0, 0, 0, reduce_scatter_tokens * self.tp_size - residual.size(0))) + start = self.tp_rank_in_group * reduce_scatter_tokens + residual = residual[start:start + reduce_scatter_tokens] + # post layernorm + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + # 全局 all_gather + hidden_states = get_wp_group().all_gather(hidden_states, 0) + # unpad + hidden_states = unpadding_aligned_tp(hidden_states) + + else: + if self.tp_size > 1: + hidden_states = get_tp_group().all_reduce(hidden_states) + if self.enable_graph_mode and not envs_ascend.VLLM_ENABLE_MC2: + hidden_states = get_dp_group().all_gather(hidden_states, 0) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + return hidden_states, residual + + def post_moe_process(self, shared_output, hidden_states, residual, is_prefill): + if is_prefill or not self.enable_graph_mode: + hidden_states = shared_output + hidden_states + if self.dp_size <= 1: + hidden_states = get_wp_group().all_reduce(hidden_states) + return hidden_states, None + hidden_states = padding_aligned_wp(hidden_states, is_prefill, self.layer_idx) + + # RS hidden_states + hidden_states = dist._functional_collectives.reduce_scatter_tensor( + hidden_states, + "sum", + scatter_dim=0, + group=get_wp_group().device_group) + # add residual + hidden_states = hidden_states + residual + residual = hidden_states + # 全局 all_gather + hidden_states = get_wp_group().all_gather(hidden_states, 0) + # unpad + hidden_states = unpadding_aligned_wp(self.dp_rank, hidden_states) + if self.layer_idx == self.num_hidden_layers - 1: + residual = None + return hidden_states, residual + else: + if envs_ascend.VLLM_ENABLE_MC2: + shared_output = self.shared_experts(hidden_states, is_prefill = False, reduce_results=True) + num_tokens, hidden_dim = hidden_states.shape + final_hidden_states = torch.zeros([num_tokens, hidden_dim], + dtype=self.params_dtype, + device="npu") + dist.all_gather_into_tensor(final_hidden_states, hidden_states, + self.tp_group) + hidden_states = final_hidden_states + hidden_states = shared_output + final_hidden_states + hidden_states = hidden_states + residual + else: + hidden_states = shared_output + hidden_states + hidden_states = get_wp_group().all_reduce(hidden_states) + hidden_states = hidden_states + residual + return hidden_states, None def forward( self, @@ -844,4 +922,4 @@ def forward( class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): - pass \ No newline at end of file + pass