3636
3737VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
3838USING_LCCL_COM : bool = envs_ascend .USING_LCCL_COM
39+ VLLM_ENABLE_FIX_ROUTE : bool = envs_ascend .VLLM_ENABLE_FIX_ROUTE
3940
4041
4142def fused_experts_with_mc2 (
@@ -50,6 +51,14 @@ def fused_experts_with_mc2(
5051) -> torch .Tensor :
5152 global_bs = 0
5253 moe_expert_num = len (expert_map )
54+
55+ rank = torch .distributed .get_rank ()
56+ if VLLM_ENABLE_FIX_ROUTE :
57+ step = hidden_states .shape [0 ] * top_k
58+ uniform_topk_list = [(i + rank ) % moe_expert_num
59+ for i in range (rank * step , (rank + 1 ) * step )]
60+ topk_ids = torch .Tensor (uniform_topk_list ).int ().view (
61+ hidden_states .shape [0 ], - 1 ).to (hidden_states .device )
5362 kwargs = {
5463 "x" : hidden_states ,
5564 "expert_ids" : topk_ids ,
@@ -59,8 +68,6 @@ def fused_experts_with_mc2(
5968 "global_bs" : global_bs ,
6069 }
6170
62- rank = torch .distributed .get_rank ()
63-
6471 quant_mode = 0
6572 ep_group = get_ep_group ().device_group
6673 local_rank = torch .distributed .get_rank (group = ep_group )
@@ -89,15 +96,20 @@ def fused_experts_with_mc2(
8996 0 :5 ]
9097
9198 w1 = w1 .transpose (1 , 2 )
92- expert_token_nums = torch .cumsum (expert_token_nums ,
93- dim = 0 ,
94- dtype = torch .int64 )
95- group_list = expert_token_nums .to (torch .int64 )
99+
100+ if VLLM_ENABLE_FIX_ROUTE :
101+ uniform_group_list = hidden_states .shape [0 ] * \
102+ all_to_all_group_size * top_k // moe_expert_num
103+ group_list = torch .Tensor ([uniform_group_list ] *
104+ w1 .shape [0 ]).long ().to (hidden_states .device )
105+ else :
106+ group_list = expert_token_nums
96107 gate_up_out_list = torch_npu .npu_grouped_matmul (
97108 x = [expand_x ],
98109 weight = [w1 ],
99110 split_item = 2 ,
100- group_list_type = 0 ,
111+ # 1 means count mode, to avoid cumulative operation of the group list
112+ group_list_type = 1 ,
101113 group_type = 0 ,
102114 group_list = group_list ,
103115 )
@@ -111,7 +123,7 @@ def fused_experts_with_mc2(
111123 x = [gate_up_out ],
112124 weight = [w2 ],
113125 split_item = 2 ,
114- group_list_type = 0 ,
126+ group_list_type = 1 ,
115127 group_type = 0 ,
116128 group_list = group_list ,
117129 )
0 commit comments