Skip to content

Commit 5efefb7

Browse files
committed
remove torch.cat and replace it by TensorList[0]
Signed-off-by: huangxialu <huangxialu1@huawei.com>
1 parent f3b50c5 commit 5efefb7

File tree

1 file changed

+12
-25
lines changed

1 file changed

+12
-25
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,10 @@ def fused_experts_with_mc2(
204204
group_list_type=1,
205205
group_type=0,
206206
group_list=group_list,
207-
)
207+
)[0]
208208

209209
# TODO: Remove this in the future.
210-
gate_up_out = torch.cat(gate_up_out_list, dim=0)
211-
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
210+
gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
212211

213212
w2 = w2.transpose(1, 2)
214213
down_out_list = torch_npu.npu_grouped_matmul(
@@ -218,9 +217,7 @@ def fused_experts_with_mc2(
218217
group_list_type=1,
219218
group_type=0,
220219
group_list=group_list,
221-
)
222-
223-
down_out_list = torch.cat(down_out_list, dim=0)
220+
)[0]
224221

225222
# moeCombine
226223
kwargs_mc2 = {
@@ -311,9 +308,8 @@ def apply_mlp(
311308
group_list_type=group_list_type,
312309
group_type=0,
313310
group_list=group_list,
314-
)
311+
)[0]
315312

316-
hidden_states = torch.cat(hidden_states, dim=0)
317313
hidden_states = torch_npu.npu_swiglu(hidden_states)
318314

319315
w2 = w2.transpose(1, 2)
@@ -324,9 +320,8 @@ def apply_mlp(
324320
group_list_type=group_list_type,
325321
group_type=0,
326322
group_list=group_list,
327-
)
323+
)[0]
328324

329-
hidden_states = torch.cat(hidden_states, dim=0)
330325
return hidden_states
331326

332327

@@ -416,23 +411,19 @@ def fused_experts_with_all2all(
416411
group_list_type=0,
417412
group_type=0,
418413
group_list=expert_tokens,
419-
)
414+
)[0]
420415

421-
# TODO: Remove this in the future.
422-
hidden_states = torch.cat(gate_up_out_list, dim=0)
423-
hidden_states = torch_npu.npu_swiglu(hidden_states)
416+
hidden_states = torch_npu.npu_swiglu(gate_up_out_list)
424417

425418
w2 = w2.transpose(1, 2)
426-
down_out_list = torch_npu.npu_grouped_matmul(
419+
hidden_states = torch_npu.npu_grouped_matmul(
427420
x=[hidden_states],
428421
weight=[w2],
429422
split_item=2,
430423
group_list_type=0,
431424
group_type=0,
432425
group_list=expert_tokens,
433-
)
434-
435-
hidden_states = torch.cat(down_out_list, dim=0)
426+
)[0]
436427

437428
if expert_map is not None:
438429
resorted_idx = torch.argsort(sorted_idx)
@@ -822,11 +813,9 @@ def fused_experts(
822813
group_list_type=0,
823814
group_type=0,
824815
group_list=expert_tokens,
825-
)
816+
)[0]
826817

827-
# TODO: Remove this in the future.
828-
gate_up_out = torch.cat(gate_up_out_list, dim=0)
829-
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
818+
gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
830819

831820
w2 = w2.transpose(1, 2)
832821
down_out_list = torch_npu.npu_grouped_matmul(
@@ -836,9 +825,7 @@ def fused_experts(
836825
group_list_type=0,
837826
group_type=0,
838827
group_list=expert_tokens,
839-
)
840-
841-
down_out_list = torch.cat(down_out_list, dim=0)
828+
)[0]
842829

843830
if expert_map is not None:
844831
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)

0 commit comments

Comments
 (0)