From b78b7a9b658ff3781051e7c95c0f98a73ef209ff Mon Sep 17 00:00:00 2001 From: Levi-JQ Date: Tue, 2 Dec 2025 14:13:49 +0800 Subject: [PATCH] optimization of kimi-k2 Signed-off-by: Levi-JQ --- .../torchair/ops/torchair_fused_moe.py | 46 ++++++------------- .../quantization/torchair_w4a8_dynamic.py | 43 ++++++----------- .../quantization/torchair_w8a8_dynamic.py | 46 ++++++------------- 3 files changed, 42 insertions(+), 93 deletions(-) diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 7a14ca53684..c07700ca74f 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -857,38 +857,20 @@ def apply( shared_experts: Optional[Any] = None, **kwargs, ) -> torch.Tensor: - global_redundant_expert_num = get_ascend_config( - ).init_redundancy_expert - is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if is_deepseek_v3_r1: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently is 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = torchair_select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently is 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) topk_weights = topk_weights.to(x.dtype) # this is a naive implementation for experts load balance so as diff --git a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py index c61ddf322c2..b8cc36586d5 100644 --- a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py @@ -27,7 +27,6 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2) from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor @@ -322,34 +321,20 @@ def apply( assert router_logits.shape[ 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" - if global_num_experts == 256: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently is 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = torchair_select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently is 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) fused_moe_state = get_forward_context().fused_moe_state shared_gate_up, shared_dequant_scale = None, None diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index bc0a8d35783..ebfae2cf467 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -25,7 +25,6 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor, super_kernel) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, @@ -938,8 +937,6 @@ def apply( assert router_logits.shape[ 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" - is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 - fused_moe_state = get_forward_context().fused_moe_state if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: fused_moe_state = FusedMoEState.All2All @@ -948,35 +945,20 @@ def apply( with super_kernel(prefix, "stream-fusion=1", enabled=running_in_super_kernel): - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if is_deepseek_v3_r1: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently is 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = torchair_select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently is 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: with npu_stream_switch("moe_secondary", 0):