@@ -207,11 +207,9 @@ def fused_experts_with_mc2(
207207 group_list_type = 1 ,
208208 group_type = 0 ,
209209 group_list = group_list ,
210- )
210+ )[ 0 ]
211211
212- # TODO: Remove this in the future.
213- gate_up_out = torch .cat (gate_up_out_list , dim = 0 )
214- gate_up_out = torch_npu .npu_swiglu (gate_up_out )
212+ gate_up_out = torch_npu .npu_swiglu (gate_up_out_list )
215213
216214 w2 = w2 .transpose (1 , 2 )
217215 down_out_list = torch_npu .npu_grouped_matmul (
@@ -221,9 +219,8 @@ def fused_experts_with_mc2(
221219 group_list_type = 1 ,
222220 group_type = 0 ,
223221 group_list = group_list ,
224- )
222+ )[ 0 ]
225223
226- down_out_list = torch .cat (down_out_list , dim = 0 )
227224
228225 # moeCombine
229226 kwargs_mc2 = {
@@ -314,9 +311,8 @@ def apply_mlp(
314311 group_list_type = group_list_type ,
315312 group_type = 0 ,
316313 group_list = group_list ,
317- )
314+ )[ 0 ]
318315
319- hidden_states = torch .cat (hidden_states , dim = 0 )
320316 hidden_states = torch_npu .npu_swiglu (hidden_states )
321317
322318 w2 = w2 .transpose (1 , 2 )
@@ -327,9 +323,8 @@ def apply_mlp(
327323 group_list_type = group_list_type ,
328324 group_type = 0 ,
329325 group_list = group_list ,
330- )
326+ )[ 0 ]
331327
332- hidden_states = torch .cat (hidden_states , dim = 0 )
333328 return hidden_states
334329
335330
@@ -419,23 +414,20 @@ def fused_experts_with_all2all(
419414 group_list_type = 0 ,
420415 group_type = 0 ,
421416 group_list = expert_tokens ,
422- )
417+ )[ 0 ]
423418
424- # TODO: Remove this in the future.
425- hidden_states = torch .cat (gate_up_out_list , dim = 0 )
426- hidden_states = torch_npu .npu_swiglu (hidden_states )
419+ hidden_states = torch_npu .npu_swiglu (gate_up_out_list )
427420
428421 w2 = w2 .transpose (1 , 2 )
429- down_out_list = torch_npu .npu_grouped_matmul (
422+ hidden_states = torch_npu .npu_grouped_matmul (
430423 x = [hidden_states ],
431424 weight = [w2 ],
432425 split_item = 2 ,
433426 group_list_type = 0 ,
434427 group_type = 0 ,
435428 group_list = expert_tokens ,
436- )
429+ )[ 0 ]
437430
438- hidden_states = torch .cat (down_out_list , dim = 0 )
439431
440432 if expert_map is not None :
441433 resorted_idx = torch .argsort (sorted_idx )
@@ -825,11 +817,9 @@ def fused_experts(
825817 group_list_type = 0 ,
826818 group_type = 0 ,
827819 group_list = expert_tokens ,
828- )
820+ )[ 0 ]
829821
830- # TODO: Remove this in the future.
831- gate_up_out = torch .cat (gate_up_out_list , dim = 0 )
832- gate_up_out = torch_npu .npu_swiglu (gate_up_out )
822+ gate_up_out = torch_npu .npu_swiglu (gate_up_out_list )
833823
834824 w2 = w2 .transpose (1 , 2 )
835825 down_out_list = torch_npu .npu_grouped_matmul (
@@ -839,9 +829,8 @@ def fused_experts(
839829 group_list_type = 0 ,
840830 group_type = 0 ,
841831 group_list = expert_tokens ,
842- )
832+ )[ 0 ]
843833
844- down_out_list = torch .cat (down_out_list , dim = 0 )
845834
846835 if expert_map is not None :
847836 weighted_down_out = down_out_list * sorted_weights .unsqueeze (1 )
0 commit comments