@@ -204,11 +204,10 @@ def fused_experts_with_mc2(
204204 group_list_type = 1 ,
205205 group_type = 0 ,
206206 group_list = group_list ,
207- )
207+ )[ 0 ]
208208
209209 # TODO: Remove this in the future.
210- gate_up_out = torch .cat (gate_up_out_list , dim = 0 )
211- gate_up_out = torch_npu .npu_swiglu (gate_up_out )
210+ gate_up_out = torch_npu .npu_swiglu (gate_up_out_list )
212211
213212 w2 = w2 .transpose (1 , 2 )
214213 down_out_list = torch_npu .npu_grouped_matmul (
@@ -218,9 +217,7 @@ def fused_experts_with_mc2(
218217 group_list_type = 1 ,
219218 group_type = 0 ,
220219 group_list = group_list ,
221- )
222-
223- down_out_list = torch .cat (down_out_list , dim = 0 )
220+ )[0 ]
224221
225222 # moeCombine
226223 kwargs_mc2 = {
@@ -311,9 +308,8 @@ def apply_mlp(
311308 group_list_type = group_list_type ,
312309 group_type = 0 ,
313310 group_list = group_list ,
314- )
311+ )[ 0 ]
315312
316- hidden_states = torch .cat (hidden_states , dim = 0 )
317313 hidden_states = torch_npu .npu_swiglu (hidden_states )
318314
319315 w2 = w2 .transpose (1 , 2 )
@@ -324,9 +320,8 @@ def apply_mlp(
324320 group_list_type = group_list_type ,
325321 group_type = 0 ,
326322 group_list = group_list ,
327- )
323+ )[ 0 ]
328324
329- hidden_states = torch .cat (hidden_states , dim = 0 )
330325 return hidden_states
331326
332327
@@ -416,23 +411,19 @@ def fused_experts_with_all2all(
416411 group_list_type = 0 ,
417412 group_type = 0 ,
418413 group_list = expert_tokens ,
419- )
414+ )[ 0 ]
420415
421- # TODO: Remove this in the future.
422- hidden_states = torch .cat (gate_up_out_list , dim = 0 )
423- hidden_states = torch_npu .npu_swiglu (hidden_states )
416+ hidden_states = torch_npu .npu_swiglu (gate_up_out_list )
424417
425418 w2 = w2 .transpose (1 , 2 )
426- down_out_list = torch_npu .npu_grouped_matmul (
419+ hidden_states = torch_npu .npu_grouped_matmul (
427420 x = [hidden_states ],
428421 weight = [w2 ],
429422 split_item = 2 ,
430423 group_list_type = 0 ,
431424 group_type = 0 ,
432425 group_list = expert_tokens ,
433- )
434-
435- hidden_states = torch .cat (down_out_list , dim = 0 )
426+ )[0 ]
436427
437428 if expert_map is not None :
438429 resorted_idx = torch .argsort (sorted_idx )
@@ -822,11 +813,9 @@ def fused_experts(
822813 group_list_type = 0 ,
823814 group_type = 0 ,
824815 group_list = expert_tokens ,
825- )
816+ )[ 0 ]
826817
827- # TODO: Remove this in the future.
828- gate_up_out = torch .cat (gate_up_out_list , dim = 0 )
829- gate_up_out = torch_npu .npu_swiglu (gate_up_out )
818+ gate_up_out = torch_npu .npu_swiglu (gate_up_out_list )
830819
831820 w2 = w2 .transpose (1 , 2 )
832821 down_out_list = torch_npu .npu_grouped_matmul (
@@ -836,9 +825,7 @@ def fused_experts(
836825 group_list_type = 0 ,
837826 group_type = 0 ,
838827 group_list = expert_tokens ,
839- )
840-
841- down_out_list = torch .cat (down_out_list , dim = 0 )
828+ )[0 ]
842829
843830 if expert_map is not None :
844831 weighted_down_out = down_out_list * sorted_weights .unsqueeze (1 )
0 commit comments