@@ -722,8 +722,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
722722        use_moe_wna16_cuda  =  should_moe_wna16_use_cuda (
723723            num_valid_tokens = topk_ids .numel (),
724724            group_size = block_shape [1 ],
725-             num_experts = B .shape [0 ],
726-             bit = 4  if  use_int4_w4a16  else  8 )
725+             num_experts = B .shape [0 ])
727726        config  =  config .copy ()
728727        config .update (
729728            get_moe_wna16_block_config (config = config ,
@@ -885,13 +884,19 @@ def get_moe_wna16_block_config(config: Dict[str,
885884                               num_experts : int , group_size : int ,
886885                               real_top_k : int , block_size_m : int ):
887886    if  "BLOCK_SIZE_N"  in  config  and  "BLOCK_SIZE_K"  in  config :
887+         # optimal block config is set 
888888        return  {}
889889    if  not  use_moe_wna16_cuda :
890+         # triton moe wna16 kernel 
890891        if  num_valid_tokens  //  real_top_k  ==  1 :
892+             # if bs=1, use a smaller BLOCK_SIZE_N 
891893            return  {"BLOCK_SIZE_N" : 32 , "BLOCK_SIZE_K" : 64 }
892894        else :
893895            return  {"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 32 }
894896    else :
897+         # cuda moe wna16 kernel 
898+         # set default block_size 128, and increase them when num_blocks 
899+         # is too large. 
895900        block_size_n  =  128 
896901        block_size_k  =  128 
897902        if  block_size_k  <=  group_size :
@@ -922,15 +927,18 @@ def get_moe_wna16_block_config(config: Dict[str,
922927            num_blocks  =  num_blocks  //  2 
923928
924929        if  size_n  <=  1024  and  num_blocks  >=  1024 :
930+             # The kernel performance got much better with BLOCK_SIZE_N=1024 
931+             # when num_blocks is large, event when N is small. 
932+             # Not sure why, maybe it force the CUDA SM process only one block 
933+             # at the same time. 
925934            block_size_n  =  1024 
926935
927936        return  {"BLOCK_SIZE_N" : block_size_n , "BLOCK_SIZE_K" : block_size_k }
928937
929938
930939def  should_moe_wna16_use_cuda (num_valid_tokens : int , group_size : int ,
931-                               num_experts : int , bit : int ):
932-     return  bit  ==  4  and  group_size  in  [32 , 64 , 128 ] and  \
933-         num_valid_tokens  /  num_experts  <=  8 
940+                               num_experts : int ):
941+     return  group_size  in  [32 , 64 , 128 ] and  num_valid_tokens  /  num_experts  <=  8 
934942
935943
936944def  get_default_config (
@@ -958,9 +966,8 @@ def get_default_config(
958966        # moe wna16 kernels 
959967        # only set BLOCK_SIZE_M 
960968        # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later 
961-         bit  =  4  if  dtype  ==  "int4_w4a16"  else  8 
962969        use_moe_wna16_cuda  =  should_moe_wna16_use_cuda (M  *  topk ,
963-                                                        block_shape [1 ], E ,  bit )
970+                                                        block_shape [1 ], E )
964971        if  use_moe_wna16_cuda :
965972            config  =  {"BLOCK_SIZE_M" : min (16 , M )}
966973        elif  M  <=  20 :
0 commit comments