Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm_ascend/ops/moe/experts_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ def _select_experts_with_fusion_ops(
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
global_redundant_expert_num = get_ascend_config().init_redundancy_expert
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
if is_deepseek_v3_r1:
is_kimi = global_num_experts - global_redundant_expert_num == 384
# NOTE: now npu_moe_gating_top_k can support `group_count=256` pattern, and `group_count=384` pattern in cann8.3
if is_deepseek_v3_r1 or (is_kimi and torch.version.cann.startswith("8.3")):
Comment on lines 182 to +185
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to identify deepseek_v3_r1 and kimi models using magic numbers (256, 384), and the check for CANN version 8.3, is duplicated across multiple files (experts_selector.py, torchair_fused_moe.py, torchair_w8a8_dynamic.py, and torchair_w4a8_dynamic.py). This makes the code harder to maintain and increases the risk of inconsistencies when adding support for new models or CANN versions. Consider centralizing this logic into a helper function or a configuration object for better maintainability and readability.

topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently 8
Expand Down
6 changes: 4 additions & 2 deletions vllm_ascend/torchair/ops/torchair_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,8 +860,10 @@ def apply(
global_redundant_expert_num = get_ascend_config(
).init_redundancy_expert
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if is_deepseek_v3_r1:
is_kimi = global_num_experts - global_redundant_expert_num == 384
# NOTE: now npu_moe_gating_top_k can support `group_count=256` pattern, and `group_count=384` pattern in cann8.3
if is_deepseek_v3_r1 or (is_kimi
and torch.version.cann.startswith("8.3")):
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently is 8
Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,9 @@ def apply(
assert router_logits.shape[
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"

if global_num_experts == 256:
# NOTE: now npu_moe_gating_top_k can support `group_count=256` pattern, and `group_count=384` pattern in cann8.3
if global_num_experts == 256 or (global_num_experts == 384 and
torch.version.cann.startswith("8.3")):
Comment on lines +326 to +327
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's an inconsistency in how the model type is determined here. This file checks global_num_experts directly, while other files in this PR (e.g., torchair_w8a8_dynamic.py) check against the effective number of experts (global_num_experts - global_redundant_expert_num). This could lead to incorrect kernel selection if global_redundant_expert_num is non-zero, which would be a bug. The logic should be consistent across all files. The original logic if global_num_experts == 256: was also likely incorrect for the same reason.

Suggested change
if global_num_experts == 256 or (global_num_experts == 384 and
torch.version.cann.startswith("8.3")):
if (global_num_experts - global_redundant_expert_num == 256) or \
((global_num_experts - global_redundant_expert_num == 384) and torch.version.cann.startswith("8.3")):

topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently is 8
Expand Down
6 changes: 4 additions & 2 deletions vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@ def apply(
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"

is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
is_kimi = global_num_experts - global_redundant_expert_num == 384

fused_moe_state = get_forward_context().fused_moe_state
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
Expand All @@ -948,8 +949,9 @@ def apply(
with super_kernel(prefix,
"stream-fusion=1",
enabled=running_in_super_kernel):
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if is_deepseek_v3_r1:
# NOTE: now npu_moe_gating_top_k can support `group_count=256` pattern, and `group_count=384` pattern in cann8.3
if is_deepseek_v3_r1 or (is_kimi
and torch.version.cann.startswith("8.3")):
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently is 8
Expand Down
Loading