3636
3737VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
3838USING_LCCL_COM : bool = envs_ascend .USING_LCCL_COM
39+ VLLM_ENABLE_FIX_ROUTE : bool = envs_ascend .VLLM_ENABLE_FIX_ROUTE
3940
4041
4142def fused_experts_with_mc2 (
@@ -50,6 +51,14 @@ def fused_experts_with_mc2(
5051) -> torch .Tensor :
5152 global_bs = 0
5253 moe_expert_num = len (expert_map )
54+
55+ rank = torch .distributed .get_rank ()
56+ if VLLM_ENABLE_FIX_ROUTE :
57+ step = hidden_states .shape [0 ] * top_k
58+ uniform_topk_list = [(i + rank ) % moe_expert_num
59+ for i in range (rank * step , (rank + 1 ) * step )]
60+ topk_ids = torch .Tensor (uniform_topk_list ).int ().view (
61+ hidden_states .shape [0 ], - 1 ).to (hidden_states .device )
5362 kwargs = {
5463 "x" : hidden_states ,
5564 "expert_ids" : topk_ids ,
@@ -59,8 +68,6 @@ def fused_experts_with_mc2(
5968 "global_bs" : global_bs ,
6069 }
6170
62- rank = torch .distributed .get_rank ()
63-
6471 quant_mode = 0
6572 ep_group = get_ep_group ().device_group
6673 local_rank = torch .distributed .get_rank (group = ep_group )
@@ -88,15 +95,20 @@ def fused_experts_with_mc2(
8895 0 :5 ]
8996
9097 w1 = w1 .transpose (1 , 2 )
91- expert_token_nums = torch .cumsum (expert_token_nums ,
92- dim = 0 ,
93- dtype = torch .int64 )
94- group_list = expert_token_nums .to (torch .int64 )
98+
99+ if VLLM_ENABLE_FIX_ROUTE :
100+ uniform_group_list = hidden_states .shape [0 ] * \
101+ all_to_all_group_size * top_k // moe_expert_num
102+ group_list = torch .Tensor ([uniform_group_list ] *
103+ w1 .shape [0 ]).long ().to (hidden_states .device )
104+ else :
105+ group_list = expert_token_nums
95106 gate_up_out_list = torch_npu .npu_grouped_matmul (
96107 x = [expand_x ],
97108 weight = [w1 ],
98109 split_item = 2 ,
99- group_list_type = 0 ,
110+ # 1 means count mode, to avoid cumulative operation of the group list
111+ group_list_type = 1 ,
100112 group_type = 0 ,
101113 group_list = group_list ,
102114 )
@@ -110,7 +122,7 @@ def fused_experts_with_mc2(
110122 x = [gate_up_out ],
111123 weight = [w2 ],
112124 split_item = 2 ,
113- group_list_type = 0 ,
125+ group_list_type = 1 ,
114126 group_type = 0 ,
115127 group_list = group_list ,
116128 )
0 commit comments