Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torchao/experimental/mxfp8_cpp/mxfp8_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise,

const int64_t rows = input.size(0);
const int64_t cols = input.size(1);
TORCH_CHECK((rows >= 32) && (rows % 32 == 0), "rows must be a multiple of 32");
TORCH_CHECK((cols >= 32) && (cols % 32 == 0), "cols must be a multiple of 32");

c10::cuda::CUDAGuard device_guard(input.device());

Expand Down
9 changes: 5 additions & 4 deletions torchao/experimental/mxfp8_cpp/mxfp8_quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)

// Column-wise scaling
if constexpr (USE_COLWISE_SCALING) {
// TODO(danvm): these bounds checks are different from reference since we are
// not handling dbias, how can we verify if they are correct?
const bool col_out_of_bounds = (chunk_offset_X + tid_colwise_X >= cols);

float in_compute[SCALE_DIM_Y];
Expand Down Expand Up @@ -678,10 +676,13 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
scale_fp32 = max(scale_fp32, F32_MIN_NORMAL);

// Use scales to perform value conversion.
// Do reciprocal once so we can do fast multiplies instead of division in the loop.
const float inv_scale_fp32 = __fdiv_rn(1.0f, scale_fp32);

#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
// torchao ref: https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L289
float data_lp = __fdiv_rn(in_compute[i], scale_fp32);
float data_lp = in_compute[i] * inv_scale_fp32;

#if defined(DEBUG)
printf("tid_colwise_X=%d, amax=%d, data_hp=%f, scale_fp32=%f, data_lp=%f\n", tid_colwise_X, amax, in_compute[i], scale_fp32, data_lp);
Expand Down Expand Up @@ -868,7 +869,7 @@ public:
LAUNCH_KERNEL(bfloat16, fp8e4m3, 32, 1);
}
} else {
printf("unsupported input dtype, must be float32\n");
printf("unsupported input dtype, must be float32 or bfloat16\n");
exit(1);
}
#endif
Expand Down
Loading