diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index 04afb2b..030b471 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -96,7 +96,6 @@ def run( "dim1_mx", "dim1_mx_triton", "dim1_cuda", - "dim0_dim1_cuda", ) x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000 @@ -213,12 +212,12 @@ def run( raise NotImplementedError("dim0_dim1_cuda not implemented yet") elif mode == "dim1_cuda": - _, y_d1, _, s_d1 = mxfp8_cuda.quantize(x, rowwise=False, colwise=True) + bench_fn = partial(mxfp8_cuda.quantize, rowwise=False, colwise=True) + _, y_d1, _, s_d1 = bench_fn(x) for _ in range(2): - __ = mxfp8_cuda.quantize(x, rowwise=False, colwise=True) + __ = bench_fn(x) - bench_fn = partial(mxfp8_cuda.quantize, rowwise=False, colwise=True) time_us = benchmark_cuda_function_in_microseconds(bench_fn, x) assert y_d1.dtype == torch.float8_e4m3fn diff --git a/torchao/experimental/mxfp8_cpp/mxfp8_extension.cpp b/torchao/experimental/mxfp8_cpp/mxfp8_extension.cpp index 6102392..69e673a 100644 --- a/torchao/experimental/mxfp8_cpp/mxfp8_extension.cpp +++ b/torchao/experimental/mxfp8_cpp/mxfp8_extension.cpp @@ -86,7 +86,13 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise, if (colwise) { const int64_t num_row_blocks = (rows + scale_dim_y - 1) / scale_dim_y; output_colwise = torch::empty({rows, cols}, options_fp8); - scales_colwise = torch::empty({cols, num_row_blocks, 1}, options_scale); + // Need scales_colwise to be this shape so the 'col' dim stride is 1, + // for colwise scaling, we can avoid uncoalesced writes to global memory. + // This is because each of the 32 threads in a warp will be computing + // a scale for a different column of 32 input data values, then each writing + // that scale to global memory - so the stride along this `col` dim should be 1 + // so writes can be coalesced into a single transaction. + scales_colwise = torch::empty({num_row_blocks, cols, 1}, options_scale); } else { output_colwise = torch::empty({0}, options_fp8); scales_colwise = torch::empty({0}, options_scale); @@ -100,7 +106,10 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise, colwise ? scale_dim_y : 1, // scale_dim_y fp8_format); - + // Need to tranpose scales so `cols` is leading dim, to match torchao reference impl. + if (colwise) { + scales_colwise = scales_colwise.transpose(0, 1); + } return std::make_tuple(output_rowwise, output_colwise, scales_rowwise, scales_colwise); } diff --git a/torchao/experimental/mxfp8_cpp/mxfp8_quantize.cuh b/torchao/experimental/mxfp8_cpp/mxfp8_quantize.cuh index 0b1020f..9f839d8 100644 --- a/torchao/experimental/mxfp8_cpp/mxfp8_quantize.cuh +++ b/torchao/experimental/mxfp8_cpp/mxfp8_quantize.cuh @@ -657,7 +657,8 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) // Calculate scale offsets and write scaling factor. const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; - const int scale_idx = global_scales_offset_Y * scales_colwise_stride_dim1 + global_scales_offset_X * scales_colwise_stride_dim0; + // const int scale_idx = global_scales_offset_Y * scales_colwise_stride_dim1 + global_scales_offset_X * scales_colwise_stride_dim0; + const int scale_idx = global_scales_offset_Y * scales_colwise_stride_dim0 + global_scales_offset_X * scales_colwise_stride_dim1; // Write scales to global memory. // I had to add this bounds check because the original code was having threads from the second `iter` overwrite values from the first. @@ -680,7 +681,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) #pragma unroll for (int i = 0; i < SCALE_DIM_Y; ++i) { // torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L289 - // float data_lp = in_compute[i] / scale_fp32; float data_lp = __fdiv_rn(in_compute[i], scale_fp32); #if defined(DEBUG)