@@ -48,18 +48,88 @@ enum class DType {
4848
4949// Data types
5050using e8m0_t = uint8_t ;
51- using bf16 = nv_bfloat16;
51+ using bfloat16 = nv_bfloat16;
5252using fp8e4m3 = __nv_fp8_e4m3;
5353
54- // Constants for dtype conversion
54+ constexpr size_t get_dtype_bits (DType dtype) {
55+ switch (dtype) {
56+ case DType::kFloat32 :
57+ return 32 ;
58+ case DType::kBFloat16 :
59+ return 16 ;
60+ case DType::kFloat8E4M3 :
61+ return 8 ;
62+ default :
63+ // TODO: something smarter than this
64+ return 0 ;
65+ }
66+ }
67+
68+ // FP32 constants
5569constexpr int32_t FP32_MANTISSA_BITS = 23 ;
5670constexpr int32_t FP32_EXPONENT_BIAS = 127 ;
71+
72+ // BF16 constants
73+ constexpr int32_t BF16_MANTISSA_BITS = 7 ;
74+ constexpr int32_t BF16_EXPONENT_BIAS = 127 ;
75+
76+ // FP8E4M3 constants
5777constexpr int32_t F8E4M3_MAX_POW2 = 8 ;
58- constexpr int32_t E8M0_EXPONENT_BIAS= 127 ;
59- constexpr int32_t F32_EXP_BIAS = 127 ;
6078constexpr float F8E4M3_MAX = 448.0 ;
6179
62- // Constants for MXFP8
80+ // FP8E8M0 constants
81+ constexpr int32_t E8M0_EXPONENT_BIAS= 127 ;
82+
83+
84+ // 1. Base template (for unsupported types)
85+ template <typename T>
86+ struct DataTypeTraits {
87+ static constexpr bool is_supported = false ;
88+ };
89+
90+ // 2. Specialization for float32
91+ template <>
92+ struct DataTypeTraits <float > {
93+ static constexpr bool is_supported = true ;
94+ static constexpr int mantissa_bits = 23 ;
95+ static constexpr int exponent_bias = 127 ;
96+
97+ __device__ static __forceinline__ float to_float (const float val) {
98+ return val;
99+ }
100+ };
101+
102+ // 3. Specialization for bfloat16
103+ template <>
104+ struct DataTypeTraits <nv_bfloat16> {
105+ static constexpr bool is_supported = true ;
106+ static constexpr int mantissa_bits = 7 ;
107+ static constexpr int exponent_bias = 127 ;
108+
109+ __device__ static __forceinline__ float to_float (const nv_bfloat16 val) {
110+ return __bfloat162float (val);
111+ }
112+ };
113+
114+ __device__ static __forceinline__ e8m0_t calculate_e8m0_biased_scale (const float amax) {
115+ // torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L239
116+ const int32_t int_amax = *reinterpret_cast <const int32_t *>(&amax);
117+ const int32_t extracted_pow2 = ((int_amax >> FP32_MANTISSA_BITS) & 0b11111111 ) - FP32_EXPONENT_BIAS;
118+
119+ // torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L244
120+ int32_t scale_unbiased = extracted_pow2 - F8E4M3_MAX_POW2;
121+
122+ // torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L256
123+ scale_unbiased = max (scale_unbiased, -E8M0_EXPONENT_BIAS);
124+ scale_unbiased = min (scale_unbiased, E8M0_EXPONENT_BIAS + 1 );
125+ int32_t scale_with_e8m0_bias = scale_unbiased + E8M0_EXPONENT_BIAS;
126+
127+ // torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L261C9-L261C26
128+ const e8m0_t e8m0_biased_scale = *reinterpret_cast <e8m0_t *>(&scale_with_e8m0_bias);
129+ return e8m0_biased_scale;
130+ }
131+
132+ // Constants for MXFP8 kernel
63133constexpr size_t MXFP8_CHUNK_DIM_Y = 64 ;
64134constexpr size_t MXFP8_CHUNK_DIM_X = 64 ;
65135constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1 ;
@@ -343,6 +413,8 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
343413 (unsigned long long )rows, (unsigned long long )cols, (unsigned long long )scales_rowwise_stride_dim0, (unsigned long long )scales_rowwise_stride_dim1, (unsigned long long )scales_colwise_stride_dim0, (unsigned long long )scales_colwise_stride_dim1);
344414#endif
345415
416+ static_assert (DataTypeTraits<IType>::is_supported, " Input data type is not supported by this kernel." );
417+
346418 constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1 ;
347419 constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1 ;
348420
@@ -505,7 +577,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
505577 const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
506578
507579 // Load from shared memory into thread local registers.
508- float elt = static_cast < float > (in.data .elt [j]);
580+ float elt = DataTypeTraits<IType>:: to_float (in.data .elt [j]);
509581 in_compute[j] = elt;
510582
511583 // Update thread local amax.
@@ -564,7 +636,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
564636 const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
565637
566638 // Load from shared memory into thread local registers.
567- float elt = static_cast < float > (in_sh[buff][i][tid_colwise_X]);
639+ float elt = DataTypeTraits<IType>:: to_float (in_sh[buff][i][tid_colwise_X]);
568640 in_compute[i] = elt;
569641
570642 // Update thread local amax.
@@ -580,55 +652,28 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
580652 // ******* END TE original ***********
581653
582654 // ******* Updated implementation based on torchao to_mx() with ScaleCalculationMode=FLOOR **********
583- // torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L239
584- const int32_t int_amax = *reinterpret_cast <const int32_t *>(&amax);
585- const int32_t extracted_pow2 = ((int_amax >> FP32_MANTISSA_BITS) & 0b11111111 ) - FP32_EXPONENT_BIAS;
586-
587- // torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L244
588- int32_t scale_unbiased = extracted_pow2 - F8E4M3_MAX_POW2;
589-
590- // torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L256
591- scale_unbiased = max (scale_unbiased, -E8M0_EXPONENT_BIAS);
592- scale_unbiased = min (scale_unbiased, E8M0_EXPONENT_BIAS + 1 );
593- int32_t scale_with_e8m0_bias = scale_unbiased + E8M0_EXPONENT_BIAS;
594-
595- // torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L261C9-L261C26
596- const e8m0_t e8m0_biased_scale = *reinterpret_cast <e8m0_t *>(&scale_with_e8m0_bias);
655+ const e8m0_t e8m0_biased_scale = calculate_e8m0_biased_scale (amax);
597656
598657 // Calculate scale offsets and write scaling factor.
599658 const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter;
600659 const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X;
601660 const int scale_idx = global_scales_offset_Y * scales_colwise_stride_dim1 + global_scales_offset_X * scales_colwise_stride_dim0;
602661
603- // Debug logging
604- #if defined(DEBUG)
605- printf (" tid_colwise_X=%llu, scales_colwise_stride_dim0=%d, global_scales_offset_Y=%llu, global_scales_offset_X=%llu, scale_idx=%llu, amax=%d, extracted_pow_2=%d, scale_unbiased=%d, scale_with_e8m0_bias=%d, e8m0_biased_scale=%d, col_out_of_bounds=%d\n " ,
606- (unsigned long long )tid_colwise_X,
607- (unsigned long long )global_scales_offset_Y,
608- (unsigned long long )global_scales_offset_X,
609- (unsigned long long )scale_idx,
610- (int )(amax),
611- extracted_pow2,
612- scale_unbiased,
613- scale_with_e8m0_bias,
614- e8m0_biased_scale,
615- col_out_of_bounds);
616- #endif
617-
618662 // Write scales to global memory.
619663 // I had to add this bounds check because the original code was having threads from the second `iter` overwrite values from the first.
620664 const bool row_out_of_bounds = (row_base >= rows);
621665 if (!row_out_of_bounds && !col_out_of_bounds) {
622666 scales_colwise[scale_idx] = e8m0_biased_scale;
623667 }
624668
669+ // Apply scales to do value conversion.
625670 // torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L275C1-L277C30
626671 int32_t exponent_as_int32 = static_cast <int32_t >(e8m0_biased_scale);
627672 int32_t float_bits = exponent_as_int32 << FP32_MANTISSA_BITS;
628673 float scale_fp32 = *reinterpret_cast <float *>(&float_bits);
629674
630675 // torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L286
631- const float F32_MIN_NORMAL = exp2f (-F32_EXP_BIAS + 1 );
676+ const float F32_MIN_NORMAL = exp2f (-FP32_EXPONENT_BIAS + 1 );
632677 scale_fp32 = max (scale_fp32, F32_MIN_NORMAL);
633678
634679 // Use scales to perform value conversion.
@@ -743,35 +788,42 @@ public:
743788 printf (" grid.x=%d, grid.y=%d, block.x=%d, block.y=%d\n " , grid.x , grid.y , block.x , block.y );
744789#endif
745790
791+
746792 // Create TMA descriptors
747793 alignas (64 ) CUtensorMap tensor_map_input{};
748794 alignas (64 ) CUtensorMap tensor_map_output_rowwise{};
749795 alignas (64 ) CUtensorMap tensor_map_output_colwise{};
796+ int32_t input_bits_per_elem = get_dtype_bits (input_dtype);
797+ int32_t output_bits_per_elem = get_dtype_bits (output_dtype);
798+
799+ #if defined(DEBUG)
800+ printf (" input_bits_per_elem=%d, output_bits_per_elem=%d\n " , input_bits_per_elem, output_bits_per_elem);
801+ #endif
750802
751803 create_2D_tensor_map (tensor_map_input,
752804 const_cast <void *>(input),
753805 input_dtype,
754806 rows, cols,
755807 MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X,
756- cols, // input stride along dim0
757- 32 ); // bits per elem in input
808+ cols, // input stride along dim0
809+ input_bits_per_elem); // bits per elem in input
758810
759811 if (output_rowwise) {
760812 create_2D_tensor_map (tensor_map_output_rowwise, output_rowwise,
761813 output_dtype,
762814 rows, cols,
763815 MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X,
764- cols, // input stride along dim0
765- 8 ); // bits per elem in output fp8e4m3
816+ cols, // input stride along dim0
817+ output_bits_per_elem); // bits per elem in output fp8e4m3
766818 }
767819
768820 if (output_colwise) {
769821 create_2D_tensor_map (tensor_map_output_colwise, output_colwise,
770822 output_dtype,
771823 rows, cols,
772824 MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X,
773- cols, // input stride along dim0
774- 8 ); // bits per elem in output fp8e4m3
825+ cols, // input stride along dim0
826+ output_bits_per_elem); // bits per elem in output fp8e4m3
775827 }
776828
777829// Launch kernel based on input/output types and scaling dimensions
@@ -807,6 +859,14 @@ public:
807859 } else if (scale_dim_x == 1 && scale_dim_y == 32 ) {
808860 LAUNCH_KERNEL (float , fp8e4m3, 32 , 1 );
809861 }
862+ } else if (input_dtype == DType::kBFloat16 ) {
863+ if (scale_dim_x == 32 && scale_dim_y == 32 ) {
864+ LAUNCH_KERNEL (bfloat16, fp8e4m3, 32 , 32 );
865+ } else if (scale_dim_x == 32 && scale_dim_y == 1 ) {
866+ LAUNCH_KERNEL (bfloat16, fp8e4m3, 1 , 32 );
867+ } else if (scale_dim_x == 1 && scale_dim_y == 32 ) {
868+ LAUNCH_KERNEL (bfloat16, fp8e4m3, 32 , 1 );
869+ }
810870 } else {
811871 printf (" unsupported input dtype, must be float32\n " );
812872 exit (1 );
0 commit comments