@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
108108 }
109109}
110110
111- template <int TPB>
112- __launch_bounds__ (TPB) __global__ void moeTopK (const float * inputs_after_softmax, const bool * finished, float * output,
113- int * indices, int * source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
111+ template <int TPB, typename IndType>
112+ __launch_bounds__ (TPB) __global__ void moeTopK (
113+ const float * inputs_after_softmax,
114+ const bool * finished,
115+ float * output,
116+ IndType* indices,
117+ int * source_rows,
118+ const int num_experts,
119+ const int k,
120+ const int start_expert,
121+ const int end_expert)
114122{
115123
116124 using cub_kvp = cub::KeyValuePair<int , float >;
@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
182190 2) This implementation assumes k is small, but will work for any k.
183191*/
184192
185- template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
193+ template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType >
186194__launch_bounds__ (WARPS_PER_CTA* WARP_SIZE) __global__
187- void topkGatingSoftmax (const float * input, const bool * finished, float * output, const int num_rows, int * indices,
195+ void topkGatingSoftmax (const float * input, const bool * finished, float * output, const int num_rows, IndType * indices,
188196 int * source_rows, const int k, const int start_expert, const int end_expert)
189197{
190198 // We begin by enforcing compile time assertions and setting up compile time constants.
@@ -397,8 +405,8 @@ struct TopkConstants
397405};
398406} // namespace detail
399407
400- template <int EXPERTS, int WARPS_PER_TB>
401- void topkGatingSoftmaxLauncherHelper (const float * input, const bool * finished, float * output, int * indices,
408+ template <int EXPERTS, int WARPS_PER_TB, typename IndType >
409+ void topkGatingSoftmaxLauncherHelper (const float * input, const bool * finished, float * output, IndType * indices,
402410 int * source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
403411{
404412 static constexpr std::size_t MAX_BYTES_PER_LDG = 16 ;
@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
421429 token_expert_indices, num_tokens, topk, 0 , num_experts, \
422430 stream);
423431
432+ template <typename IndType>
424433void topkGatingSoftmaxKernelLauncher (
425434 const float * gating_output,
426435 float * topk_weights,
427- int * topk_indicies,
436+ IndType * topk_indicies,
428437 int * token_expert_indices,
429438 float * softmax_workspace,
430439 const int num_tokens,
@@ -493,14 +502,32 @@ void topk_softmax(
493502 const at::cuda::OptionalCUDAGuard device_guard (device_of (gating_output));
494503 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
495504 torch::Tensor softmax_workspace = torch::empty ({workspace_size}, gating_output.options ());
496- vllm::moe::topkGatingSoftmaxKernelLauncher (
497- gating_output.data_ptr <float >(),
498- topk_weights.data_ptr <float >(),
499- topk_indices.data_ptr <int >(),
500- token_expert_indices.data_ptr <int >(),
501- softmax_workspace.data_ptr <float >(),
502- num_tokens,
503- num_experts,
504- topk,
505- stream);
505+
506+ if (topk_indices.scalar_type () == at::ScalarType::Int)
507+ {
508+ vllm::moe::topkGatingSoftmaxKernelLauncher (
509+ gating_output.data_ptr <float >(),
510+ topk_weights.data_ptr <float >(),
511+ topk_indices.data_ptr <int >(),
512+ token_expert_indices.data_ptr <int >(),
513+ softmax_workspace.data_ptr <float >(),
514+ num_tokens,
515+ num_experts,
516+ topk,
517+ stream);
518+ }
519+ else
520+ {
521+ assert (topk_indices.scalar_type () == at::ScalarType::UInt32);
522+ vllm::moe::topkGatingSoftmaxKernelLauncher (
523+ gating_output.data_ptr <float >(),
524+ topk_weights.data_ptr <float >(),
525+ topk_indices.data_ptr <uint32_t >(),
526+ token_expert_indices.data_ptr <int >(),
527+ softmax_workspace.data_ptr <float >(),
528+ num_tokens,
529+ num_experts,
530+ topk,
531+ stream);
532+ }
506533}
0 commit comments