@@ -190,8 +190,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(
190190 2) This implementation assumes k is small, but will work for any k.
191191*/
192192
193- template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
194- __launch_bounds__ (WARPS_PER_CTA* WARP_SIZE ) __global__
193+ template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType>
194+ __launch_bounds__ (WARPS_PER_CTA* WARP_SIZE_PARAM ) __global__
195195 void topkGatingSoftmax (const float * input, const bool * finished, float * output, const int num_rows, IndType* indices,
196196 int * source_rows, const int k, const int start_expert, const int end_expert)
197197{
@@ -209,12 +209,12 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
209209
210210 // Restrictions based on previous section.
211211 static_assert (VPT % ELTS_PER_LDG == 0 , " The elements per thread must be a multiple of the elements per ldg" );
212- static_assert (WARP_SIZE % THREADS_PER_ROW == 0 , " The threads per row must cleanly divide the threads per warp" );
212+ static_assert (WARP_SIZE_PARAM % THREADS_PER_ROW == 0 , " The threads per row must cleanly divide the threads per warp" );
213213 static_assert (THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), " THREADS_PER_ROW must be power of 2" );
214- static_assert (THREADS_PER_ROW <= WARP_SIZE , " THREADS_PER_ROW can be at most warp size" );
214+ static_assert (THREADS_PER_ROW <= WARP_SIZE_PARAM , " THREADS_PER_ROW can be at most warp size" );
215215
216216 // We have NUM_EXPERTS elements per row. We specialize for small #experts
217- static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
217+ static constexpr int ELTS_PER_WARP = WARP_SIZE_PARAM * VPT;
218218 static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
219219 static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
220220
@@ -393,41 +393,51 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
393393namespace detail
394394{
395395// Constructs some constants needed to partition the work across threads at compile time.
396- template <int EXPERTS, int BYTES_PER_LDG>
396+ template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM >
397397struct TopkConstants
398398{
399399 static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof (float );
400- static_assert (EXPERTS / (ELTS_PER_LDG * WARP_SIZE ) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE ) == 0 , " " );
401- static constexpr int VECs_PER_THREAD = MAX(1 , EXPERTS / (ELTS_PER_LDG * WARP_SIZE ));
400+ static_assert (EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM ) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM ) == 0 , " " );
401+ static constexpr int VECs_PER_THREAD = MAX(1 , EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM ));
402402 static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
403403 static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
404- static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
404+ static const int ROWS_PER_WARP = WARP_SIZE_PARAM / THREADS_PER_ROW;
405405};
406406} // namespace detail
407407
408- template <int EXPERTS, int WARPS_PER_TB, typename IndType>
408+ template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, typename IndType>
409409void topkGatingSoftmaxLauncherHelper (const float * input, const bool * finished, float * output, IndType* indices,
410410 int * source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
411411{
412412 static constexpr std::size_t MAX_BYTES_PER_LDG = 16 ;
413413
414414 static constexpr int BYTES_PER_LDG = MIN (MAX_BYTES_PER_LDG, sizeof (float ) * EXPERTS);
415- using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
415+ using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM >;
416416 static constexpr int VPT = Constants::VPT;
417417 static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
418418 const int num_warps = (num_rows + ROWS_PER_WARP - 1 ) / ROWS_PER_WARP;
419419 const int num_blocks = (num_warps + WARPS_PER_TB - 1 ) / WARPS_PER_TB;
420420
421- dim3 block_dim (WARP_SIZE , WARPS_PER_TB);
422- topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0 , stream>>> (
421+ dim3 block_dim (WARP_SIZE_PARAM , WARPS_PER_TB);
422+ topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM ><<<num_blocks, block_dim, 0 , stream>>> (
423423 input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
424424}
425425
426- #define LAUNCH_SOFTMAX (NUM_EXPERTS, WARPS_PER_TB ) \
427- topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
428- gating_output, nullptr , topk_weights, topk_indices, \
429- token_expert_indices, num_tokens, topk, 0 , num_experts, \
430- stream);
426+ #define LAUNCH_SOFTMAX (NUM_EXPERTS, WARPS_PER_TB ) \
427+ switch (warpSize ) { \
428+ case 32 : \
429+ topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32 >( \
430+ gating_output, nullptr , topk_weights, topk_indices, \
431+ token_expert_indices, num_tokens, topk, 0 , num_experts, stream); \
432+ break ; \
433+ case 64 : \
434+ topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64 >( \
435+ gating_output, nullptr , topk_weights, topk_indices, \
436+ token_expert_indices, num_tokens, topk, 0 , num_experts, stream); \
437+ break ; \
438+ default : \
439+ TORCH_CHECK (false , " Unsupported warp size: " , warpSize ); \
440+ }
431441
432442template <typename IndType>
433443void topkGatingSoftmaxKernelLauncher (
@@ -441,6 +451,7 @@ void topkGatingSoftmaxKernelLauncher(
441451 const int topk,
442452 cudaStream_t stream) {
443453 static constexpr int WARPS_PER_TB = 4 ;
454+ auto warpSize = WARP_SIZE;
444455 switch (num_experts) {
445456 case 1 :
446457 LAUNCH_SOFTMAX (1 , WARPS_PER_TB);
0 commit comments