4444
4545VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
4646USING_LCCL_COM : bool = envs_ascend .USING_LCCL_COM
47+ VLLM_ENABLE_FIX_ROUTE : bool = envs_ascend .VLLM_ENABLE_FIX_ROUTE
4748
4849
4950def fused_experts_with_mc2 (
@@ -58,6 +59,14 @@ def fused_experts_with_mc2(
5859) -> torch .Tensor :
5960 global_bs = 0
6061 moe_expert_num = len (expert_map )
62+
63+ rank = torch .distributed .get_rank ()
64+ if VLLM_ENABLE_FIX_ROUTE :
65+ step = hidden_states .shape [0 ] * top_k
66+ uniform_topk_list = [
67+ (i + rank ) % moe_expert_num for i in range (rank * step , (rank + 1 ) * step )
68+ ]
69+ topk_ids = torch .Tensor (uniform_topk_list ).int ().view (hidden_states .shape [0 ], - 1 ).npu ()
6170 kwargs = {
6271 "x" : hidden_states ,
6372 "expert_ids" : topk_ids ,
@@ -67,8 +76,6 @@ def fused_experts_with_mc2(
6776 "global_bs" : global_bs ,
6877 }
6978
70- rank = torch .distributed .get_rank ()
71-
7279 quant_mode = 0
7380 ep_group = get_ep_group ().device_group
7481 local_rank = torch .distributed .get_rank (group = ep_group )
@@ -97,15 +104,17 @@ def fused_experts_with_mc2(
97104 0 :5 ]
98105
99106 w1 = w1 .transpose (1 , 2 )
100- expert_token_nums = torch .cumsum (expert_token_nums ,
101- dim = 0 ,
102- dtype = torch .int64 )
103- group_list = expert_token_nums .to (torch .int64 )
107+
108+ if VLLM_ENABLE_FIX_ROUTE :
109+ uniform_group_list = hidden_states .shape [0 ] * all_to_all_group_size * top_k // moe_expert_num
110+ group_list = torch .Tensor ([uniform_group_list ] * w1 .shape [0 ]).long ().npu ()
111+ else :
112+ group_list = expert_token_nums
104113 gate_up_out_list = torch_npu .npu_grouped_matmul (
105114 x = [expand_x ],
106115 weight = [w1 ],
107116 split_item = 2 ,
108- group_list_type = 0 ,
117+ group_list_type = 1 ,
109118 group_type = 0 ,
110119 group_list = group_list ,
111120 )
@@ -119,7 +128,7 @@ def fused_experts_with_mc2(
119128 x = [gate_up_out ],
120129 weight = [w2 ],
121130 split_item = 2 ,
122- group_list_type = 0 ,
131+ group_list_type = 1 ,
123132 group_type = 0 ,
124133 group_list = group_list ,
125134 )
0 commit comments