Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def run(
"dim1_mx",
"dim1_mx_triton",
"dim1_cuda",
"dim0_dim1_cuda",
)

x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
Expand Down Expand Up @@ -213,12 +212,12 @@ def run(
raise NotImplementedError("dim0_dim1_cuda not implemented yet")

elif mode == "dim1_cuda":
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(x, rowwise=False, colwise=True)
bench_fn = partial(mxfp8_cuda.quantize, rowwise=False, colwise=True)
_, y_d1, _, s_d1 = bench_fn(x)

for _ in range(2):
__ = mxfp8_cuda.quantize(x, rowwise=False, colwise=True)
__ = bench_fn(x)

bench_fn = partial(mxfp8_cuda.quantize, rowwise=False, colwise=True)
time_us = benchmark_cuda_function_in_microseconds(bench_fn, x)

assert y_d1.dtype == torch.float8_e4m3fn
Expand Down
13 changes: 11 additions & 2 deletions torchao/experimental/mxfp8_cpp/mxfp8_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,13 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise,
if (colwise) {
const int64_t num_row_blocks = (rows + scale_dim_y - 1) / scale_dim_y;
output_colwise = torch::empty({rows, cols}, options_fp8);
scales_colwise = torch::empty({cols, num_row_blocks, 1}, options_scale);
// Need scales_colwise to be this shape so the 'col' dim stride is 1,
// for colwise scaling, we can avoid uncoalesced writes to global memory.
// This is because each of the 32 threads in a warp will be computing
// a scale for a different column of 32 input data values, then each writing
// that scale to global memory - so the stride along this `col` dim should be 1
// so writes can be coalesced into a single transaction.
scales_colwise = torch::empty({num_row_blocks, cols, 1}, options_scale);
} else {
output_colwise = torch::empty({0}, options_fp8);
scales_colwise = torch::empty({0}, options_scale);
Expand All @@ -100,7 +106,10 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise,
colwise ? scale_dim_y : 1, // scale_dim_y
fp8_format);


// Need to tranpose scales so `cols` is leading dim, to match torchao reference impl.
if (colwise) {
scales_colwise = scales_colwise.transpose(0, 1);
}
return std::make_tuple(output_rowwise, output_colwise, scales_rowwise,
scales_colwise);
}
Expand Down
4 changes: 2 additions & 2 deletions torchao/experimental/mxfp8_cpp/mxfp8_quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,8 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
// Calculate scale offsets and write scaling factor.
const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter;
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X;
const int scale_idx = global_scales_offset_Y * scales_colwise_stride_dim1 + global_scales_offset_X * scales_colwise_stride_dim0;
// const int scale_idx = global_scales_offset_Y * scales_colwise_stride_dim1 + global_scales_offset_X * scales_colwise_stride_dim0;
const int scale_idx = global_scales_offset_Y * scales_colwise_stride_dim0 + global_scales_offset_X * scales_colwise_stride_dim1;

// Write scales to global memory.
// I had to add this bounds check because the original code was having threads from the second `iter` overwrite values from the first.
Expand All @@ -680,7 +681,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
// torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L289
// float data_lp = in_compute[i] / scale_fp32;
float data_lp = __fdiv_rn(in_compute[i], scale_fp32);

#if defined(DEBUG)
Expand Down