@@ -231,7 +231,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
231231}
232232
233233// Use UE4M3 by default.
234- template <class Type , bool UE8M0_SF = false >
234+ template <class Type , bool UE8M0_SF = false , bool SMALL_NUM_EXPERTS = false >
235235__global__ void
236236#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
237237__launch_bounds__ (512 , 4 ) cvt_fp16_to_fp4(
@@ -240,58 +240,191 @@ cvt_fp16_to_fp4(
240240#endif
241241 int32_t numRows, int32_t numCols, Type const * in, float const * SFScale,
242242 uint32_t * out, uint32_t * SFout, uint32_t * input_offset_by_experts,
243- uint32_t * output_scale_offset_by_experts, int n_experts) {
243+ uint32_t * output_scale_offset_by_experts, int n_experts, bool low_latency ) {
244244#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
245245 using PackedVec = PackedVec<Type>;
246246 static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
247247 (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
248248 static_assert (sizeof (PackedVec) == sizeof (Type) * CVT_FP4_ELTS_PER_THREAD,
249249 " Vec size is not matched." );
250250
251- // Input tensor row/col loops.
252- for (int rowIdx = blockIdx .x ; rowIdx < numRows; rowIdx += gridDim .x ) {
253- for (int colIdx = threadIdx .x ; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
254- colIdx += blockDim .x ) {
255- int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
256- PackedVec in_vec = reinterpret_cast <PackedVec const *>(in)[inOffset];
257- // Get the output tensor offset.
258- // Same as inOffset because 8 elements are packed into one uint32_t.
259- int64_t outOffset = inOffset;
260- auto & out_pos = out[outOffset];
261-
262- // Find index within the experts.
263- int rowIdx_in_expert = 0 ;
264- int expert_idx = 0 ;
251+ int tid = blockIdx .x * blockDim .x + threadIdx .x ;
252+ int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
253+
254+ // Each global thread processes one element
255+ for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
256+ globalIdx += gridDim .x * blockDim .x ) {
257+ // Calculate which row and column this global thread should process
258+ int rowIdx = globalIdx / colsPerRow;
259+ int colIdx = globalIdx % colsPerRow;
260+
261+ int64_t inOffset = rowIdx * colsPerRow + colIdx;
262+ PackedVec in_vec = reinterpret_cast <PackedVec const *>(in)[inOffset];
263+ // Get the output tensor offset.
264+ // Same as inOffset because 8 elements are packed into one uint32_t.
265+ int64_t outOffset = inOffset;
266+ auto & out_pos = out[outOffset];
267+
268+ // Find index within the experts using different strategies based on expert
269+ // count
270+ int rowIdx_in_expert = 0 ;
271+ int expert_idx = 0 ;
272+
273+ if constexpr (SMALL_NUM_EXPERTS) {
265274 for (int i = 0 ; i < n_experts; i++) {
266- if (rowIdx >= input_offset_by_experts[i] &&
267- rowIdx < input_offset_by_experts[i + 1 ]) {
268- rowIdx_in_expert = rowIdx - input_offset_by_experts[i];
275+ uint32_t current_offset = __ldca (&input_offset_by_experts[i]);
276+ uint32_t next_offset = __ldca (&input_offset_by_experts[i + 1 ]);
277+ if (rowIdx >= current_offset && rowIdx < next_offset) {
278+ rowIdx_in_expert = rowIdx - current_offset;
269279 expert_idx = i;
270280 break ;
271281 }
272282 }
283+ } else {
284+ // Load input offsets into registers first, then do the computation.
285+ // Local array size set to 17 because of register limit.
286+ uint32_t local_offsets[17 ];
287+ for (int chunk_start = 0 ; chunk_start < n_experts; chunk_start += 16 ) {
288+ *reinterpret_cast <int4 *>(local_offsets) =
289+ __ldca (reinterpret_cast <const int4 *>(
290+ &input_offset_by_experts[chunk_start]));
291+ *reinterpret_cast <int4 *>(local_offsets + 4 ) =
292+ __ldca (reinterpret_cast <const int4 *>(
293+ &input_offset_by_experts[chunk_start + 4 ]));
294+ *reinterpret_cast <int4 *>(local_offsets + 8 ) =
295+ __ldca (reinterpret_cast <const int4 *>(
296+ &input_offset_by_experts[chunk_start + 8 ]));
297+ *reinterpret_cast <int4 *>(local_offsets + 12 ) =
298+ __ldca (reinterpret_cast <const int4 *>(
299+ &input_offset_by_experts[chunk_start + 12 ]));
300+ local_offsets[16 ] = __ldca (&input_offset_by_experts[chunk_start + 16 ]);
301+
302+ // Check against the 16 loaded offsets
303+ #pragma unroll
304+ for (int i = 0 ; i < 16 ; i++) {
305+ if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1 ]) {
306+ rowIdx_in_expert = rowIdx - local_offsets[i];
307+ expert_idx = chunk_start + i;
308+ break ;
309+ }
310+ }
311+ }
312+ }
313+
314+ // Get the global scaling factor, which will be applied to the SF.
315+ // Note SFScale is the same as next GEMM's alpha, which is
316+ // (448.f / (Alpha_A / 6.f)).
317+ float const SFScaleVal = SFScale == nullptr ? 1 .0f : SFScale[expert_idx];
318+
319+ int factor = CVT_FP4_SF_VEC_SIZE * 4 ;
320+ // The actual output_scales dim is computed from the padded numCols.
321+ int32_t numCols_padded = (numCols + factor - 1 ) / factor * factor;
322+ int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4 ;
323+ uint32_t * SFout_in_expert =
324+ SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
325+
326+ auto sf_out =
327+ cvt_quant_to_fp4_get_sf_out_offset<uint32_t ,
328+ CVT_FP4_NUM_THREADS_PER_SF>(
329+ rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
330+
331+ out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
332+ }
333+ #endif
334+ }
273335
274- // Get the global scaling factor, which will be applied to the SF.
275- // Note SFScale is the same as next GEMM's alpha, which is
276- // (448.f / (Alpha_A / 6.f)).
277- float const SFScaleVal = SFScale == nullptr ? 1 .0f : SFScale[expert_idx];
278-
279- int factor = CVT_FP4_SF_VEC_SIZE * 4 ;
280- // The actual output_scales dim is computed from the padded numCols.
281- int32_t numCols_padded = (numCols + factor - 1 ) / factor * factor;
282- int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4 ;
283- uint32_t * SFout_in_expert =
284- SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
285-
286- auto sf_out =
287- cvt_quant_to_fp4_get_sf_out_offset<uint32_t ,
288- CVT_FP4_NUM_THREADS_PER_SF>(
289- rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
290-
291- out_pos =
292- cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
336+ // Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
337+ template <class Type , bool UE8M0_SF = false , bool SMALL_NUM_EXPERTS = false >
338+ __global__ void
339+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
340+ __launch_bounds__ (1024 , 4 ) cvt_fp16_to_fp4(
341+ #else
342+ cvt_fp16_to_fp4 (
343+ #endif
344+ int32_t numRows, int32_t numCols, Type const * in, float const * SFScale,
345+ uint32_t * out, uint32_t * SFout, uint32_t * input_offset_by_experts,
346+ uint32_t * output_scale_offset_by_experts, int n_experts) {
347+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
348+ using PackedVec = PackedVec<Type>;
349+ static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
350+ (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
351+ static_assert (sizeof (PackedVec) == sizeof (Type) * CVT_FP4_ELTS_PER_THREAD,
352+ " Vec size is not matched." );
353+ extern __shared__ uint32_t shared_input_offsets[];
354+
355+ // Load input offsets into shared memory.
356+ // If n_experts is larger than 4, use vectorized int4 to save instructions.
357+ // If n_experts is smaller than 4, read directly.
358+ if constexpr (SMALL_NUM_EXPERTS) {
359+ for (int i = threadIdx .x ; i < n_experts + 1 ; i += blockDim .x ) {
360+ shared_input_offsets[i] = input_offset_by_experts[i];
361+ }
362+ } else {
363+ for (int i = threadIdx .x * 4 ; i < n_experts; i += blockDim .x * 4 ) {
364+ *reinterpret_cast <int4 *>(&shared_input_offsets[i]) =
365+ *reinterpret_cast <const int4 *>(&input_offset_by_experts[i]);
366+ }
367+ if (threadIdx .x == 0 ) {
368+ shared_input_offsets[n_experts] = input_offset_by_experts[n_experts];
293369 }
294370 }
371+
372+ __syncthreads ();
373+
374+ int tid = blockIdx .x * blockDim .x + threadIdx .x ;
375+ int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
376+
377+ // Each global thread processes one element
378+ for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
379+ globalIdx += gridDim .x * blockDim .x ) {
380+ // Calculate which row and column this global thread should process
381+ int rowIdx = globalIdx / colsPerRow;
382+ int colIdx = globalIdx % colsPerRow;
383+
384+ int64_t inOffset = rowIdx * colsPerRow + colIdx;
385+ PackedVec in_vec = reinterpret_cast <PackedVec const *>(in)[inOffset];
386+ int64_t outOffset = inOffset;
387+ auto & out_pos = out[outOffset];
388+
389+ // Find expert using binary search for better performance with large m_topk
390+ int rowIdx_in_expert = 0 ;
391+ int expert_idx = 0 ;
392+
393+ // Binary search through experts using shared memory
394+ int left = 0 , right = n_experts - 1 ;
395+ while (left <= right) {
396+ int mid = (left + right) / 2 ;
397+ // Get offsets: shared_input_offsets[i] corresponds to
398+ // input_offset_by_experts[i]
399+ uint32_t mid_offset = shared_input_offsets[mid];
400+ uint32_t next_offset = shared_input_offsets[mid + 1 ];
401+
402+ if (rowIdx >= mid_offset && rowIdx < next_offset) {
403+ rowIdx_in_expert = rowIdx - mid_offset;
404+ expert_idx = mid;
405+ break ;
406+ } else if (rowIdx < mid_offset) {
407+ right = mid - 1 ;
408+ } else {
409+ left = mid + 1 ;
410+ }
411+ }
412+
413+ float const SFScaleVal = SFScale == nullptr ? 1 .0f : SFScale[expert_idx];
414+
415+ int factor = CVT_FP4_SF_VEC_SIZE * 4 ;
416+ int32_t numCols_padded = (numCols + factor - 1 ) / factor * factor;
417+ int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4 ;
418+ uint32_t * SFout_in_expert =
419+ SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
420+
421+ auto sf_out =
422+ cvt_quant_to_fp4_get_sf_out_offset<uint32_t ,
423+ CVT_FP4_NUM_THREADS_PER_SF>(
424+ rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
425+
426+ out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
427+ }
295428#endif
296429}
297430
@@ -309,18 +442,63 @@ void quant_impl(void* output, void* output_scale, void* input,
309442
310443 // Grid, Block size.
311444 // Each thread converts 8 values.
312- dim3 block (std::min (int (k / ELTS_PER_THREAD), 512 ));
445+ int const workSizePerRow = k / ELTS_PER_THREAD;
446+ int const totalWorkSize = m_topk * workSizePerRow;
447+ dim3 block (std::min (workSizePerRow, 512 ));
313448 // Get number of blocks per SM (assume we can fully utilize the SM).
314449 int const numBlocksPerSM = 2048 / block.x ;
315- dim3 grid (std::min (int (m_topk), multiProcessorCount * numBlocksPerSM));
316-
317- cvt_fp16_to_fp4<T, false ><<<grid, block, 0 , stream>>> (
318- m_topk, k, reinterpret_cast <T*>(input),
319- reinterpret_cast <float *>(input_global_scale),
320- reinterpret_cast <uint32_t *>(output),
321- reinterpret_cast <uint32_t *>(output_scale),
322- reinterpret_cast <uint32_t *>(input_offset_by_experts),
323- reinterpret_cast <uint32_t *>(output_scale_offset_by_experts), n_experts);
450+ dim3 grid (std::min (static_cast <int >((totalWorkSize + block.x - 1 ) / block.x ),
451+ multiProcessorCount * numBlocksPerSM));
452+ while (grid.x <= multiProcessorCount && block.x > 64 ) {
453+ grid.x *= 2 ;
454+ block.x = (block.x + 1 ) / 2 ;
455+ }
456+
457+ int const blockRepeat =
458+ (totalWorkSize + block.x * grid.x - 1 ) / (block.x * grid.x );
459+ if (blockRepeat > 1 ) {
460+ size_t shared_mem_size = (n_experts + 1 ) * sizeof (uint32_t );
461+ if (n_experts >= 4 ) {
462+ cvt_fp16_to_fp4<T, false , false >
463+ <<<grid, block, shared_mem_size, stream>>> (
464+ m_topk, k, reinterpret_cast <T*>(input),
465+ reinterpret_cast <float *>(input_global_scale),
466+ reinterpret_cast <uint32_t *>(output),
467+ reinterpret_cast <uint32_t *>(output_scale),
468+ reinterpret_cast <uint32_t *>(input_offset_by_experts),
469+ reinterpret_cast <uint32_t *>(output_scale_offset_by_experts),
470+ n_experts);
471+ } else {
472+ cvt_fp16_to_fp4<T, false , true ><<<grid, block, shared_mem_size, stream>>> (
473+ m_topk, k, reinterpret_cast <T*>(input),
474+ reinterpret_cast <float *>(input_global_scale),
475+ reinterpret_cast <uint32_t *>(output),
476+ reinterpret_cast <uint32_t *>(output_scale),
477+ reinterpret_cast <uint32_t *>(input_offset_by_experts),
478+ reinterpret_cast <uint32_t *>(output_scale_offset_by_experts),
479+ n_experts);
480+ }
481+ } else {
482+ if (n_experts >= 16 ) {
483+ cvt_fp16_to_fp4<T, false , false ><<<grid, block, 0 , stream>>> (
484+ m_topk, k, reinterpret_cast <T*>(input),
485+ reinterpret_cast <float *>(input_global_scale),
486+ reinterpret_cast <uint32_t *>(output),
487+ reinterpret_cast <uint32_t *>(output_scale),
488+ reinterpret_cast <uint32_t *>(input_offset_by_experts),
489+ reinterpret_cast <uint32_t *>(output_scale_offset_by_experts),
490+ n_experts, /* bool low_latency */ true );
491+ } else {
492+ cvt_fp16_to_fp4<T, false , true ><<<grid, block, 0 , stream>>> (
493+ m_topk, k, reinterpret_cast <T*>(input),
494+ reinterpret_cast <float *>(input_global_scale),
495+ reinterpret_cast <uint32_t *>(output),
496+ reinterpret_cast <uint32_t *>(output_scale),
497+ reinterpret_cast <uint32_t *>(input_offset_by_experts),
498+ reinterpret_cast <uint32_t *>(output_scale_offset_by_experts),
499+ n_experts, /* bool low_latency */ true );
500+ }
501+ }
324502}
325503
326504/* Quantization entry for fp4 experts quantization*/
0 commit comments