From 7de3736b424d9b64a97afd41bf4f07bf21e91433 Mon Sep 17 00:00:00 2001 From: zhuyilin <809721801@qq.com> Date: Sat, 28 Jun 2025 16:26:29 +0800 Subject: [PATCH] support pangu moe w8a8c8 Signed-off-by: zhuyilin <809721801@qq.com> --- docs/source/user_guide/additional_config.md | 1 + vllm_ascend/attention/attention_v1.py | 23 +- vllm_ascend/models/pangu_moe.py | 49 +- vllm_ascend/platform.py | 4 + vllm_ascend/quantization/quant_config.py | 34 +- vllm_ascend/quantization/quantizer.py | 14 +- vllm_ascend/quantization/w8a8.py | 583 +++++++++++++++++++- vllm_ascend/worker/model_runner_v1.py | 31 +- 8 files changed, 689 insertions(+), 50 deletions(-) diff --git a/docs/source/user_guide/additional_config.md b/docs/source/user_guide/additional_config.md index 79b4d9047e..e755b93796 100644 --- a/docs/source/user_guide/additional_config.md +++ b/docs/source/user_guide/additional_config.md @@ -32,6 +32,7 @@ The following table lists the additional configuration options available in vLLM | `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. | | `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. | | `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. | +| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. | The details of each config option are as follows: diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 5451508c81..985997e1bf 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -69,6 +69,15 @@ def get_kv_cache_shape( 16) return (2, num_blocks, block_size, num_kv_heads, head_size) + @staticmethod + def get_bsh_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size, num_kv_heads * head_size) + @staticmethod def swap_blocks( src_kv_cache: List[torch.Tensor], @@ -279,6 +288,13 @@ def forward( value=value, output=output, layer_name=layer.layer_name) + + elif hasattr(layer, 'quant_method'): + output = layer.quant_method.apply(layer, query, key, value, + kv_cache, attn_metadata, + self.attn_type, self.scale, + output) + else: if attn_metadata is None: return output.view(num_tokens, self.hidden_size) @@ -308,11 +324,8 @@ def forward( value_cache=self.value_cache, slot_indices=slots) - if hasattr(layer, 'quant_method'): - # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata - pass # V0-Style scheduler situation. - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: assert attn_metadata is not None assert attn_metadata.attn_mask is not None mask = attn_metadata.attn_mask @@ -414,6 +427,8 @@ def forward( out=output) # to make in-place change to the output tensor + if hasattr(layer, 'quant_method'): + output = output.view(num_tokens, self.num_heads, self.head_size) ori_output[:, :, :] = output[:num_tokens, :, :] return output.view(num_tokens, self.hidden_size) diff --git a/vllm_ascend/models/pangu_moe.py b/vllm_ascend/models/pangu_moe.py index 644a00ef41..e01e409989 100644 --- a/vllm_ascend/models/pangu_moe.py +++ b/vllm_ascend/models/pangu_moe.py @@ -505,7 +505,7 @@ def forward( # native FusedMoE. here we need to design a better FusedMoE # (maybe using AscendFusedMoE) to enable these different # communication schema. - final_hidden_states = self.experts.quant_method( + final_hidden_states = self.experts.quant_method.apply( layer=self.experts, x=hidden_states, router_logits=router_logits, @@ -937,6 +937,8 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + tp_size = get_tp_group().world_size + tp_rank = get_tp_group().rank_in_group stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -972,6 +974,51 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if "module" in name: continue + if name.endswith('kv_cache_offset'): + continue + + if name.endswith("k_proj.kv_cache_scale"): + remapped_kv_scale_name = name.replace( + "k_proj.kv_cache_scale", "attn.key_antiquant_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + loaded_weight = torch.tensor_split(loaded_weight, + tp_size, + dim=0)[tp_rank] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + if name.endswith("v_proj.kv_cache_scale"): + remapped_kv_scale_name = name.replace( + "v_proj.kv_cache_scale", "attn.value_antiquant_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + loaded_weight = torch.tensor_split(loaded_weight, + tp_size, + dim=0)[tp_rank] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 881b73246f..4c92abffb5 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -124,6 +124,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config cache_config = vllm_config.cache_config + kv_cache_dtype = vllm_config.additional_config.get( + "kv_cache_dtype", None) + if kv_cache_dtype is not None: + vllm_config.cache_config.cache_dtype = kv_cache_dtype if parallel_config: # Default value for expert tensor parallel size diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 3567dba355..7c7ee58033 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -98,6 +98,9 @@ def get_quant_method(self, layer: torch.nn.Module, 'fa_quant_type' in self.quant_description.keys() and \ self.quant_description['fa_quant_type'] is not None: return AscendKVCacheMethod(self, prefix) + elif isinstance(layer, Attention) and self.quant_description.get( + 'kv_quant_type') == 'C8': + return AscendKVCacheMethod(self, prefix) elif isinstance(layer, FusedMoE): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): @@ -235,32 +238,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) - def apply(self, - layer: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - k_cache: List[torch.Tensor], - v_cache: List[torch.Tensor], - scale: torch.Tensor, - block_tables: torch.Tensor, - isPrefill: bool, - attn_metadata, - output, - seq_lens_tensor_cpu: Optional[int] = None) -> torch.Tensor: - return self.quant_method.apply(layer, - query, - key, - value, - k_cache, - v_cache, - scale, - block_tables, - isPrefill, - attn_metadata.attn_mask, - attn_metadata.slot_mapping, - output, - seq_lens_tensor_cpu=seq_lens_tensor_cpu) + def apply(self, layer: torch.nn.Module, query: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata, + attn_type, scale, output) -> torch.Tensor: + return self.quant_method.apply(layer, query, key, value, kv_cache, + attn_metadata, attn_type, scale, output) class AscendFusedMoEMethod(FusedMoEMethodBase): diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index ea1297bf35..81326fb03f 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -24,7 +24,8 @@ from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init) -from .w8a8 import AscendW8A8LinearMethod +from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, + AscendW8A8LinearMethod) from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod) @@ -250,6 +251,8 @@ def get_quantizer(cls, # Attention if '.attn' in prefix and 'fa_quant_type' in quant_description.keys(): quant_type = quant_description['fa_quant_type'] + if '.attn' in prefix and 'kv_quant_type' in quant_description.keys(): + quant_type = quant_description['kv_quant_type'] # Linear else: quant_type = cls.get_linear_quant_type(quant_description, prefix, @@ -269,6 +272,14 @@ class W8A8Quantizer(VLLMAscendQuantizer): def build_linear_method(): return AscendW8A8LinearMethod() + @staticmethod + def build_moe_method(): + return AscendW8A8FusedMoEMethod() + + @staticmethod + def build_attention_method(): + return AscendC8KVCacheMethod() + class W8A8DYNAMICQuantizer(VLLMAscendQuantizer): @@ -284,4 +295,5 @@ def build_moe_method(): SUPPORT_ASCEND_QUANTIZER_TYPE = { "W8A8": W8A8Quantizer, "W8A8_DYNAMIC": W8A8DYNAMICQuantizer, + "C8": W8A8Quantizer, } diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index db23cb024d..6a2f4039de 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -15,16 +15,23 @@ # limitations under the License. # -from typing import Any, Dict, Optional +import os +from typing import Any, Callable, Dict, Optional import torch import torch_npu +from vllm.attention.backends.abstract import AttentionType +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.distributed.parallel_state import get_ep_group -def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor, - input_offset: torch.Tensor): + +def quant_per_tensor(in_tensor: torch.Tensor, + input_scale: torch.Tensor, + input_offset: torch.Tensor, + function=False): return torch_npu.npu_quantize(in_tensor, input_scale, input_offset, - torch.qint8, -1, False) + torch.qint8, -1, function) class AscendW8A8LinearMethod: @@ -86,19 +93,17 @@ def apply( ) -> torch.Tensor: original_dtype = x.dtype if original_dtype != torch.int8: - x = quant_per_tensor( - x, - layer.aclnn_input_scale, - layer.aclnn_input_offset, - ) + x = quant_per_tensor(x, layer.aclnn_input_scale, + layer.aclnn_input_offset) quant_bias = layer.quant_bias if tp_rank == 0 else None - return torch_npu.npu_quant_matmul( + output = torch_npu.npu_quant_matmul( x, layer.weight, layer.deq_scale, bias=quant_bias, output_dtype=original_dtype, ) + return output def process_weights_after_loading(self, layer): expanding_factor = layer.weight.data.shape[1] @@ -113,3 +118,561 @@ def process_weights_after_loading(self, layer): layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) + + +class AscendW8A8FusedMoEMethod: + """FusedMoe method for Ascend W8A8. + """ + + def __init__(self): + self.transpose_weight = True + + @staticmethod + def get_weight(num_experts: int, intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight"] = torch.empty(num_experts, + 2 * + intermediate_size_per_partition, + hidden_sizes, + dtype=torch.int8, + requires_grad=False) + param_dict["w2_weight"] = torch.empty(num_experts, + hidden_sizes, + intermediate_size_per_partition, + dtype=torch.int8, + requires_grad=False) + return param_dict + + @staticmethod + def get_dynamic_quant_param(num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32) + param_dict["w13_weight_offset"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float16) + param_dict["w2_weight_scale"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=torch.float32) + param_dict["w2_weight_offset"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=torch.float16) + param_dict["w2_deq_scale"] = torch.empty(num_experts, + hidden_sizes, + dtype=torch.float32) + param_dict["w13_deq_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32) + param_dict["w2_input_scale"] = torch.empty(num_experts, + 1, + dtype=torch.float32) + param_dict["w13_input_scale"] = torch.empty(num_experts, + 1, + dtype=torch.float32) + param_dict["w2_input_offset"] = torch.empty(num_experts, + 1, + dtype=torch.int8) + param_dict["w13_input_offset"] = torch.empty(num_experts, + 1, + dtype=torch.int8) + param_dict["quant_bias"] = torch.empty(num_experts, + hidden_sizes, + dtype=torch.int32) + + return param_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = False, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[ + 1] == global_num_experts, "Number of global experts mismatch" + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if global_num_experts == 256: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, + bias=e_score_correction_bias, + k_group=topk_group, + group_count=num_expert_group, + group_select_mode=1, + renorm=0, + norm_type=1, + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + ) + + if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill: + raise NotImplementedError("W8A8FusedMoe are not " + "mplemented for VLLM_ENABLE_MC2") + + else: + return fused_experts(hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w1_input_scale=layer.w13_input_scale, + w1_input_offset=layer.w13_input_offset, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + w2_input_scale=layer.w2_input_scale, + w2_input_offset=layer.w2_input_offset, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + global_num_experts=global_num_experts, + expert_map=expert_map) + + def process_weights_after_loading(self, layer): + # torch.npu.config.allow_internal_format = True + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose(1, + 2).contiguous() + layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( + layer.w13_weight_scale.data.shape[0], -1).to(torch.float32) + + layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( + layer.w13_weight_offset.data.shape[0], -1).to(torch.float16) + layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( + layer.w2_weight_scale.data.shape[0], -1).to(torch.float32) + layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( + layer.w2_weight_offset.data.shape[0], -1).to(torch.float16) + expanding_factor_w13 = layer.w13_weight.data.shape[1] + expanding_factor_w2 = layer.w2_weight.data.shape[1] + layer.w13_input_scale.data = torch.nn.Parameter( + layer.w13_input_scale.data.repeat( + 1, expanding_factor_w13)[0:1]).to(torch.float16) + + layer.w2_input_scale.data = torch.nn.Parameter( + layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]).to( + torch.float16) + layer.w13_input_offset.data = torch.nn.Parameter( + layer.w13_input_scale.data.repeat( + 1, expanding_factor_w13)[0:1]).to(torch.int8) + layer.w2_input_offset.data = torch.nn.Parameter( + layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]).to( + torch.int8) + + # NZ + # layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, 29).contiguous() + # layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, 29).contiguous() + + +class AscendC8KVCacheMethod: + + def __init__(self) -> None: + self.antiquant_scale_comb = None + + @staticmethod + def create_weights(layer) -> None: + param_dict = {} # num_kv_heads * head_size + param_dict["key_antiquant_scale"] = torch.empty(layer.num_kv_heads * + layer.head_size, + dtype=torch.float16, + requires_grad=False) + param_dict["value_antiquant_scale"] = torch.empty(layer.num_kv_heads * + layer.head_size, + dtype=torch.float16, + requires_grad=False) + for weight_name, weight_param in param_dict.items(): + param = torch.nn.Parameter(weight_param, requires_grad=False) + layer.register_parameter(weight_name, param) + + def process_weights_after_loading(self, layer): + self.antiquant_scale_comb = torch.cat( + (layer.key_antiquant_scale.data.unsqueeze(0), + layer.value_antiquant_scale.data.unsqueeze(0)), + dim=0).to(torch.float16).contiguous() + + def apply(self, layer, query, key, value, kv_cache, attn_metadata, + attn_type, scale, output) -> torch.Tensor: + num_tokens = query.shape[0] + if attn_metadata is None: + return output.view(num_tokens, layer.num_heads * layer.head_size) + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + + # C8 + quant_key = quant_per_tensor( + key.view(-1, layer.num_kv_heads * layer.head_size), + layer.key_antiquant_scale.data.view(-1), None, True) + quant_value = quant_per_tensor( + value.view(-1, layer.num_kv_heads * layer.head_size), + layer.value_antiquant_scale.data.view(-1), None, True) + + # View q k v to BSH. + query = query.view(-1, layer.num_heads, layer.head_size) + key = key.view(-1, layer.num_kv_heads, layer.head_size) + value = value.view(-1, layer.num_kv_heads, layer.head_size) + # TODO: Remove this contiguous in the future. + value = value.contiguous() + + if kv_cache[0].numel() > 0: + # if key_cache is None: + key_cache, value_cache = kv_cache[0], kv_cache[1] + slots = attn_metadata.slot_mapping + + block_size = key_cache.shape[1] + slots_indices = slots.reshape(-1, 1) + block_indices = slots_indices // block_size + slots_indices = slots_indices % block_size + indices = torch.cat((block_indices, slots_indices), dim=1) + + # C8 + torch_npu.npu_scatter_nd_update_(key_cache, indices, quant_key) + torch_npu.npu_scatter_nd_update_(value_cache, indices, quant_value) + + # V0-Style scheduler situation. + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None + mask = attn_metadata.attn_mask + torch_npu._npu_flash_attention(query=query, + key=key, + value=value, + mask=mask, + seq_len=attn_metadata.seq_lens, + scale_value=scale, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + out=output.reshape(query.shape)) + + elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + raise NotImplementedError("kv cache int8 are not " + "implemented for " + "PrefillCacheHit") + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly + # torch_air + # decode_meta = attn_metadata.decode + # seq_lens = decode_meta.seq_lens_list + seq_lens = attn_metadata.seq_lens + block_size = key_cache.shape[1] + query = query.view(num_tokens, 1, layer.num_heads * + layer.head_size).contiguous() # changed + + # [num_blocks, block_size, N, D] --> [num_blocks, N, block_size, D] + key = key_cache + value = value_cache + + output = torch_npu.npu_incre_flash_attention( + query, + key, + value, + num_key_value_heads=layer.num_kv_heads, + num_heads=layer.num_heads, + actual_seq_lengths=seq_lens, + scale_value=scale, + input_layout='BSH', + block_size=block_size, + block_table=attn_metadata.block_tables, + antiquant_scale=self.antiquant_scale_comb, + ) + + # Normal V1 situation. + else: + raise NotImplementedError("kv cache int8 are not " + "implemented for " + "other case") + return output + + +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w1_input_scale: torch.Tensor, + w1_input_offset: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + w2_input_scale: torch.Tensor, + w2_input_offset: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + global_num_experts: int, + expert_map: torch.Tensor = None, +) -> torch.Tensor: + """ + Fused experts with top-k routing. + + Args: + hidden_states: Hidden states of shape (num_tokens, hidden_size). + w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). + w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). + topk_weights: Routing weights of shape (num_tokens, top_k). + topk_ids: Selected expert IDs of shape (num_tokens, top_k). + top_k: Number of experts to select. + expert_map: Expert mapping of shape (num_experts,). + + Returns: + hidden_states: Hidden states after routing. + """ + """ + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + """ + + original_dtype = hidden_states.dtype + ep_size = get_ep_group().world_size + local_num_experts = global_num_experts // ep_size + w1_input_scale, _ = w1_input_scale.max(0) + quant_sorted_hidden_states = quant_per_tensor( + hidden_states, + w1_input_scale, + None, + True, + ) + if expert_map is not None: + expanded_x, expanded_row_idx, expert_token_count, expanded_scale = torch_npu.npu_moe_init_routing_v2( + quant_sorted_hidden_states, + topk_ids, + scale=None, + active_num=topk_ids.numel(), + expert_capacity=-1, + expert_num=local_num_experts, + drop_pad_mode=0, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + quant_mode=-1, + active_expert_range=[0, local_num_experts], + row_idx_type=0, + ) + + else: + raise NotImplementedError( + "The quantified version of MOE class models " + "currently does not support tensor parallelism") + if expanded_x.dtype != w1.dtype: + w1_input_scale, _ = w1_input_scale.max(0) + quant_sorted_hidden_states = quant_per_tensor( + expanded_x, + w1_input_scale, + None, + True, + ) + else: + quant_sorted_hidden_states = expanded_x + gate_up_out = torch_npu.npu_grouped_matmul( + x=[quant_sorted_hidden_states], + weight=[w1], + scale=[w1_scale * w1_input_scale[0]], + split_item=2, + group_list_type=1, + group_type=0, + group_list=expert_token_count, + output_dtype=original_dtype, + )[0] + gate_up_out = torch_npu.npu_swiglu(gate_up_out) + + if gate_up_out.dtype != w2.dtype: + w2_input_scale, _ = w2_input_scale.max(0) + quant_gate_up_out = quant_per_tensor( + gate_up_out, + w2_input_scale, + None, + True, + ) + else: + quant_gate_up_out = gate_up_out + + down_out = torch_npu.npu_grouped_matmul( + x=[quant_gate_up_out], + weight=[w2], + scale=[w2_scale * w2_input_scale[0]], + split_item=2, + group_list_type=1, + group_type=0, + group_list=expert_token_count, + output_dtype=original_dtype, + )[0] + + if expert_map is not None: + final_hidden_states = torch_npu.npu_moe_finalize_routing( + down_out, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights.to(down_out.dtype), + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + drop_pad_mode=2, + ) + else: + raise NotImplementedError( + "The quantified version of MOE class models " + "currently does not support tensor parallelism") + + return final_hidden_states + + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + global_num_experts=-1, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Select top-k experts based on router logits. + + Args: + hidden_states: Hidden states of shape (num_tokens, hidden_size). + router_logits: Router logits of shape (num_tokens, num_experts). + top_k: Number of experts to select. + use_grouped_topk: Whether to group experts before selecting top-k. + renormalize: Whether to renormalize the routing weights. + topk_group: Number of expert groups to select from. + num_expert_group: Number of experts in each group. + custom_routing_function: Custom routing function. + scoring_func: Scoring function to use. + e_score_correction_bias: Correction bias to apply to expert scores. + + Returns: + topk_weights: Routing weights of shape (num_tokens, top_k). + topk_ids: Selected expert IDs of shape (num_tokens, top_k). + + Raises: + ValueError: If an unsupported scoring function is provided. + """ + + if scoring_func == "softmax": + # NOTE: vLLM use dtype=torch.float here + topk_weights = router_logits.softmax(dim=-1) + elif scoring_func == "sigmoid": + topk_weights = router_logits.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_weights = topk_weights + topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0) + + # TODO: Change to npu_group_topk when the latest CANN and NNAL is available + # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) + topk_weights = native_grouped_topk(topk_weights, num_expert_group, + topk_group) + # TODO bfloat16 is not supported in torch.topk with ge graph. + if e_score_correction_bias is not None: + topk_ids = torch.topk(topk_weights.to(torch.float32), + k=top_k, + dim=-1, + sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_weights.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), + k=top_k, + dim=-1, + sorted=False) + elif custom_routing_function is None: + topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) + topk_weights = topk_weights.to(hidden_states.dtype) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + global_num_experts=global_num_experts, + ) + # Required by npu_moe_init_routing + topk_ids = topk_ids.to(torch.int32) + return topk_weights, topk_ids + + # Required by npu_moe_init_routing + topk_ids = topk_ids.to(torch.int32) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +def native_grouped_topk( + topk_weights: torch.Tensor, + num_expert_group: Optional[int], + topk_group: Optional[int], +): + topk_group = 0 if topk_group is None else topk_group + num_expert_group = 0 if num_expert_group is None else num_expert_group + + num_token = topk_weights.shape[0] + grouped_weights = topk_weights.view(num_token, num_expert_group, + -1).max(dim=-1).values + topk_group_indices = torch.topk(grouped_weights.to(torch.float32), + k=topk_group, + dim=-1, + sorted=False)[1] + topk_group_mask = torch.zeros_like(grouped_weights) + topk_group_mask.scatter_(1, topk_group_indices, 1) + topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) + topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) + + return topk_weights diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f919595287..c75021a702 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -49,7 +49,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors -from vllm.utils import DeviceMemoryProfiler, LazyLoader, cdiv +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, + LazyLoader, cdiv) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) @@ -169,6 +170,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): else: self.chunked_prefill_enabled = True + if self.cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype] + self.is_multimodal_model = self.model_config.is_multimodal_model if self.is_multimodal_model: self.inputs_embeds = torch.zeros( @@ -1924,10 +1931,17 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may # encounter OOM issue if isinstance(kv_cache_spec, FullAttentionSpec): - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - dtype = kv_cache_spec.dtype + if self.vllm_config.additional_config.get( + "kv_cache_dtype", None) == 'int8': + kv_cache_shape = self.attn_backend.get_bsh_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) + else: + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) if self.torchair_graph_enabled: layer_kv_cache_nope = torch.zeros( kv_cache_shape[:-1] + @@ -1951,9 +1965,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: acl_format), ) else: - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) + kv_caches[layer_name] = torch.zeros( + kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device) kv_caches[layer_name] = \ torch_npu.npu_format_cast(kv_caches[layer_name], acl_format) else: