Skip to content

Commit 79e359b

Browse files
add bfloat16 support for colwise scaling
stack-info: PR: #10, branch: danielvegamyhre/stack/4
1 parent cb95ba5 commit 79e359b

File tree

3 files changed

+107
-66
lines changed

3 files changed

+107
-66
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -210,29 +210,9 @@ def run(
210210
bps = (bytes_r + bytes_w) / (time_us / 1e6)
211211

212212
elif mode == "dim0_dim1_cuda":
213-
x = x.to(torch.float32)
214-
y_d0, y_d1, s_d0, s_d1 = mxfp8_cuda.quantize(x, rowwise=True, colwise=True)
215-
216-
for _ in range(2):
217-
__ = mxfp8_cuda.quantize(x, rowwise=True, colwise=True)
218-
219-
bench_fn = partial(mxfp8_cuda.quantize, rowwise=True, colwise=True)
220-
time_us = benchmark_cuda_function_in_microseconds(bench_fn, x)
221-
222-
assert y_d0.dtype == torch.float8_e4m3fn
223-
assert s_d0.dtype == torch.float8_e8m0fnu
224-
assert y_d1.dtype == torch.float8_e4m3fn
225-
assert s_d1.dtype == torch.float8_e8m0fnu
226-
227-
bytes_r = x.numel() * bytes_per_el_fp32
228-
bytes_w = (
229-
sum(t.numel() for t in [y_d0, y_d1, s_d0, s_d1]) * bytes_per_el_fp8
230-
)
231-
bytes_rw = bytes_r + bytes_w
232-
bps = bytes_rw / (time_us / 1e6)
213+
raise NotImplementedError("dim0_dim1_cuda not implemented yet")
233214

234215
elif mode == "dim1_cuda":
235-
x = x.to(torch.float32)
236216
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(x, rowwise=False, colwise=True)
237217

238218
for _ in range(2):

test/prototype/mx_formats/test_kernels.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,9 +481,10 @@ def test_triton_mxfp8_dim1_randn(M, K):
481481
)
482482
@pytest.mark.parametrize("M", (32,64,2048))
483483
@pytest.mark.parametrize("K", (32,64,2048))
484-
def test_cuda_mx_dim1_randn(M, K):
484+
@pytest.mark.parametrize("input_dtype", (torch.float32,torch.bfloat16))
485+
def test_cuda_mx_dim1_randn(M, K, input_dtype):
485486
# Use disinct incrementing values from 0 to M*K-1 to make debugging easier.
486-
x = torch.arange(0, M*K, dtype=torch.float32, device="cuda").reshape(M, K).contiguous()
487+
x = torch.arange(0, M*K, dtype=input_dtype, device="cuda").reshape(M, K).contiguous()
487488

488489
y_d1_ref, s_d1_ref = triton_to_mxfp8_dim1_reference(x, block_size=32)
489490
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(x, rowwise=False, colwise=True)

torchao/experimental/mxfp8_cpp/mxfp8_quantize.cuh

Lines changed: 103 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,88 @@ enum class DType {
4848

4949
// Data types
5050
using e8m0_t = uint8_t;
51-
using bf16 = nv_bfloat16;
51+
using bfloat16 = nv_bfloat16;
5252
using fp8e4m3 = __nv_fp8_e4m3;
5353

54-
// Constants for dtype conversion
54+
constexpr size_t get_dtype_bits(DType dtype) {
55+
switch (dtype) {
56+
case DType::kFloat32:
57+
return 32;
58+
case DType::kBFloat16:
59+
return 16;
60+
case DType::kFloat8E4M3:
61+
return 8;
62+
default:
63+
// TODO: something smarter than this
64+
return 0;
65+
}
66+
}
67+
68+
// FP32 constants
5569
constexpr int32_t FP32_MANTISSA_BITS = 23;
5670
constexpr int32_t FP32_EXPONENT_BIAS = 127;
71+
72+
// BF16 constants
73+
constexpr int32_t BF16_MANTISSA_BITS = 7;
74+
constexpr int32_t BF16_EXPONENT_BIAS = 127;
75+
76+
// FP8E4M3 constants
5777
constexpr int32_t F8E4M3_MAX_POW2 = 8;
58-
constexpr int32_t E8M0_EXPONENT_BIAS= 127;
59-
constexpr int32_t F32_EXP_BIAS = 127;
6078
constexpr float F8E4M3_MAX = 448.0;
6179

62-
// Constants for MXFP8
80+
// FP8E8M0 constants
81+
constexpr int32_t E8M0_EXPONENT_BIAS= 127;
82+
83+
84+
// 1. Base template (for unsupported types)
85+
template <typename T>
86+
struct DataTypeTraits {
87+
static constexpr bool is_supported = false;
88+
};
89+
90+
// 2. Specialization for float32
91+
template <>
92+
struct DataTypeTraits<float> {
93+
static constexpr bool is_supported = true;
94+
static constexpr int mantissa_bits = 23;
95+
static constexpr int exponent_bias = 127;
96+
97+
__device__ static __forceinline__ float to_float(const float val) {
98+
return val;
99+
}
100+
};
101+
102+
// 3. Specialization for bfloat16
103+
template <>
104+
struct DataTypeTraits<nv_bfloat16> {
105+
static constexpr bool is_supported = true;
106+
static constexpr int mantissa_bits = 7;
107+
static constexpr int exponent_bias = 127;
108+
109+
__device__ static __forceinline__ float to_float(const nv_bfloat16 val) {
110+
return __bfloat162float(val);
111+
}
112+
};
113+
114+
__device__ static __forceinline__ e8m0_t calculate_e8m0_biased_scale(const float amax) {
115+
// torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L239
116+
const int32_t int_amax = *reinterpret_cast<const int32_t*>(&amax);
117+
const int32_t extracted_pow2 = ((int_amax >> FP32_MANTISSA_BITS) & 0b11111111) - FP32_EXPONENT_BIAS;
118+
119+
// torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L244
120+
int32_t scale_unbiased = extracted_pow2 - F8E4M3_MAX_POW2;
121+
122+
// torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L256
123+
scale_unbiased = max(scale_unbiased, -E8M0_EXPONENT_BIAS);
124+
scale_unbiased = min(scale_unbiased, E8M0_EXPONENT_BIAS + 1);
125+
int32_t scale_with_e8m0_bias = scale_unbiased + E8M0_EXPONENT_BIAS;
126+
127+
// torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L261C9-L261C26
128+
const e8m0_t e8m0_biased_scale = *reinterpret_cast<e8m0_t*>(&scale_with_e8m0_bias);
129+
return e8m0_biased_scale;
130+
}
131+
132+
// Constants for MXFP8 kernel
63133
constexpr size_t MXFP8_CHUNK_DIM_Y = 64;
64134
constexpr size_t MXFP8_CHUNK_DIM_X = 64;
65135
constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1;
@@ -343,6 +413,8 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
343413
(unsigned long long)rows, (unsigned long long)cols, (unsigned long long)scales_rowwise_stride_dim0, (unsigned long long)scales_rowwise_stride_dim1, (unsigned long long)scales_colwise_stride_dim0, (unsigned long long)scales_colwise_stride_dim1);
344414
#endif
345415

416+
static_assert(DataTypeTraits<IType>::is_supported, "Input data type is not supported by this kernel.");
417+
346418
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1;
347419
constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1;
348420

@@ -505,7 +577,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
505577
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
506578

507579
// Load from shared memory into thread local registers.
508-
float elt = static_cast<float>(in.data.elt[j]);
580+
float elt = DataTypeTraits<IType>::to_float(in.data.elt[j]);
509581
in_compute[j] = elt;
510582

511583
// Update thread local amax.
@@ -564,7 +636,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
564636
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
565637

566638
// Load from shared memory into thread local registers.
567-
float elt = static_cast<float>(in_sh[buff][i][tid_colwise_X]);
639+
float elt = DataTypeTraits<IType>::to_float(in_sh[buff][i][tid_colwise_X]);
568640
in_compute[i] = elt;
569641

570642
// Update thread local amax.
@@ -580,55 +652,28 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
580652
// ******* END TE original ***********
581653

582654
// ******* Updated implementation based on torchao to_mx() with ScaleCalculationMode=FLOOR **********
583-
// torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L239
584-
const int32_t int_amax = *reinterpret_cast<const int32_t*>(&amax);
585-
const int32_t extracted_pow2 = ((int_amax >> FP32_MANTISSA_BITS) & 0b11111111) - FP32_EXPONENT_BIAS;
586-
587-
// torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L244
588-
int32_t scale_unbiased = extracted_pow2 - F8E4M3_MAX_POW2;
589-
590-
// torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L256
591-
scale_unbiased = max(scale_unbiased, -E8M0_EXPONENT_BIAS);
592-
scale_unbiased = min(scale_unbiased, E8M0_EXPONENT_BIAS + 1);
593-
int32_t scale_with_e8m0_bias = scale_unbiased + E8M0_EXPONENT_BIAS;
594-
595-
// torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L261C9-L261C26
596-
const e8m0_t e8m0_biased_scale = *reinterpret_cast<e8m0_t*>(&scale_with_e8m0_bias);
655+
const e8m0_t e8m0_biased_scale = calculate_e8m0_biased_scale(amax);
597656

598657
// Calculate scale offsets and write scaling factor.
599658
const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter;
600659
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X;
601660
const int scale_idx = global_scales_offset_Y * scales_colwise_stride_dim1 + global_scales_offset_X * scales_colwise_stride_dim0;
602661

603-
// Debug logging
604-
#if defined(DEBUG)
605-
printf("tid_colwise_X=%llu, scales_colwise_stride_dim0=%d, global_scales_offset_Y=%llu, global_scales_offset_X=%llu, scale_idx=%llu, amax=%d, extracted_pow_2=%d, scale_unbiased=%d, scale_with_e8m0_bias=%d, e8m0_biased_scale=%d, col_out_of_bounds=%d\n",
606-
(unsigned long long)tid_colwise_X,
607-
(unsigned long long)global_scales_offset_Y,
608-
(unsigned long long)global_scales_offset_X,
609-
(unsigned long long)scale_idx,
610-
(int)(amax),
611-
extracted_pow2,
612-
scale_unbiased,
613-
scale_with_e8m0_bias,
614-
e8m0_biased_scale,
615-
col_out_of_bounds);
616-
#endif
617-
618662
// Write scales to global memory.
619663
// I had to add this bounds check because the original code was having threads from the second `iter` overwrite values from the first.
620664
const bool row_out_of_bounds = (row_base >= rows);
621665
if (!row_out_of_bounds && !col_out_of_bounds) {
622666
scales_colwise[scale_idx] = e8m0_biased_scale;
623667
}
624668

669+
// Apply scales to do value conversion.
625670
// torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L275C1-L277C30
626671
int32_t exponent_as_int32 = static_cast<int32_t>(e8m0_biased_scale);
627672
int32_t float_bits = exponent_as_int32 << FP32_MANTISSA_BITS;
628673
float scale_fp32 = *reinterpret_cast<float*>(&float_bits);
629674

630675
// torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L286
631-
const float F32_MIN_NORMAL = exp2f(-F32_EXP_BIAS + 1);
676+
const float F32_MIN_NORMAL = exp2f(-FP32_EXPONENT_BIAS + 1);
632677
scale_fp32 = max(scale_fp32, F32_MIN_NORMAL);
633678

634679
// Use scales to perform value conversion.
@@ -743,35 +788,42 @@ public:
743788
printf("grid.x=%d, grid.y=%d, block.x=%d, block.y=%d\n", grid.x, grid.y, block.x, block.y);
744789
#endif
745790

791+
746792
// Create TMA descriptors
747793
alignas(64) CUtensorMap tensor_map_input{};
748794
alignas(64) CUtensorMap tensor_map_output_rowwise{};
749795
alignas(64) CUtensorMap tensor_map_output_colwise{};
796+
int32_t input_bits_per_elem = get_dtype_bits(input_dtype);
797+
int32_t output_bits_per_elem = get_dtype_bits(output_dtype);
798+
799+
#if defined(DEBUG)
800+
printf("input_bits_per_elem=%d, output_bits_per_elem=%d\n", input_bits_per_elem, output_bits_per_elem);
801+
#endif
750802

751803
create_2D_tensor_map(tensor_map_input,
752804
const_cast<void *>(input),
753805
input_dtype,
754806
rows, cols,
755807
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X,
756-
cols, // input stride along dim0
757-
32); // bits per elem in input
808+
cols, // input stride along dim0
809+
input_bits_per_elem); // bits per elem in input
758810

759811
if (output_rowwise) {
760812
create_2D_tensor_map(tensor_map_output_rowwise, output_rowwise,
761813
output_dtype,
762814
rows, cols,
763815
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X,
764-
cols, // input stride along dim0
765-
8); // bits per elem in output fp8e4m3
816+
cols, // input stride along dim0
817+
output_bits_per_elem); // bits per elem in output fp8e4m3
766818
}
767819

768820
if (output_colwise) {
769821
create_2D_tensor_map(tensor_map_output_colwise, output_colwise,
770822
output_dtype,
771823
rows, cols,
772824
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X,
773-
cols, // input stride along dim0
774-
8); // bits per elem in output fp8e4m3
825+
cols, // input stride along dim0
826+
output_bits_per_elem); // bits per elem in output fp8e4m3
775827
}
776828

777829
// Launch kernel based on input/output types and scaling dimensions
@@ -807,6 +859,14 @@ public:
807859
} else if (scale_dim_x == 1 && scale_dim_y == 32) {
808860
LAUNCH_KERNEL(float, fp8e4m3, 32, 1);
809861
}
862+
} else if (input_dtype == DType::kBFloat16) {
863+
if (scale_dim_x == 32 && scale_dim_y == 32) {
864+
LAUNCH_KERNEL(bfloat16, fp8e4m3, 32, 32);
865+
} else if (scale_dim_x == 32 && scale_dim_y == 1) {
866+
LAUNCH_KERNEL(bfloat16, fp8e4m3, 1, 32);
867+
} else if (scale_dim_x == 1 && scale_dim_y == 32) {
868+
LAUNCH_KERNEL(bfloat16, fp8e4m3, 32, 1);
869+
}
810870
} else {
811871
printf("unsupported input dtype, must be float32\n");
812872
exit(1);

0 commit comments

Comments
 (0)