@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
108108    }
109109}
110110
111- template  <int  TPB>
112- __launch_bounds__ (TPB) __global__  void  moeTopK (const  float * inputs_after_softmax, const  bool * finished, float * output,
113-     int * indices, int * source_rows, const  int  num_experts, const  int  k, const  int  start_expert, const  int  end_expert)
111+ template  <int  TPB, typename  IndType>
112+ __launch_bounds__ (TPB) __global__  void  moeTopK (
113+     const  float * inputs_after_softmax,
114+     const  bool * finished,
115+     float * output,
116+     IndType* indices,
117+     int * source_rows,
118+     const  int  num_experts,
119+     const  int  k,
120+     const  int  start_expert,
121+     const  int  end_expert)
114122{
115123
116124    using  cub_kvp = cub::KeyValuePair<int , float >;
@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
182190  2) This implementation assumes k is small, but will work for any k. 
183191*/ 
184192
185- template  <int  VPT, int  NUM_EXPERTS, int  WARPS_PER_CTA, int  BYTES_PER_LDG>
193+ template  <int  VPT, int  NUM_EXPERTS, int  WARPS_PER_CTA, int  BYTES_PER_LDG,  typename  IndType >
186194__launch_bounds__ (WARPS_PER_CTA* WARP_SIZE) __global__ 
187-     void  topkGatingSoftmax (const  float * input, const  bool * finished, float * output, const  int  num_rows, int * indices,
195+     void  topkGatingSoftmax (const  float * input, const  bool * finished, float * output, const  int  num_rows, IndType * indices,
188196        int * source_rows, const  int  k, const  int  start_expert, const  int  end_expert)
189197{
190198    //  We begin by enforcing compile time assertions and setting up compile time constants.
@@ -397,8 +405,8 @@ struct TopkConstants
397405};
398406} //  namespace detail
399407
400- template  <int  EXPERTS, int  WARPS_PER_TB>
401- void  topkGatingSoftmaxLauncherHelper (const  float * input, const  bool * finished, float * output, int * indices,
408+ template  <int  EXPERTS, int  WARPS_PER_TB,  typename  IndType >
409+ void  topkGatingSoftmaxLauncherHelper (const  float * input, const  bool * finished, float * output, IndType * indices,
402410    int * source_row, const  int  num_rows, const  int  k, const  int  start_expert, const  int  end_expert, cudaStream_t stream)
403411{
404412    static  constexpr  std::size_t  MAX_BYTES_PER_LDG = 16 ;
@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
421429        token_expert_indices, num_tokens, topk, 0 , num_experts,         \
422430        stream);
423431
432+ template  <typename  IndType>
424433void  topkGatingSoftmaxKernelLauncher (
425434    const  float * gating_output,
426435    float * topk_weights,
427-     int * topk_indicies,
436+     IndType * topk_indicies,
428437    int * token_expert_indices,
429438    float * softmax_workspace,
430439    const  int  num_tokens,
@@ -493,14 +502,32 @@ void topk_softmax(
493502    const  at::cuda::OptionalCUDAGuard device_guard (device_of (gating_output));
494503    const  cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
495504    torch::Tensor softmax_workspace = torch::empty ({workspace_size}, gating_output.options ());
496-     vllm::moe::topkGatingSoftmaxKernelLauncher (
497-         gating_output.data_ptr <float >(),
498-         topk_weights.data_ptr <float >(),
499-         topk_indices.data_ptr <int >(),
500-         token_expert_indices.data_ptr <int >(),
501-         softmax_workspace.data_ptr <float >(),
502-         num_tokens,
503-         num_experts,
504-         topk,
505-         stream);
505+ 
506+     if (topk_indices.scalar_type () == at::ScalarType::Int)
507+     {
508+         vllm::moe::topkGatingSoftmaxKernelLauncher (
509+             gating_output.data_ptr <float >(),
510+             topk_weights.data_ptr <float >(),
511+             topk_indices.data_ptr <int >(),
512+             token_expert_indices.data_ptr <int >(),
513+             softmax_workspace.data_ptr <float >(),
514+             num_tokens,
515+             num_experts,
516+             topk,
517+             stream);
518+     }
519+     else 
520+     {
521+         assert (topk_indices.scalar_type () == at::ScalarType::UInt32);
522+         vllm::moe::topkGatingSoftmaxKernelLauncher (
523+             gating_output.data_ptr <float >(),
524+             topk_weights.data_ptr <float >(),
525+             topk_indices.data_ptr <uint32_t >(),
526+             token_expert_indices.data_ptr <int >(),
527+             softmax_workspace.data_ptr <float >(),
528+             num_tokens,
529+             num_experts,
530+             topk,
531+             stream);
532+     }
506533}
0 commit comments