5252 get_rm_router_logits_state , is_310p )
5353
5454MOE_ALL2ALL_BUFFER : bool = envs_ascend .MOE_ALL2ALL_BUFFER
55- SELECT_GATING_TOPK_SOTFMAX_EXPERTS : bool = envs_ascend .SELECT_GATING_TOPK_SOTFMAX_EXPERTS
5655
5756
5857def process_topk_ids (topk_ids : torch .Tensor , expert_num : int , ep_size : int ,
@@ -859,39 +858,6 @@ def fused_experts(
859858 return final_hidden_states
860859
861860
862- def select_gating_top_k_softmax_experts (
863- hidden_states : torch .Tensor , router_logits : torch .Tensor , top_k : int ,
864- renormalize : bool ) -> tuple [torch .Tensor , torch .Tensor ]:
865- """
866- Select top-k experts based on router logits.
867- only supports float16、bfloat16、float32
868-
869- Args:
870- hidden_states: Hidden states of shape (num_tokens, hidden_size).
871- router_logits: Router logits of shape (num_tokens, num_experts).
872- top_k: Number of experts to select.
873- renormalize: Whether to renormalize the routing weights.
874-
875- Returns:
876- topk_weights: Routing weights of shape (num_tokens, top_k).
877- topk_ids: Selected expert IDs of shape (num_tokens, top_k).
878-
879- Raises:
880- ValueError: If an unsupported scoring function is provided.
881- """
882- topk_weights , topk_ids , row_idx = torch_npu .npu_moe_gating_top_k_softmax (
883- router_logits , None , k = top_k )
884-
885- # # Required by npu_moe_init_routing
886- # topk_weights = topk_weights.to(hidden_states.dtype)
887- # topk_ids = topk_ids.to(torch.int32)
888-
889- if renormalize :
890- topk_weights = topk_weights / topk_weights .sum (dim = - 1 , keepdim = True )
891-
892- return topk_weights , topk_ids
893-
894-
895861def native_grouped_topk (
896862 topk_weights : torch .Tensor ,
897863 num_expert_group : Optional [int ],
@@ -953,8 +919,24 @@ def select_experts(
953919 ValueError: If an unsupported scoring function is provided.
954920 """
955921
922+ def _renormalize_topk_weights (
923+ topk_weights : torch .Tensor ,
924+ renormalize : bool ,
925+ ):
926+ if renormalize :
927+ topk_weights = topk_weights / topk_weights .sum (dim = - 1 ,
928+ keepdim = True )
929+ return topk_weights
930+
956931 if scoring_func == "softmax" :
957932 # NOTE: vLLM use dtype=torch.float here
933+ if not use_grouped_topk and custom_routing_function is None :
934+ topk_weights , topk_ids , _ = torch_npu .npu_moe_gating_top_k_softmax (
935+ x = router_logits , finished = None , k = top_k )
936+ topk_ids = topk_ids .to (torch .int32 )
937+ topk_weights = _renormalize_topk_weights (topk_weights , renormalize )
938+ return topk_weights , topk_ids
939+
958940 topk_weights = router_logits .softmax (dim = - 1 )
959941 elif scoring_func == "sigmoid" :
960942 topk_weights = router_logits .sigmoid ()
@@ -988,10 +970,11 @@ def select_experts(
988970 k = top_k ,
989971 dim = - 1 ,
990972 sorted = False )
991- elif custom_routing_function is None :
992- topk_weights , topk_ids = topk_weights .topk (top_k , dim = - 1 )
993- topk_weights = topk_weights .to (hidden_states .dtype )
994- else :
973+ topk_ids = topk_ids .to (torch .int32 )
974+ topk_weights = _renormalize_topk_weights (topk_weights , renormalize )
975+ return topk_weights , topk_ids
976+
977+ if custom_routing_function is not None :
995978 topk_weights , topk_ids = custom_routing_function (
996979 hidden_states = hidden_states ,
997980 gating_output = router_logits ,
@@ -1002,11 +985,12 @@ def select_experts(
1002985 topk_ids = topk_ids .to (torch .int32 )
1003986 return topk_weights , topk_ids
1004987
988+ topk_weights , topk_ids = topk_weights .topk (top_k , dim = - 1 )
989+ topk_weights = topk_weights .to (hidden_states .dtype )
990+
1005991 # Required by npu_moe_init_routing
1006992 topk_ids = topk_ids .to (torch .int32 )
1007-
1008- if renormalize :
1009- topk_weights = topk_weights / topk_weights .sum (dim = - 1 , keepdim = True )
993+ topk_weights = _renormalize_topk_weights (topk_weights , renormalize )
1010994
1011995 return topk_weights , topk_ids
1012996
@@ -1070,23 +1054,18 @@ def apply(
10701054 if is_deepseek_v3_r1 :
10711055 topk_weights , topk_ids , _ = torch_npu .npu_moe_gating_top_k (
10721056 router_logits ,
1073- k = top_k , # topk当前写8
1057+ k = top_k , # topk currently is 8
10741058 bias = e_score_correction_bias ,
10751059 k_group = topk_group , # fix: 4
10761060 group_count = num_expert_group , # fix 8
1077- group_select_mode = 1 , # 0: group中的最大; 1: topk2.sum(fix)
1061+ group_select_mode =
1062+ 1 , # 0: the maximum in the group; 1: topk2.sum(fix)
10781063 renorm = 0 , # 0: softmax->topk(fix); 1: topk->softmax
10791064 norm_type = 1 , # 0: softmax; 1: sigmoid(fix)
1080- # out_flag=False, # todo new api; 第三个输出是否输出
1081- # y2_flag=False, # old api; 第三个输出是否输出
1065+ # out_flag=False, # todo new api; should the third output be output
1066+ # y2_flag=False, # old api; should the third output be output
10821067 routed_scaling_factor = 1 ,
10831068 eps = float (1e-20 ))
1084- elif SELECT_GATING_TOPK_SOTFMAX_EXPERTS :
1085- topk_weights , topk_ids = select_gating_top_k_softmax_experts (
1086- hidden_states = x ,
1087- router_logits = router_logits ,
1088- top_k = top_k ,
1089- renormalize = renormalize )
10901069 else :
10911070 topk_weights , topk_ids = select_experts (
10921071 hidden_states = x ,
0 commit comments