Skip to content

Commit 3631dea

Browse files
Fix uncoalesced global accesses
stack-info: PR: #11, branch: danielvegamyhre/stack/5
1 parent 2db90aa commit 3631dea

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def run(
9696
"dim1_mx",
9797
"dim1_mx_triton",
9898
"dim1_cuda",
99-
"dim0_dim1_cuda",
10099
)
101100

102101
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
@@ -213,12 +212,12 @@ def run(
213212
raise NotImplementedError("dim0_dim1_cuda not implemented yet")
214213

215214
elif mode == "dim1_cuda":
216-
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(x, rowwise=False, colwise=True)
215+
bench_fn = partial(mxfp8_cuda.quantize, rowwise=False, colwise=True)
216+
_, y_d1, _, s_d1 = bench_fn(x)
217217

218218
for _ in range(2):
219-
__ = mxfp8_cuda.quantize(x, rowwise=False, colwise=True)
219+
__ = bench_fn(x)
220220

221-
bench_fn = partial(mxfp8_cuda.quantize, rowwise=False, colwise=True)
222221
time_us = benchmark_cuda_function_in_microseconds(bench_fn, x)
223222

224223
assert y_d1.dtype == torch.float8_e4m3fn

torchao/experimental/mxfp8_cpp/mxfp8_extension.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,13 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise,
8686
if (colwise) {
8787
const int64_t num_row_blocks = (rows + scale_dim_y - 1) / scale_dim_y;
8888
output_colwise = torch::empty({rows, cols}, options_fp8);
89-
scales_colwise = torch::empty({cols, num_row_blocks, 1}, options_scale);
89+
// Need scales_colwise to be this shape so the 'col' dim stride is 1,
90+
// for colwise scaling, we can avoid uncoalesced writes to global memory.
91+
// This is because each of the 32 threads in a warp will be computing
92+
// a scale for a different column of 32 input data values, then each writing
93+
// that scale to global memory - so the stride along this `col` dim should be 1
94+
// so writes can be coalesced into a single transaction.
95+
scales_colwise = torch::empty({num_row_blocks, cols, 1}, options_scale);
9096
} else {
9197
output_colwise = torch::empty({0}, options_fp8);
9298
scales_colwise = torch::empty({0}, options_scale);
@@ -100,7 +106,10 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise,
100106
colwise ? scale_dim_y : 1, // scale_dim_y
101107
fp8_format);
102108

103-
109+
// Need to tranpose scales so `cols` is leading dim, to match torchao reference impl.
110+
if (colwise) {
111+
scales_colwise = scales_colwise.transpose(0, 1);
112+
}
104113
return std::make_tuple(output_rowwise, output_colwise, scales_rowwise,
105114
scales_colwise);
106115
}

torchao/experimental/mxfp8_cpp/mxfp8_quantize.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,8 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
657657
// Calculate scale offsets and write scaling factor.
658658
const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter;
659659
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X;
660-
const int scale_idx = global_scales_offset_Y * scales_colwise_stride_dim1 + global_scales_offset_X * scales_colwise_stride_dim0;
660+
// const int scale_idx = global_scales_offset_Y * scales_colwise_stride_dim1 + global_scales_offset_X * scales_colwise_stride_dim0;
661+
const int scale_idx = global_scales_offset_Y * scales_colwise_stride_dim0 + global_scales_offset_X * scales_colwise_stride_dim1;
661662

662663
// Write scales to global memory.
663664
// I had to add this bounds check because the original code was having threads from the second `iter` overwrite values from the first.
@@ -680,7 +681,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
680681
#pragma unroll
681682
for (int i = 0; i < SCALE_DIM_Y; ++i) {
682683
// torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L289
683-
// float data_lp = in_compute[i] / scale_fp32;
684684
float data_lp = __fdiv_rn(in_compute[i], scale_fp32);
685685

686686
#if defined(DEBUG)

0 commit comments

Comments
 (0)