@@ -226,16 +226,17 @@ def apply(
226226 input_shape = x .shape
227227 x = x .view (- 1 , x .shape [- 1 ])
228228 if use_grouped_topk or custom_routing_function is not None :
229- topk_weights , topk_ids = FusedMoE .select_experts (hidden_states = x ,
230- router_logits = router_logits ,
231- use_grouped_topk = use_grouped_topk ,
232- top_k = top_k ,
233- renormalize = renormalize ,
234- topk_group = topk_group ,
235- num_expert_group = num_expert_group ,
236- custom_routing_function = custom_routing_function ,
237- scoring_func = scoring_func ,
238- e_score_correction_bias = e_score_correction_bias )
229+ topk_weights , topk_ids , zero_expert_result = FusedMoE .select_experts (
230+ hidden_states = x ,
231+ router_logits = router_logits ,
232+ use_grouped_topk = use_grouped_topk ,
233+ top_k = top_k ,
234+ renormalize = renormalize ,
235+ topk_group = topk_group ,
236+ num_expert_group = num_expert_group ,
237+ custom_routing_function = custom_routing_function ,
238+ scoring_func = scoring_func ,
239+ e_score_correction_bias = e_score_correction_bias )
239240 else :
240241 import torch .nn .functional as F
241242 topk_weights = F .softmax (router_logits , dim = 1 , dtype = torch .float32 )
@@ -663,18 +664,19 @@ def apply(
663664 x = x .view (- 1 , x .shape [- 1 ])
664665
665666 if use_grouped_topk or custom_routing_function is not None :
666- topk_weights , topk_ids = FusedMoE .select_experts (hidden_states = x ,
667- router_logits = router_logits ,
668- use_grouped_topk = use_grouped_topk ,
669- top_k = top_k ,
670- renormalize = renormalize ,
671- topk_group = topk_group ,
672- num_expert_group = num_expert_group ,
673- custom_routing_function = custom_routing_function ,
674- scoring_func = scoring_func ,
675- routed_scaling_factor = routed_scaling_factor ,
676- e_score_correction_bias = e_score_correction_bias ,
677- indices_type = self .topk_indices_dtype )
667+ topk_weights , topk_ids , zero_expert_result = FusedMoE .select_experts (
668+ hidden_states = x ,
669+ router_logits = router_logits ,
670+ use_grouped_topk = use_grouped_topk ,
671+ top_k = top_k ,
672+ renormalize = renormalize ,
673+ topk_group = topk_group ,
674+ num_expert_group = num_expert_group ,
675+ custom_routing_function = custom_routing_function ,
676+ scoring_func = scoring_func ,
677+ routed_scaling_factor = routed_scaling_factor ,
678+ e_score_correction_bias = e_score_correction_bias ,
679+ indices_type = self .topk_indices_dtype )
678680 else :
679681 import torch .nn .functional as F
680682 topk_weights = F .softmax (router_logits , dim = 1 , dtype = torch .float32 )
0 commit comments