Skip to content

Commit d256b05

Browse files
committed
remove torch.cat and replace it by TensorList[0]
Signed-off-by: huangxialu <huangxialu1@huawei.com>
1 parent 4b3a210 commit d256b05

File tree

1 file changed

+12
-23
lines changed

1 file changed

+12
-23
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,9 @@ def fused_experts_with_mc2(
207207
group_list_type=1,
208208
group_type=0,
209209
group_list=group_list,
210-
)
210+
)[0]
211211

212-
# TODO: Remove this in the future.
213-
gate_up_out = torch.cat(gate_up_out_list, dim=0)
214-
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
212+
gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
215213

216214
w2 = w2.transpose(1, 2)
217215
down_out_list = torch_npu.npu_grouped_matmul(
@@ -221,9 +219,8 @@ def fused_experts_with_mc2(
221219
group_list_type=1,
222220
group_type=0,
223221
group_list=group_list,
224-
)
222+
)[0]
225223

226-
down_out_list = torch.cat(down_out_list, dim=0)
227224

228225
# moeCombine
229226
kwargs_mc2 = {
@@ -314,9 +311,8 @@ def apply_mlp(
314311
group_list_type=group_list_type,
315312
group_type=0,
316313
group_list=group_list,
317-
)
314+
)[0]
318315

319-
hidden_states = torch.cat(hidden_states, dim=0)
320316
hidden_states = torch_npu.npu_swiglu(hidden_states)
321317

322318
w2 = w2.transpose(1, 2)
@@ -327,9 +323,8 @@ def apply_mlp(
327323
group_list_type=group_list_type,
328324
group_type=0,
329325
group_list=group_list,
330-
)
326+
)[0]
331327

332-
hidden_states = torch.cat(hidden_states, dim=0)
333328
return hidden_states
334329

335330

@@ -419,23 +414,20 @@ def fused_experts_with_all2all(
419414
group_list_type=0,
420415
group_type=0,
421416
group_list=expert_tokens,
422-
)
417+
)[0]
423418

424-
# TODO: Remove this in the future.
425-
hidden_states = torch.cat(gate_up_out_list, dim=0)
426-
hidden_states = torch_npu.npu_swiglu(hidden_states)
419+
hidden_states = torch_npu.npu_swiglu(gate_up_out_list)
427420

428421
w2 = w2.transpose(1, 2)
429-
down_out_list = torch_npu.npu_grouped_matmul(
422+
hidden_states = torch_npu.npu_grouped_matmul(
430423
x=[hidden_states],
431424
weight=[w2],
432425
split_item=2,
433426
group_list_type=0,
434427
group_type=0,
435428
group_list=expert_tokens,
436-
)
429+
)[0]
437430

438-
hidden_states = torch.cat(down_out_list, dim=0)
439431

440432
if expert_map is not None:
441433
resorted_idx = torch.argsort(sorted_idx)
@@ -825,11 +817,9 @@ def fused_experts(
825817
group_list_type=0,
826818
group_type=0,
827819
group_list=expert_tokens,
828-
)
820+
)[0]
829821

830-
# TODO: Remove this in the future.
831-
gate_up_out = torch.cat(gate_up_out_list, dim=0)
832-
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
822+
gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
833823

834824
w2 = w2.transpose(1, 2)
835825
down_out_list = torch_npu.npu_grouped_matmul(
@@ -839,9 +829,8 @@ def fused_experts(
839829
group_list_type=0,
840830
group_type=0,
841831
group_list=expert_tokens,
842-
)
832+
)[0]
843833

844-
down_out_list = torch.cat(down_out_list, dim=0)
845834

846835
if expert_map is not None:
847836
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)

0 commit comments

Comments
 (0)