4444
4545VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
4646USING_LCCL_COM : bool = envs_ascend .USING_LCCL_COM
47+ VLLM_ENABLE_FIX_ROUTE : bool = envs_ascend .VLLM_ENABLE_FIX_ROUTE
4748
4849
4950def fused_experts_with_mc2 (
@@ -58,6 +59,14 @@ def fused_experts_with_mc2(
5859) -> torch .Tensor :
5960 global_bs = 0
6061 moe_expert_num = len (expert_map )
62+
63+ rank = torch .distributed .get_rank ()
64+ if VLLM_ENABLE_FIX_ROUTE :
65+ step = hidden_states .shape [0 ] * top_k
66+ uniform_topk_list = [(i + rank ) % moe_expert_num
67+ for i in range (rank * step , (rank + 1 ) * step )]
68+ topk_ids = torch .Tensor (uniform_topk_list ).int ().view (
69+ hidden_states .shape [0 ], - 1 ).to (hidden_states .device )
6170 kwargs = {
6271 "x" : hidden_states ,
6372 "expert_ids" : topk_ids ,
@@ -67,8 +76,6 @@ def fused_experts_with_mc2(
6776 "global_bs" : global_bs ,
6877 }
6978
70- rank = torch .distributed .get_rank ()
71-
7279 quant_mode = 0
7380 ep_group = get_ep_group ().device_group
7481 local_rank = torch .distributed .get_rank (group = ep_group )
@@ -97,15 +104,20 @@ def fused_experts_with_mc2(
97104 0 :5 ]
98105
99106 w1 = w1 .transpose (1 , 2 )
100- expert_token_nums = torch .cumsum (expert_token_nums ,
101- dim = 0 ,
102- dtype = torch .int64 )
103- group_list = expert_token_nums .to (torch .int64 )
107+
108+ if VLLM_ENABLE_FIX_ROUTE :
109+ uniform_group_list = hidden_states .shape [0 ] * \
110+ all_to_all_group_size * top_k // moe_expert_num
111+ group_list = torch .Tensor ([uniform_group_list ] *
112+ w1 .shape [0 ]).long ().to (hidden_states .device )
113+ else :
114+ group_list = expert_token_nums
104115 gate_up_out_list = torch_npu .npu_grouped_matmul (
105116 x = [expand_x ],
106117 weight = [w1 ],
107118 split_item = 2 ,
108- group_list_type = 0 ,
119+ # 1 means count mode, to avoid cumulative operation of the group list
120+ group_list_type = 1 ,
109121 group_type = 0 ,
110122 group_list = group_list ,
111123 )
@@ -119,7 +131,7 @@ def fused_experts_with_mc2(
119131 x = [gate_up_out ],
120132 weight = [w2 ],
121133 split_item = 2 ,
122- group_list_type = 0 ,
134+ group_list_type = 1 ,
123135 group_type = 0 ,
124136 group_list = group_list ,
125137 )
0 commit comments