Skip to content

Commit a883558

Browse files
hahazhky洪炜杰
andauthored
[Kernel] Remove cumsum in groupedmatmul (vllm-project#987)
### What this PR does / why we need it remove cumsum operator in MOE to improve performance ### How was this patch tested? it should be tested on a case with mc2 operator and graph mode enabled Signed-off-by: zhky <hahazhky@163.com> Co-authored-by: 洪炜杰 <hongweijie1@huawei.com>
1 parent f75c52f commit a883558

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,14 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
8888
0:5]
8989

9090
w1 = w1.transpose(1, 2)
91-
expert_token_nums = torch.cumsum(expert_token_nums,
92-
dim=0,
93-
dtype=torch.int64)
91+
9492
group_list = expert_token_nums.to(torch.int64)
9593
gate_up_out_list = torch_npu.npu_grouped_matmul(
9694
x=[expand_x],
9795
weight=[w1],
9896
split_item=2,
99-
group_list_type=0,
97+
# 1 means count mode, to avoid cumulative operation of the group list
98+
group_list_type=1,
10099
group_type=0,
101100
group_list=group_list,
102101
)
@@ -110,7 +109,7 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
110109
x=[gate_up_out],
111110
weight=[w2],
112111
split_item=2,
113-
group_list_type=0,
112+
group_list_type=1,
114113
group_type=0,
115114
group_list=group_list,
116115
)

0 commit comments

Comments
 (0)