@@ -664,6 +664,76 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
664664        )
665665
666666
667+ @triton .jit  
668+ def  compute_identity_kernel (
669+     top_k : int ,
670+     hidden_states_ptr : tl .tensor ,
671+     expert_scales_ptr : tl .tensor ,
672+     num_tokens : int ,
673+     output_ptr : tl .tensor ,
674+     hidden_dim : int ,
675+     scales_stride : int ,
676+     BLOCK_SIZE : tl .constexpr ,
677+ ) ->  None :
678+     pid  =  tl .program_id (0 )
679+ 
680+     batch_id  =  pid  //  (hidden_dim  //  BLOCK_SIZE )
681+     dim_offset  =  pid  %  (hidden_dim  //  BLOCK_SIZE ) *  BLOCK_SIZE 
682+ 
683+     if  batch_id  >=  num_tokens  or  dim_offset  >=  hidden_dim :
684+         return 
685+ 
686+     h  =  tl .load (hidden_states_ptr  +  batch_id  *  hidden_dim  +  dim_offset  + 
687+                 tl .arange (0 , BLOCK_SIZE ),
688+                 mask = (dim_offset  +  tl .arange (0 , BLOCK_SIZE )) <  hidden_dim )
689+ 
690+     result  =  tl .zeros ([BLOCK_SIZE ], dtype = tl .float32 )
691+     for  i  in  range (top_k ):
692+         scale  =  tl .load (expert_scales_ptr  +  batch_id  *  scales_stride  +  i )
693+         result  +=  h  *  scale 
694+ 
695+     tl .store (output_ptr  +  batch_id  *  hidden_dim  +  dim_offset  + 
696+              tl .arange (0 , BLOCK_SIZE ),
697+              result ,
698+              mask = (dim_offset  +  tl .arange (0 , BLOCK_SIZE )) <  hidden_dim )
699+ 
700+ 
701+ def  zero_experts_compute_triton (expert_indices : torch .Tensor ,
702+                                 expert_scales : torch .Tensor , num_experts : int ,
703+                                 zero_expert_type : str ,
704+                                 hidden_states : torch .Tensor ) ->  torch .Tensor :
705+     N  =  expert_indices .numel ()
706+     top_k  =  expert_indices .size (- 1 )
707+     grid  =  lambda  meta : (triton .cdiv (N , meta ['BLOCK_SIZE' ]), )
708+ 
709+     if  zero_expert_type  ==  "identity" :
710+         zero_expert_mask  =  expert_indices  <  num_experts 
711+         zero_expert_scales  =  expert_scales .clone ()
712+         zero_expert_scales [zero_expert_mask ] =  0.0 
713+ 
714+     normal_expert_mask  =  expert_indices  >=  num_experts 
715+     expert_indices [normal_expert_mask ] =  0 
716+     expert_scales [normal_expert_mask ] =  0.0 
717+ 
718+     output  =  torch .zeros_like (hidden_states ).to (hidden_states .device )
719+     hidden_dim  =  hidden_states .size (- 1 )
720+     num_tokens  =  hidden_states .size (0 )
721+ 
722+     grid  =  lambda  meta : (num_tokens  *  (hidden_dim  //  meta ['BLOCK_SIZE' ]), )
723+     compute_identity_kernel [grid ](
724+         top_k ,
725+         hidden_states ,
726+         zero_expert_scales ,
727+         num_tokens ,
728+         output ,
729+         hidden_dim ,
730+         zero_expert_scales .stride (0 ),
731+         BLOCK_SIZE = 256 ,
732+     )
733+ 
734+     return  output 
735+ 
736+ 
667737# Adapted from: https://github.com/sgl-project/sglang/pull/2628 
668738def  get_config_file_name (E : int ,
669739                         N : int ,
@@ -940,6 +1010,25 @@ def fused_topk(
9401010    return  topk_weights , topk_ids , token_expert_indices 
9411011
9421012
1013+ def  fused_topk_bias (
1014+     hidden_states : torch .Tensor ,
1015+     gating_output : torch .Tensor ,
1016+     e_score_correction_bias : torch .Tensor ,
1017+     topk : int ,
1018+     renormalize : bool ,
1019+ ):
1020+     n_routed_experts  =  gating_output .shape [- 1 ]
1021+     scores  =  gating_output .softmax (dim = - 1 )
1022+     scores_for_choice  =  scores .view (
1023+         - 1 , n_routed_experts ) +  e_score_correction_bias .unsqueeze (0 )
1024+     topk_indices  =  torch .topk (scores_for_choice , k = topk , dim = - 1 ,
1025+                               sorted = False )[1 ]
1026+     topk_weights  =  scores .gather (1 , topk_indices )
1027+     if  renormalize :
1028+         topk_weights  =  topk_weights  /  topk_weights .sum (dim = - 1 , keepdim = True )
1029+     return  topk_weights .to (torch .float32 ), topk_indices .to (torch .int32 )
1030+ 
1031+ 
9431032# This is used by the Deepseek-V2 and Deepseek-V3 model 
9441033@torch .compile (dynamic = True , backend = current_platform .simple_compile_backend ) 
9451034def  grouped_topk (
0 commit comments