diff --git a/torchao/experimental/mxfp8_cpp/cast_kernels.cuh b/torchao/experimental/mxfp8_cpp/cast_kernels.cuh new file mode 100644 index 0000000..c1a06db --- /dev/null +++ b/torchao/experimental/mxfp8_cpp/cast_kernels.cuh @@ -0,0 +1,1468 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cast_kernels.cuh + * \brief CUDA kernels to cast to/from FP8/MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ +#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ + +#include +#include +#include +#include + +#include + +#include "../common.h" +#include "../transpose/cast_transpose.h" +#include "../util/vectorized_pointwise.h" +#include "../utils.cuh" +#include "math.h" +#include "ptx.cuh" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { + +constexpr size_t MXFP8_CHUNK_DIM_Y = 64; +constexpr size_t MXFP8_CHUNK_DIM_X = 64; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK = + MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; +constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; +constexpr size_t MXFP8_BUFFERS_NUM = 2; +constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1; +static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM); + +constexpr size_t ELEMS_PER_THREAD = 16; +constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported +constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 +constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64 + +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = + MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16 +constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = + MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4 +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t MXFP8_BUFF_STAGES_NUM = + MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 +constexpr size_t MXFP8_ITERATIONS = + MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 +static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM); + +template +__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) cast_mxfp8_2D_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const float *noop, float *const dbias_workspace, float *const amax_ptr, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { + if (noop != nullptr && noop[0] == 1.0f) + return; + } + + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = + MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = + MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = + SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = + SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = + MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = + MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = + SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_X = + SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 + + const int block_offset_Y = + blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; + const int block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; + const int scales_rowwise_block_offset_Y = + blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; + const int scales_rowwise_block_offset_X = + blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; + const int scales_colwise_block_offset_Y = + blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; + const int scales_colwise_block_offset_X = + blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; + + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + + const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + // const int thread_offset_X_colwise = tid_colwise_X; + + const int dbias_rowwise_offset_Y = + blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y; + const int dbias_rowwise_block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + + thread_offset_X_rowwise; + const int dbias_colwise_offset_Y = blockIdx.y; + const int dbias_colwise_block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X; + const int dbias_stride = cols; + + Vec partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X]; + float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X]; + if constexpr (IS_DBIAS) { + if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + partial_dbias_rowwise[i].clear(); + } + } else { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + partial_dbias_colwise[i] = 0; + } + } + } + + // The destination shared memory buffer of a bulk tensor operation should be + // 128 e8m0_t aligned + __shared__ alignas(128) + IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) + IType act_in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) OType + out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) OType + out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + float block_amax = 0; + +// Initialize shared memory barrier with the number of threads participating in +// the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; + + initialize_barriers( + mbar, is_master_thread); + + int parity = 0; +#pragma unroll + for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { + const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; + const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; + + const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; + const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int dbias_rowwise_offset_X = + dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + const int dbias_colwise_offset_X = + dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int scales_rowwise_chunk_offset_Y = + scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = + scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = + scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = + scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; + ++prefetch_buff) { + const int chunk_stage_offset_Y = + chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2( + &in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, &act_in_sh[prefetch_buff], + &tensor_map_act_input, chunk_stage_offset_X, chunk_stage_offset_Y, + shmem_buff_size, &mbar[prefetch_buff], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, + chunk_stage_offset_X, chunk_stage_offset_Y, + shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + } + +#pragma unroll + for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { + const int buff = iter % MXFP8_BUFFERS_NUM; + const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + + if (next_iter < MXFP8_ITERATIONS) { + const int next_buff = next_iter % MXFP8_BUFFERS_NUM; + const int chunk_it_offset_y = + chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2( + &in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, + &mbar[next_iter], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, + chunk_it_offset_x, chunk_it_offset_y, + shmem_buff_size, &mbar[next_iter], + is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + + if constexpr (USE_ROWWISE_SCALING) { + Vec in; + Vec act_in; + Vec out_c; + + const int iteration_scale_rowwise_offset_Y = + scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + +#pragma unroll + for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X_rowwise; + + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = (row >= rows); + + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + } + + float thread_amax = 0; + float in_compute[ELEMS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float elt = static_cast(in.data.elt[j]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[j]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + if (!out_of_bounds) { + partial_dbias_rowwise[chunk_X].data.elt[j] += elt; + } + } + in_compute[j] = elt; + + if constexpr (IS_ACT || IS_DACT) { + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + const float subwarp_amax = + subwarp_reduce_max_broadcast(thread_amax); + const e8m0_t biased_exponent = float_to_e8m0( + subwarp_amax * Quantized_Limits::max_norm_rcp); + + // Only single thread writes the computed scaling factor + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + + tid_rowwise_Y; + const int global_scales_offset_X = + scales_rowwise_chunk_offset_X + + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + const int scale_idx = + global_scales_offset_Y * scale_stride_rowwise + + global_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent; + } + + const float block_scale_inverse = exp2f_rcp(biased_exponent); + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + out_c.data.elt[j] = + static_cast(in_compute[j] * block_scale_inverse); + } + out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); + } + } + + if constexpr (USE_COLWISE_SCALING) { + const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + float in_compute[SCALE_DIM_Y]; + + float amax = 0; +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const size_t row = row_base + i; + const bool row_out_of_bounds = (row >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float elt = static_cast(in_sh[buff][i][tid_colwise_X]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = + static_cast(act_in_sh[buff][i][tid_colwise_X]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if (!out_of_bounds) { + partial_dbias_colwise[chunk_X] += elt; + } + } + in_compute[i] = elt; + if constexpr (IS_ACT || IS_DACT) { + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + amax = fmaxf(amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(amax >= 0); + block_amax = fmaxf(block_amax, amax); + + const e8m0_t biased_exponent = + float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; + const int global_scales_offset_X = + scales_colwise_chunk_offset_X + tid_colwise_X; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = exp2f_rcp(biased_exponent); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + out_colwise_sh[buff][i][tid_colwise_X] = + static_cast(in_compute[i] * block_scale_inverse); + } + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = + chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (USE_ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), + chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(&out_rowwise_sh[buff])); + } + if constexpr (USE_COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), + chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(&out_colwise_sh[buff])); + } + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + } + + if constexpr (IS_DBIAS) { + if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X; + constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; + constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; + __shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD]; + + if (tid_rowwise_Y > 0) { +#pragma unroll + for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { + partial_dbias_rowwise[c].store_to( + &shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1] + [tid_rowwise_X]); + } + } + __syncthreads(); + + if (tid_rowwise_Y == 0) { +#pragma unroll + for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { + Vec other_row_dbias; + const int dbias_rowwise_offset_X = + dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; + const int dbias_offset = + dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; + + const int left_bound = dbias_rowwise_offset_X; + const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1; + +#pragma unroll + for (int i = 0; i < Y; ++i) { + other_row_dbias.load_from( + &shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + partial_dbias_rowwise[c].data.elt[j] += + other_row_dbias.data.elt[j]; + } + } + + // Vectorized store when all elements are inside the boundaries + if (right_bound < cols) { + partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]); + } else if (left_bound < cols && right_bound >= cols) { + // Element-by-element store when some elements cross the boundaries + const int in_bound_elts_count = cols - left_bound; + partial_dbias_rowwise[c].store_to_elts( + &dbias_workspace[dbias_offset], 0, in_bound_elts_count); + } + } + } + } else { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + const int dbias_colwise_offset_X = + dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X; + const int dbias_offset = + dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X; + const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias_colwise[i]; + } + } + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max( + block_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +constexpr size_t FP8_CHUNK_DIM_Y = 128; +constexpr size_t FP8_CHUNK_DIM_X = 128; +constexpr size_t FP8_THREADS_PER_CHUNK = 128; +constexpr size_t FP8_BUFFERS_NUM = 2; +constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; +static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); + +constexpr size_t FP8_BUFFER_DIM_Y = 16; +constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 +constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 + +constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_ITERATIONS = + FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 +static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); + +template +__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) + cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output, + float *const dbias_workspace, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, + const size_t rows, const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const int block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; + + const int tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; + const int tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; + + const int thread_offset_Y = tid_Y; + const int thread_offset_X = tid_X; + + const int dbias_offset_Y = blockIdx.y + tid_Y; + const int my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; + const bool col_out_of_bounds = my_column >= cols; + const int dbias_stride = cols; + + float partial_dbias = 0.f; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be + // 128-byte aligned + __shared__ alignas(128) + IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(128) + IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(128) + OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in +// the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; + + initialize_barriers(mbar, + is_master_thread); + + int parity = 0; + + const int chunk_offset_Y = block_offset_Y; + const int chunk_offset_X = block_offset_X; + +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; + ++prefetch_buff) { + const int chunk_stage_offset_Y = + chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2( + &in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, &act_in_sh[prefetch_buff], + &tensor_map_act_input, chunk_stage_offset_X, chunk_stage_offset_Y, + shmem_buff_size, &mbar[prefetch_buff], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, + chunk_stage_offset_X, chunk_stage_offset_Y, + shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + } + +#pragma unroll + for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { + const int buff = iter % FP8_BUFFERS_NUM; + const int next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; + if (next_iter < FP8_ITERATIONS) { + const int next_buff = next_iter % FP8_BUFFERS_NUM; + const int chunk_it_offset_y = + chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2( + &in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, + &mbar[next_iter], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, + &mbar[next_iter], is_master_thread); + } + } + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = row >= rows; + const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; + + float elt = + static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); + if constexpr (IS_DACT) { + float act_in_elt = + static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if constexpr (IS_DACT) { + if (!out_of_bounds) { + partial_dbias += elt; + } + } else { + // If no activation, elt is 0 so we can safely do this + partial_dbias += elt; + } + } + __builtin_assume(amax >= 0); + if (IS_DACT) { + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + amax = fmaxf(amax, fabsf(elt)); + } + out_sh[buff][shmem_offset_y][shmem_offset_x] = + static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), + chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(&out_sh[buff])); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + + if constexpr (IS_DBIAS) { + const int dbias_offset_X = my_column; + const int dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias; + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +constexpr size_t CHUNKS_PER_BLOCK = 128; +constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; +constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; +constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; +constexpr size_t CHUNKS_PER_ITERATION = 32; +constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; +constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; +constexpr size_t SHMEM_BUFFERS = 2; +static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, + float *const amax_ptr, float *const scale_inv_ptr, + const float *const scale_ptr, const size_t N) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const int block_offset = blockIdx.x * ELEMS_PER_BLOCK; + const IType *input = input_ptr + block_offset; + OType *output = output_ptr + block_offset; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be + // 128-byte aligned + __shared__ alignas(128) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(128) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; + + constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; + constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in +// the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), + is_master_thread); + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + const int buff = iter % SHMEM_BUFFERS; + const int it_offset = iter * SHMEM_DIM; + + const int next_iter = iter + 1; + const int next_buff = next_iter % SHMEM_BUFFERS; + const int next_iter_offset = next_iter * SHMEM_DIM; + + if (next_iter < ITERATIONS) { + copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, + transaction_size_IN, &(mbar[next_iter]), + is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { + const int shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; + float elt = static_cast(in_sh[buff][shmem_offset]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(elt)); + out_sh[buff][shmem_offset] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + ptx::cp_async_bulk_tensor_1d_shared_to_global( + reinterpret_cast(output + it_offset), + reinterpret_cast(&out_sh[buff]), transaction_size_OUT); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read<1>(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; +template +__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) + reduce_dbias_kernel(OType *const dbias_output, + const float *const dbias_partial, const int rows, + const int cols) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= cols) { + return; + } + + const float *const thread_in_base = dbias_partial + thread_id * nvec; + OType *const thread_out_base = dbias_output + thread_id * nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < rows; ++i) { + ldg_vec.load_from(thread_in_base + i * cols); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base); +} + +template +void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, + const size_t cols, cudaStream_t stream) { + constexpr int reduce_dbias_store_bytes = 8; // stg.64 + constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + + NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); + const size_t reduce_dbias_num_blocks = + DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); + + reduce_dbias_kernel + <<>>( + reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, + cols); +} + +template +static void cast_fp8_1D(const Tensor &input, Tensor *output, + cudaStream_t stream) { + const size_t N = product(input.data.shape); + + const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); + NVTE_CHECK(isFullTile, "Only full tiles are supported."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, + "Scaling tensor must be allocated"); + + const size_t chunks = DIVUP(N, CHUNK_SIZE); + const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = + reinterpret_cast(output->scale_inv.dptr); + const float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(THREADS_PER_BLOCK); + const dim3 grid(blocks); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + const IType *input_ptr = + reinterpret_cast(input.data.dptr); + OType *output_ptr = reinterpret_cast(output->data.dptr); + + cast_fp8_1D_kernel + <<>>(input_ptr, output_ptr, amax_ptr, + scale_inv_ptr, scale_ptr, + N);); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, + Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + checkCuDriverContext(stream); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); + const size_t blocks_Y = chunks_Y; + const size_t blocks_X = chunks_X; + + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, + "Scaling tensor must be allocated"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.data.dtype, + "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, + "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + float *const workspace_ptr = + IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = + reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(FP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->data.dtype, OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, + FP8_SHMEM_DIM_Y, FP8_SHMEM_DIM_X, cols, 0, + typeToNumBits(input.data.dtype)); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, + cols, FP8_SHMEM_DIM_Y, FP8_SHMEM_DIM_X, cols, + 0, typeToNumBits(input.data.dtype)); + } + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, + FP8_SHMEM_DIM_Y, FP8_SHMEM_DIM_X, cols, 0, + typeToNumBits(output->data.dtype)); + + cast_fp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output, + workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols); + + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, + stream); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void mxfp8_quantize(const Tensor &input, const Tensor *act_input, + const Tensor *noop, // TODO (ksivamani) + Tensor *output, Tensor *dbias, Tensor *workspace, + cudaStream_t stream) { + bool use_rowwise_scaling = output->has_data(); + bool use_colwise_scaling = output->has_columnwise_data(); + checkCuDriverContext(stream); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + + if (use_rowwise_scaling) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, + "Scaling tensor must be allocated"); + } + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + // TODO: Make more general + const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; + const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); + const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); + + const size_t scale_stride_rowwise = + use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + e8m0_t *const scales_rowwise_ptr = + use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) + : nullptr; + e8m0_t *const scales_colwise_ptr = + use_colwise_scaling + ? reinterpret_cast(output->columnwise_scale_inv.dptr) + : nullptr; + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.dtype(), + "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, + "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + float *const workspace_ptr = + IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + + const dim3 block(MXFP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, + cols, 0, typeToNumBits(input.dtype())); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, + rows, cols, MXFP8_SHMEM_DIM_Y, + MXFP8_SHMEM_DIM_X, cols, 0, + typeToNumBits(input.dtype())); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map( + tensor_map_output_rowwise, output->data, rows, cols, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, + typeToNumBits(output->dtype())); + } + + if (use_colwise_scaling) { + create_2D_tensor_map( + tensor_map_output_colwise, output->columnwise_data, + rows, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, + 0, typeToNumBits(output->dtype())); + } + + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, + tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_ptr, scales_colwise_ptr, + reinterpret_cast(noop->data.dptr), + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, + dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +} + +namespace detail { + +using Empty = transformer_engine::Empty; + +__device__ inline float identity(float value, const Empty &) { return value; } + +struct DequantizeParam { + const float *scale_inv; +}; + +__device__ inline float dequantize_func(float value, + const DequantizeParam ¶m) { + return value * (*(param.scale_inv)); +} + +} // namespace detail + +template +void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, + Tensor *output, cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = + (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || + is_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, + stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, + const Tensor *input, Tensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = + (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input->data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input->data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || + is_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryGradKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, + stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +namespace { + +static bool is_full_tile_1D_tensor(const Tensor *const t) { + const size_t N = product(t->data.shape); + const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); + return isFullTile; +} + +bool dimensions_supported_by_TMA(const Tensor *const t) { + const size_t cols = t->flat_last_dim(); + constexpr int TMA_bytes = 16; + const int alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); + return cols % alignment_requirement == 0; +} + +} // namespace + +// Supported by the Arch >= 10.0 +template +void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, + const Tensor *noop, Tensor *output, Tensor *dbias, + Tensor *workspace, cudaStream_t stream) { + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (!IS_DBIAS && !IS_DACT) { + if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && + is_aligned_tensor_data(input, TMA_gmem_alignment) && + is_aligned_tensor_data(*output, TMA_gmem_alignment)) { + // Aligned AND FP8 + cast_fp8_1D(input, output, stream); + } else { + // Unaligned + CastVectorizedUnaryKernelLauncher(input, noop, output, + stream); + } + } else if (!IS_DBIAS && IS_DACT) { + if (dimensions_supported_by_TMA(output) && + is_fp8_dtype(output->dtype()) && + is_aligned_tensor_data(input, TMA_gmem_alignment) && + is_aligned_tensor_data(*output, TMA_gmem_alignment) && + is_aligned_tensor_data(*act_input, TMA_gmem_alignment)) { + // Aligned AND FP8 (+dAct) + cast_fp8_2D(input, act_input, output, + dbias, workspace, stream); + } else { + // Unaligned + CastVectorizedUnaryGradKernelLauncher(input, act_input, + output, stream); + } + } else { + cast_fp8_2D(input, act_input, output, + dbias, workspace, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8_quantize( + input, act_input, noop, output, dbias, workspace, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + + to_string(output->scaling_mode) + "."); + } +} + +// Supported by the Arch < 10.0 +template +void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, + const Tensor *noop, Tensor *output, Tensor *dbias, + Tensor *workspace, cudaStream_t stream) { + if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) { + // zhongboz: should we just ignore IS_ACT here? + NVTE_ERROR("Not implemented scaling mode or fusion: " + + to_string(output->scaling_mode) + + " on GPU with compute capability < 10.0."); + } + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (!IS_DACT) { + CastVectorizedUnaryKernelLauncher(input, noop, output, + stream); + } else { + CastVectorizedUnaryGradKernelLauncher(input, act_input, + output, stream); + } + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + + to_string(output->scaling_mode) + "."); + } +} + +template +void fp8_quantize(const Tensor &input, const Tensor *act_input, + const Tensor *noop, Tensor *output, Tensor *dbias, + Tensor *workspace, cudaStream_t stream) { + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias != nullptr); + CheckOutputTensor(*dbias, "dbias"); + } + if constexpr (IS_DACT) { + NVTE_CHECK(act_input != nullptr); + CheckInputTensor(*act_input, "activation_input"); + NVTE_CHECK(input.dtype() == act_input->dtype(), + "Types of both inputs must match."); + NVTE_CHECK(input.data.shape == act_input->data.shape, + "Shapes of both inputs must match."); + } + + NVTE_CHECK(!is_fp8_dtype(input.dtype()), + "Input must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, + "Input and output shapes need to match."); + + // Supported by the Arch >= 10.0 + if (is_supported_by_CC_100()) { + fp8_quantize_arch_ge_100( + input, act_input, noop, output, dbias, workspace, stream); + } else { + // Supported by the Arch < 10.0 + fp8_quantize_arch_l_100( + input, act_input, noop, output, dbias, workspace, stream); + } +} + +namespace detail { + +template +void quantize_helper(const NVTETensor input, const NVTETensor grad, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { + const Tensor *input_tensor; + const Tensor *activation_input_tensor; + if constexpr (IS_DBIAS || IS_DACT) { + // backward - input is incoming gradient + input_tensor = convertNVTETensorCheck(grad); + activation_input_tensor = convertNVTETensor(input); + } else { + // forward = input is activation input + input_tensor = convertNVTETensorCheck(input); + activation_input_tensor = nullptr; + } + auto output_tensor = convertNVTETensorCheck(output); + auto dbias_tensor = convertNVTETensor(dbias); + auto workspace_tensor = convertNVTETensor(workspace); + + const QuantizationConfig *quant_config_cpp = + reinterpret_cast(quant_config); + + // extract noop tensor from quant_config_cpp if it's not null + const NVTETensor noop = + quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; + const auto noop_tensor = + noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor(); + + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK( + output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { + cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + } + } else if (output_tensor->has_data()) { + fp8_quantize( + *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, + dbias_tensor, workspace_tensor, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8_quantize( + *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, + dbias_tensor, workspace_tensor, stream); + break; + } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), + "IS_DBIAS, IS_DACT, and IS_ACT not implemented for " + "NVTE_BLOCK_SCALING_2D"); + bool force_pow_2_scales = + quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; + float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; + quantize_transpose_square_blockwise( + input_tensor->data, output_tensor->scale_inv, + output_tensor->columnwise_scale_inv, output_tensor->data, + output_tensor->columnwise_data, epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), + force_pow_2_scales, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), + "IS_DBIAS, IS_DACT, and IS_ACT not implemented for " + "NVTE_BLOCK_SCALING_1D"); + bool force_pow_2_scales = + quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; + float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = + FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + bool rowwise_compact = + quant_config_cpp + ? quant_config_cpp->float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT + : false; + rowwise_option = rowwise_compact + ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT + : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + bool columnwise_compact = + quant_config_cpp + ? quant_config_cpp->float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT + : false; + columnwise_option = + columnwise_compact + ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT + : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } + quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, + output_tensor->columnwise_scale_inv, output_tensor->data, + output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + + to_string(output_tensor->scaling_mode) + "."); + } +} + +} // namespace detail +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ diff --git a/torchao/experimental/mxfp8_cpp/mxfp8_cuda.cu b/torchao/experimental/mxfp8_cpp/mxfp8_cuda.cu index f9d9fd6..f4ab6fe 100644 --- a/torchao/experimental/mxfp8_cpp/mxfp8_cuda.cu +++ b/torchao/experimental/mxfp8_cpp/mxfp8_cuda.cu @@ -1 +1,83 @@ -// .cu bridge +// CUDA bridge for MXFP8 quantization + +#include "mxfp8_quantize.cuh" +#include +#include +#include + +namespace mxfp8 { + +// Convert PyTorch scalar type to our DType enum +DType get_input_dtype(const torch::Tensor &t) { + switch (t.scalar_type()) { + case torch::kFloat32: + return DType::kFloat32; + case torch::kFloat16: + return DType::kFloat16; + case torch::kBFloat16: + return DType::kBFloat16; + case torch::kUInt8: + return DType::kByte; + default: + TORCH_CHECK(false, "Unsupported input tensor dtype: ", t.scalar_type()); + } +} + +// Convert FP8 format string to DType enum +DType get_output_dtype(const std::string &fp8_format) { + if (fp8_format.compare("e4m3") == 0) { + return DType::kFloat8E4M3; + } else { + TORCH_CHECK(false, "Unsupported FP8 format: ", fp8_format, + ". Only 'e4m3' is supported."); + } +} + +void mxfp8_quantize_cuda(const torch::Tensor &input, + torch::Tensor &output_rowwise, + torch::Tensor &output_colwise, + torch::Tensor &scales_rowwise, + torch::Tensor &scales_colwise, int64_t scale_dim_x, + int64_t scale_dim_y, const std::string &fp8_format) { + + // Get tensor properties + const int64_t rows = input.size(0); + const int64_t cols = input.size(1); + + // Get data pointers + const void *input_ptr = input.data_ptr(); + void *output_rowwise_ptr = + output_rowwise.numel() > 0 ? output_rowwise.data_ptr() : nullptr; + void *output_colwise_ptr = + output_colwise.numel() > 0 ? output_colwise.data_ptr() : nullptr; + e8m0_t *scales_rowwise_ptr = + scales_rowwise.numel() > 0 + ? reinterpret_cast(scales_rowwise.data_ptr()) + : nullptr; + e8m0_t *scales_colwise_ptr = + scales_colwise.numel() > 0 + ? reinterpret_cast(scales_colwise.data_ptr()) + : nullptr; + + // Get CUDA stream + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Call the quantization kernel + MXFP8Quantizer::quantize(input_ptr, output_rowwise_ptr, output_colwise_ptr, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, + get_input_dtype(input), get_output_dtype(fp8_format), + scale_dim_x, scale_dim_y, stream); + + // Synchronize the stream to ensure kernel completion + // cudaError_t sync_error = cudaStreamSynchronize(stream); + // TORCH_CHECK(sync_error == cudaSuccess, "CUDA stream synchronization failed: + // ", + // cudaGetErrorString(sync_error)); + + // Check for CUDA errors + cudaError_t error = cudaGetLastError(); + TORCH_CHECK(error == cudaSuccess, + "CUDA kernel failed: ", cudaGetErrorString(error)); +} + +} // namespace mxfp8 diff --git a/torchao/experimental/mxfp8_cpp/mxfp8_extension.cpp b/torchao/experimental/mxfp8_cpp/mxfp8_extension.cpp index 77c9432..ce70b39 100644 --- a/torchao/experimental/mxfp8_cpp/mxfp8_extension.cpp +++ b/torchao/experimental/mxfp8_cpp/mxfp8_extension.cpp @@ -1 +1,134 @@ // PyBind wrapping for the mxfp8 extension +#include +#include +#include +#include +#include + +namespace mxfp8 { + +// Forward declarations +void mxfp8_quantize_cuda(const torch::Tensor &input, + torch::Tensor &output_rowwise, + torch::Tensor &output_columnwise, + torch::Tensor &scales_rowwise, + torch::Tensor &scales_colwise, int64_t scale_dim_x, + int64_t scale_dim_y, const std::string &fp8_format); + +// Helper for tensor validation +void check_cuda_tensor(const torch::Tensor &t, const char *name) { + TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(t.is_contiguous(), name, " must be contiguous"); +} + +// Helper to validate FP8 format +void validate_fp8_format(const std::string &fp8_format) { + TORCH_CHECK(fp8_format.compare("e4m3") == 0, + "fp8_format must be 'e4m3', got: ", fp8_format); +} + +// Helper to validate scale dimensions +void validate_scale_dimensions(int64_t scale_dim_x, int64_t scale_dim_y) { + TORCH_CHECK(scale_dim_x == 1 || scale_dim_x == 32, + "scale_dim_x must be 1 or 32, got: ", scale_dim_x); + TORCH_CHECK(scale_dim_y == 1 || scale_dim_y == 32, + "scale_dim_y must be 1 or 32, got: ", scale_dim_y); +} + +// Main quantization function +std::tuple +mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise, + int64_t scale_dim_x, int64_t scale_dim_y, + const std::string &fp8_format) { + + // Validate inputs + check_cuda_tensor(input, "input"); + TORCH_CHECK(input.dim() == 2, "input must be 2D"); + TORCH_CHECK(input.scalar_type() == torch::kFloat32 || + input.scalar_type() == torch::kFloat16 || + input.scalar_type() == torch::kBFloat16, + "Input must be float32, float16, or bfloat16"); + TORCH_CHECK(rowwise || colwise, + "At least one of rowwise or colwise must be true"); + + validate_scale_dimensions(scale_dim_x, scale_dim_y); + validate_fp8_format(fp8_format); + + const int64_t rows = input.size(0); + const int64_t cols = input.size(1); + + c10::cuda::CUDAGuard device_guard(input.device()); + + // Create tensor options + const auto options_fp8 = torch::TensorOptions() + .dtype(torch::kUInt8) // FP8 stored as uint8 + .device(input.device()); + + const auto options_scale = torch::TensorOptions() + .dtype(torch::kUInt8) // E8M0 stored as uint8 + .device(input.device()); + + // Allocate output tensors + torch::Tensor output_rowwise, output_colwise; + torch::Tensor scales_rowwise, scales_colwise; + + if (rowwise) { + output_rowwise = torch::empty({rows, cols}, options_fp8); + const int64_t scale_cols = (cols + scale_dim_x - 1) / scale_dim_x; + scales_rowwise = torch::empty({rows, scale_cols}, options_scale); + } else { + output_rowwise = torch::empty({0}, options_fp8); + scales_rowwise = torch::empty({0}, options_scale); + } + + if (colwise) { + output_colwise = torch::empty({rows, cols}, options_fp8); + const int64_t scale_rows = (rows + scale_dim_y - 1) / scale_dim_y; + scales_colwise = torch::empty({scale_rows, cols}, options_scale); + } else { + output_colwise = torch::empty({0}, options_fp8); + scales_colwise = torch::empty({0}, options_scale); + } + + // Call CUDA kernel + mxfp8_quantize_cuda(input, output_rowwise, output_colwise, scales_rowwise, + scales_colwise, rowwise ? scale_dim_x : 1, + colwise ? scale_dim_y : 1, fp8_format); + + return std::make_tuple(output_rowwise, output_colwise, scales_rowwise, + scales_colwise); +} + +// Get scale tensor for given input shape - ceiling division +torch::Tensor get_scale_shape(torch::IntArrayRef input_shape, + int64_t block_size, bool transpose) { + + TORCH_CHECK(input_shape.size() == 2, "Input shape must be 2D"); + + int64_t rows = input_shape[0]; + int64_t cols = input_shape[1]; + + if (transpose) { + int64_t scale_rows = (rows + block_size - 1) / block_size; + return torch::empty({scale_rows, cols}, torch::kUInt8); + } else { + int64_t scale_cols = (cols + block_size - 1) / block_size; + return torch::empty({rows, scale_cols}, torch::kUInt8); + } +} + +} // namespace mxfp8 + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "MXFP8 Quantization PyTorch Extension"; + + m.def("quantize", &mxfp8::mxfp8_quantize, "MXFP8 quantization", + py::arg("input"), py::arg("rowwise") = true, py::arg("colwise") = false, + py::arg("scale_dim_x") = 32, py::arg("scale_dim_y") = 32, + py::arg("fp8_format") = "e4m3"); + + m.def("get_scale_shape", &mxfp8::get_scale_shape, + "Get shape of scale tensor for given input shape", + py::arg("input_shape"), py::arg("block_size") = 32, + py::arg("transpose") = false); +} diff --git a/torchao/experimental/mxfp8_cpp/mxfp8_quantize.cuh b/torchao/experimental/mxfp8_cpp/mxfp8_quantize.cuh index 6dd6d81..ca6cb25 100644 --- a/torchao/experimental/mxfp8_cpp/mxfp8_quantize.cuh +++ b/torchao/experimental/mxfp8_cpp/mxfp8_quantize.cuh @@ -1 +1,737 @@ -// TODO - main functiona +// MFXP8 quantization kernel and utilities +// Credit - this is a derivative work from TransformerEngine +// https://github.com/NVIDIA/TransformerEngine +// License - Apache 2.0, +// https://github.com/NVIDIA/TransformerEngine/blob/main/LICENSE + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +// Use official CUDA PTX library +#include "ptx.cuh" +#include +#include + +#define MIN_CUDA_SM 1000 // SM90 = 900, SM100 = 1000 + +// Check if we're compiling for supported architecture +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < MIN_CUDA_SM) +#warning \ + "MXFP8 quantization requires SM90+ (Hopper) or SM100+ (Blackwell) architecture. Kernel will be disabled for this architecture." +#endif + +// Architecture detection for native FP8 support +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 +#define HAS_NATIVE_FP8_CONVERSION 1 +#else +#define HAS_NATIVE_FP8_CONVERSION 0 +#endif + +enum class DType { + kByte, + kFloat32, + kFloat16, + kBFloat16, + kFloat8E4M3, + kFloat8E5M2 +}; + +// E8M0 type for MXFP8 scaling factors +using e8m0_t = uint8_t; + +// Software FP8 conversion for older architectures +__device__ __forceinline__ uint8_t software_float_to_fp8e4m3(float val) { + // Handle special cases + if (__isnanf(val)) + return 0x7F; // NaN + if (__isinff(val)) + return val > 0 ? 0x7E : 0xFE; // +/-Inf -> +/-Max + + // Clamp to FP8 range before conversion + constexpr float max_fp8_val = 448.0f; + val = fminf(fmaxf(val, -max_fp8_val), max_fp8_val); + + uint32_t bits = __float_as_uint(val); + uint32_t sign = (bits >> 31) & 0x1; + int32_t exp = ((bits >> 23) & 0xFF) - 127; // Remove bias + uint32_t mantissa = bits & 0x7FFFFF; + + // E4M3 has 4-bit exponent, 3-bit mantissa + // Bias is 7 for E4M3 + exp += 7; + + // Clamp exponent to valid range [0, 15] + if (exp <= 0) { + // Underflow to zero + return sign << 7; + } else if (exp >= 15) { + // Overflow to max finite value + return (sign << 7) | 0x7E; + } + + // Round mantissa to 3 bits + uint32_t rounded_mantissa = (mantissa + (1 << 19)) >> 20; // Round to nearest + if (rounded_mantissa >= 8) { + // Mantissa overflow, increment exponent + exp++; + rounded_mantissa = 0; + if (exp >= 15) { + return (sign << 7) | 0x7E; // Max finite value + } + } + + return (sign << 7) | (exp << 3) | rounded_mantissa; +} + +// Perf: Hardware FP8 conversion for Blackwell +template +__device__ __forceinline__ FP8Type convert_to_fp8(float val); + +// FP8 types - 1 byte each exactly +struct __align__(1) fp8e4m3 { + uint8_t data; + + fp8e4m3() = default; + + __device__ __forceinline__ fp8e4m3(float val) { + // Use software conversion for all architectures + data = software_float_to_fp8e4m3(val); + } + + __device__ __forceinline__ fp8e4m3 &operator=(float val) { + // Use software conversion for all architectures + data = software_float_to_fp8e4m3(val); + return *this; + } +}; + +// Template specialization for fp8e4m3 conversion +template <> +__device__ __forceinline__ fp8e4m3 convert_to_fp8(float val) { + return fp8e4m3(val); +} + +// TODO - Vectorized FP8 conversion for Blackwell? +__device__ inline void convert_float4_to_fp8x4(float4 in, fp8e4m3 *out, + float scale) { +#if HAS_NATIVE_FP8_CONVERSION + // TODO: Check if Blackwell supports vectorized FP8 conversion + // For now, just use scalar conversion + out[0] = fp8e4m3(in.x * scale); + out[1] = fp8e4m3(in.y * scale); + out[2] = fp8e4m3(in.z * scale); + out[3] = fp8e4m3(in.w * scale); +#else + // Scalar conversion for older architectures + out[0] = fp8e4m3(in.x * scale); + out[1] = fp8e4m3(in.y * scale); + out[2] = fp8e4m3(in.z * scale); + out[3] = fp8e4m3(in.w * scale); +#endif +} + +// Constants for MXFP8 +constexpr size_t MXFP8_CHUNK_DIM_Y = 64; +constexpr size_t MXFP8_CHUNK_DIM_X = 64; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK = + MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; +constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; +constexpr size_t MXFP8_BUFFERS_NUM = 2; +constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1; + +constexpr size_t ELEMS_PER_THREAD = 16; +constexpr size_t MXFP8_BUFFER_DIM_Y = 32; +constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; +constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; +constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; + +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = + MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; +constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = + MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; +constexpr size_t MXFP8_BUFF_STAGES_NUM = + MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; +constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; + +constexpr size_t THREADS_PER_WARP = 32; // lol + +// Utility macros +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) + +// Vector type for loading/storing multiple elements +template struct Vec { + union { + T elt[N]; + } data; + + __device__ inline void clear() { +#pragma unroll + for (int i = 0; i < N; ++i) { + data.elt[i] = T(0); + } + } + + __device__ inline void load_from(const T *ptr) { +#pragma unroll + for (int i = 0; i < N; ++i) { + data.elt[i] = ptr[i]; + } + } + + __device__ inline void store_to(T *ptr) const { +#pragma unroll + for (int i = 0; i < N; ++i) { + ptr[i] = data.elt[i]; + } + } +}; + +// Math utilities +// fast reciprocal of 2, raised to a power (1/2^x), no floating point ops +__device__ __forceinline__ float exp2f_rcp(const e8m0_t biased_exp) { + constexpr uint32_t exponent_nbits = 8; + constexpr uint32_t exponent_bias = 127; + constexpr uint32_t biased_exponent_offset = + exponent_bias - (1 << (exponent_nbits - 1)); + + uint32_t biased_exponent = + static_cast(biased_exp) + biased_exponent_offset; + uint32_t reversed_biased_exponent = ~biased_exponent + 1; + uint32_t scale_bits = (reversed_biased_exponent & 0xFF) << 23; + + return __int_as_float(scale_bits); +} + +// conversion from float to scale +__device__ __forceinline__ e8m0_t float_to_e8m0(const float val) { + constexpr uint32_t mantissa_nbits = 23; + constexpr uint32_t exponent_nbits = 8; + constexpr uint32_t exponent_bias = 127; + constexpr uint32_t biased_exponent_offset = + exponent_bias - (1 << (exponent_nbits - 1)); + + uint32_t bits = __float_as_uint(val); + uint32_t biased_exponent = + ((bits >> mantissa_nbits) & ((1 << exponent_nbits) - 1)); + uint32_t e8m0_biased_exponent = biased_exponent - biased_exponent_offset; + + return static_cast(e8m0_biased_exponent); +} + +// Quantization limits +template struct Quantized_Limits { + static constexpr float max_norm = 448.0f; // For E4M3 + static constexpr float max_norm_rcp = 1.0f / max_norm; +}; + +// Warp reduction utilities +template +__device__ float subwarp_reduce_max_broadcast(float val) { +#pragma unroll + for (int mask = SUBWARP_WIDTH / 2; mask > 0; mask /= 2) { + val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, mask, SUBWARP_WIDTH)); + } + return val; +} + +template __device__ float reduce_max(float val, int warp_id) { + static __shared__ float shared[NUM_WARPS]; + + val = __shfl_sync(0xffffffff, val, 0); + + if (threadIdx.x % THREADS_PER_WARP == 0) { + shared[warp_id] = val; + } + __syncthreads(); + + if (warp_id == 0) { + val = (threadIdx.x < NUM_WARPS) ? shared[threadIdx.x] : 0.0f; + +#pragma unroll + for (int mask = NUM_WARPS / 2; mask > 0; mask /= 2) { + val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, mask)); + } + } + + return __shfl_sync(0xffffffff, val, 0); +} + +// Atomic max for float +__device__ inline void atomicMaxFloat(float *addr, float value) { + unsigned int *address_as_uint = (unsigned int *)addr; + unsigned int old = *address_as_uint, assumed; + do { + assumed = old; + old = atomicCAS(address_as_uint, assumed, + __float_as_uint(fmaxf(value, __uint_as_float(assumed)))); + } while (assumed != old); +} + +// TMA descriptor creation +inline CUtensorMapDataType get_dtype_for_tma(DType dtype) { + switch (dtype) { + case DType::kFloat32: + return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + case DType::kFloat16: + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + case DType::kBFloat16: + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + case DType::kFloat8E4M3: + case DType::kFloat8E5M2: + case DType::kByte: + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + default: + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } +} + +inline void create_2D_tensor_map(CUtensorMap &tensorMap, void *data_ptr, + DType dtype, size_t rows, size_t cols, + uint32_t shmem_y, uint32_t shmem_x) { + // Get function pointer to cuTensorMapEncodeTiled + static void *driver_ptr = nullptr; + if (!driver_ptr) { + cudaDriverEntryPointQueryResult result; + cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, + cudaEnableDefault, &result); + } + auto cuTensorMapEncodeTiled = + reinterpret_cast(driver_ptr); + + constexpr uint32_t rank = 2; + uint64_t size[rank] = {cols, rows}; + uint64_t stride[rank - 1] = {cols}; // bytes + uint32_t boxSize[rank] = {shmem_x, shmem_y}; + uint32_t elemStride[rank] = {1, 1}; + + cuTensorMapEncodeTiled( + &tensorMap, get_dtype_for_tma(dtype), rank, data_ptr, size, stride, + boxSize, elemStride, CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); +} + +// Helper functions for TMA operations +__device__ inline void copy_2d_to_shared(void *smem, + const CUtensorMap *tensor_map, + uint32_t x, uint32_t y, + size_t smem_size, uint64_t *mbar, + bool is_master) { +#if __CUDA_ARCH__ >= 1000 + if (is_master) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(smem), + reinterpret_cast(tensor_map), x, y, mbar); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(mbar, smem_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(mbar); + } +#endif +} + +//////////////////////////////////////////////////////////////////////////////// +// MXFP8 quantization kernel +//////////////////////////////////////////////////////////////////////////////// + +// Main MXFP8 quantization kernel (with TMA) +template +__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + mxfp8_quantize_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= MIN_CUDA_SM) + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = + SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; + constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = + SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; + constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = + SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; + constexpr size_t SCALES_COLWISE_PER_BLOCK_X = + SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; + + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; + + const int block_offset_Y = + blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; + const int block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; + const int scales_rowwise_block_offset_Y = + blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; + const int scales_rowwise_block_offset_X = + blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; + const int scales_colwise_block_offset_Y = + blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; + const int scales_colwise_block_offset_X = + blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; + + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + + const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + + // Shared memory buffers + __shared__ alignas(128) + IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) OType + out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) OType + out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + float block_amax = 0; + + // Initialize barriers + __shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; + initialize_barriers( + mbar, is_master_thread); + + int parity = 0; + +// Process chunks +#pragma unroll + for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { + const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; + const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; + + const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; + const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int scales_rowwise_chunk_offset_Y = + scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = + scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = + scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = + scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + +// Prefetch initial data +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; + ++prefetch_buff) { + const int chunk_stage_offset_Y = + chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, + chunk_stage_offset_X, chunk_stage_offset_Y, + shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + +// Process iterations +#pragma unroll + for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { + const int buff = iter % MXFP8_BUFFERS_NUM; + const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + + // Prefetch next iteration data + if (next_iter < MXFP8_ITERATIONS) { + const int next_buff = next_iter % MXFP8_BUFFERS_NUM; + const int chunk_it_offset_y = + chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, + &mbar[next_iter], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + + // Row-wise scaling + if constexpr (USE_ROWWISE_SCALING) { + Vec in; + Vec out_c; + + const int iteration_scale_rowwise_offset_Y = + scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + +#pragma unroll + for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X_rowwise; + + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = (row >= rows); + + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); + + float thread_amax = 0; + float in_compute[ELEMS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + const bool col_out_of_bounds = + (chunk_offset_X + shmem_offset_x + j >= cols); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float elt = static_cast(in.data.elt[j]); + in_compute[j] = elt; + + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + + block_amax = fmaxf(block_amax, thread_amax); + + const float subwarp_amax = + subwarp_reduce_max_broadcast(thread_amax); + const e8m0_t biased_exponent = float_to_e8m0( + subwarp_amax * Quantized_Limits::max_norm_rcp); + + // Write scaling factor + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + + tid_rowwise_Y; + const int global_scales_offset_X = + scales_rowwise_chunk_offset_X + + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + const int scale_idx = + global_scales_offset_Y * scale_stride_rowwise + + global_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent; + } + + const float block_scale_inverse = exp2f_rcp(biased_exponent); + + // Use hardware-accelerated conversion when available +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { +#if HAS_NATIVE_FP8_CONVERSION + // Hardware conversion on Blackwell + out_c.data.elt[j] = + static_cast(in_compute[j] * block_scale_inverse); +#else + // Software conversion for older architectures + out_c.data.elt[j] = + convert_to_fp8(in_compute[j] * block_scale_inverse); +#endif + } + out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); + } + } + + // Column-wise scaling + if constexpr (USE_COLWISE_SCALING) { + const bool col_out_of_bounds = (chunk_offset_X + tid_colwise_X >= cols); + float in_compute[SCALE_DIM_Y]; + + float amax = 0; +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const size_t row = row_base + i; + const bool row_out_of_bounds = (row >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float elt = static_cast(in_sh[buff][i][tid_colwise_X]); + in_compute[i] = elt; + + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } + + block_amax = fmaxf(block_amax, amax); + + const e8m0_t biased_exponent = + float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; + const int global_scales_offset_X = + scales_colwise_chunk_offset_X + tid_colwise_X; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = exp2f_rcp(biased_exponent); + + // Use hardware-accelerated conversion when available +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { +#if HAS_NATIVE_FP8_CONVERSION + // Hardware conversion on Blackwell + out_colwise_sh[buff][i][tid_colwise_X] = + static_cast(in_compute[i] * block_scale_inverse); +#else + // Software conversion for older architectures + out_colwise_sh[buff][i][tid_colwise_X] = + convert_to_fp8(in_compute[i] * block_scale_inverse); +#endif + } + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = + chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (USE_ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), + chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(&out_rowwise_sh[buff])); + } + if constexpr (USE_COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), + chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(&out_colwise_sh[buff])); + } + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + } + + destroy_barriers(mbar, is_master_thread); +#endif +} + +// Simple wrapper class for MXFP8 quantization +class MXFP8Quantizer { +public: + // Quantize a tensor using MXFP8 + // input: pointer to input data + // output_rowwise: pointer to row-wise quantized output (can be nullptr) + // output_colwise: pointer to column-wise quantized output (can be nullptr) + // scales_rowwise: pointer to row-wise scaling factors (required if + // output_rowwise is not null) scales_colwise: pointer to column-wise scaling + // factors (required if output_colwise is not null) rows, cols: tensor + // dimensions input_dtype: data type of input output_dtype: FP8 output type + // (fp8e4m3 or fp8e5m2) scale_dim_x: block size for row-wise scaling + // (typically 32) scale_dim_y: block size for column-wise scaling (typically + // 32) + static void quantize(const void *input, void *output_rowwise, + void *output_colwise, e8m0_t *scales_rowwise, + e8m0_t *scales_colwise, size_t rows, size_t cols, + DType input_dtype, DType output_dtype, + size_t scale_dim_x = 32, size_t scale_dim_y = 32, + cudaStream_t stream = 0) { + + // Check parameters + assert(scale_dim_x == 1 || scale_dim_x == 32); + assert(scale_dim_y == 1 || scale_dim_y == 32); + assert(output_rowwise != nullptr || output_colwise != nullptr); + + if (output_rowwise) + assert(scales_rowwise != nullptr); + if (output_colwise) + assert(scales_colwise != nullptr); + + // Calculate grid dimensions + const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); + const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); + + const dim3 block(MXFP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + // Calculate scale strides + size_t scale_stride_rowwise = output_rowwise ? DIVUP(cols, scale_dim_x) : 1; + size_t scale_stride_colwise = output_colwise ? cols : 1; + + // Create TMA descriptors + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + create_2D_tensor_map(tensor_map_input, const_cast(input), + input_dtype, rows, cols, MXFP8_SHMEM_DIM_Y, + MXFP8_SHMEM_DIM_X); + + if (output_rowwise) { + create_2D_tensor_map(tensor_map_output_rowwise, output_rowwise, + output_dtype, rows, cols, MXFP8_SHMEM_DIM_Y, + MXFP8_SHMEM_DIM_X); + } + + if (output_colwise) { + create_2D_tensor_map(tensor_map_output_colwise, output_colwise, + output_dtype, rows, cols, MXFP8_SHMEM_DIM_Y, + MXFP8_SHMEM_DIM_X); + } + +// Launch kernel based on input/output types and scaling dimensions +// Only compile kernel launches for SM90+ +#if defined(__CUDACC__) && \ + (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= MIN_CUDA_SM) + + // Use TMA and mbarrier instructions +#define LAUNCH_KERNEL(IType, OType, SCALE_Y, SCALE_X) \ + mxfp8_quantize_kernel \ + <<>>( \ + tensor_map_input, tensor_map_output_rowwise, \ + tensor_map_output_colwise, scales_rowwise, scales_colwise, rows, \ + cols, scale_stride_rowwise, scale_stride_colwise) +#endif + + // Dispatch based on types + if (input_dtype == DType::kFloat32) { + if (output_dtype == DType::kFloat8E4M3) { + if (scale_dim_x == 32 && scale_dim_y == 32) { + LAUNCH_KERNEL(float, fp8e4m3, 32, 32); + } else if (scale_dim_x == 32 && scale_dim_y == 1) { + LAUNCH_KERNEL(float, fp8e4m3, 1, 32); + } else if (scale_dim_x == 1 && scale_dim_y == 32) { + LAUNCH_KERNEL(float, fp8e4m3, 32, 1); + } + } + } + +#undef LAUNCH_KERNEL + } +}; diff --git a/torchao/experimental/mxfp8_cpp/ptx.cuh b/torchao/experimental/mxfp8_cpp/ptx.cuh new file mode 100644 index 0000000..101a073 --- /dev/null +++ b/torchao/experimental/mxfp8_cpp/ptx.cuh @@ -0,0 +1,318 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ptx.cuh + * \brief BW PTX + */ + +#ifndef TRANSFORMER_ENGINE_PTX_CUH_ +#define TRANSFORMER_ENGINE_PTX_CUH_ + +#include +#include + +// namespace transformer_engine { +namespace ptx { + +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init +__device__ __forceinline__ void mbarrier_init(uint64_t *mbar, + const uint32_t count) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval +__device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive +__device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.shared.b64 _, [%0];" ::"r"(mbar_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive +__device__ __forceinline__ void +mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile( + "mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), + "r"(tx_count) + : "memory"); +} + +__device__ __forceinline__ void fence_mbarrier_init_release_cluster() { + asm volatile("fence.mbarrier_init.release.cluster;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// global -> shared::cluster +__device__ __forceinline__ void +cp_async_bulk_tensor_1d_global_to_shared(uint64_t *dst_shmem, + const uint64_t *src_global_ptr, + const uint32_t size, uint64_t *mbar) { + uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + // triggers async copy, i.e. the thread continues until wait() on mbarrier + // barrier condition: + // - leader must arrive (i.e. 1 thread as set above) + // - TMA hardware substracts bytes from expect_tx counter, must reach zero + asm volatile("cp.async.bulk.shared::cta.global" + ".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ::"r"( + dst_shmem_ptr), + "l"(src_global_ptr), "r"(size), "r"(mbar_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// global -> shared::cluster +__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( + uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, + const uint32_t offset_x, const uint32_t offset_y, uint64_t *mbar) { + uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + // triggers async copy, i.e. the thread continues until wait() on mbarrier + // barrier condition: + // - leader must arrive (i.e. 1 thread as set above) + // - TMA hardware substracts bytes from expect_tx counter, must reach zero + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"( + dst_shmem_ptr), + "l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr) + : "memory"); +} + +__device__ __forceinline__ bool +mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { + uint32_t waitComplete; + asm volatile("{\n\t .reg .pred P_OUT; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P_OUT; \n" + "}" + : "=r"(waitComplete) + : "r"(mbar_ptr), "r"(parity) + : "memory"); + return static_cast(waitComplete); +} + +__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, + const uint32_t parity) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { + } +} + +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// shared::cta -> global +__device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global( + uint64_t *dst_global_ptr, const uint64_t *src_shmem, const uint32_t size) { + uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); + asm volatile( + "cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" ::"l"( + dst_global_ptr), + "r"(src_shmem_ptr), "r"(size) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// shared::cta -> global +__device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( + const uint64_t *tensor_map_ptr, const uint32_t offset_x, + const uint32_t offset_y, uint64_t *src_shmem) { + uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, " + "{%1, %2}], [%3];" ::"l"(tensor_map_ptr), + "r"(offset_x), "r"(offset_y), "r"(src_shmem_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +__device__ __forceinline__ void cp_async_bulk_wait_group() { + asm volatile("cp.async.bulk.wait_group 0;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +template +__device__ __forceinline__ void cp_async_bulk_wait_group_read() { + asm volatile("cp.async.bulk.wait_group.read 0;"); +} + +template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() { + asm volatile("cp.async.bulk.wait_group.read 0;"); +} +template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() { + asm volatile("cp.async.bulk.wait_group.read 1;"); +} +template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() { + asm volatile("cp.async.bulk.wait_group.read 2;"); +} +template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { + asm volatile("cp.async.bulk.wait_group.read 4;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group +__device__ __forceinline__ void cp_async_bulk_commit_group() { + asm volatile("cp.async.bulk.commit_group;"); +} + +// Proxy fence (bi-directional): +__device__ __forceinline__ void fence_proxy_async() { + asm volatile("fence.proxy.async;"); +} + +__device__ __forceinline__ void fence_proxy_async_shared_cta() { + asm volatile("fence.proxy.async.shared::cta;"); +} + +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + +} // namespace ptx + +namespace { + +template +__forceinline__ __device__ void +initialize_barriers(uint64_t *mbar, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initialize barrier. All `blockDim.x * blockDim.y` threads in block + // participate. +#pragma unroll + for (int iter = 0; iter < num_barriers; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__forceinline__ __device__ void destroy_barriers(uint64_t *mbar, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Destroy barrier. This invalidates the memory region of the barrier. If + // further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int iter = 0; iter < num_barriers; ++iter) { + ptx::mbarrier_invalid(&mbar[iter]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, + const size_t num_bytes, + uint64_t *barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_1d_global_to_shared( + reinterpret_cast(dst), + reinterpret_cast(src), num_bytes, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void +copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X, + const size_t chunk_Y, const size_t num_bytes, + uint64_t *barrier, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst), + reinterpret_cast(src), chunk_X, chunk_Y, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_2d_to_sharedx2( + void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, + void *dst2, const void *src2, const size_t chunk_X2, const size_t chunk_Y2, + const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst), + reinterpret_cast(src), chunk_X1, chunk_Y1, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst2), + reinterpret_cast(src2), chunk_X2, chunk_Y2, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, 2 * num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_2d_to_sharedx3( + void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, + void *dst2, const void *src2, const size_t chunk_X2, const size_t chunk_Y2, + void *dst3, const void *src3, const size_t chunk_X3, const size_t chunk_Y3, + const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst), + reinterpret_cast(src), chunk_X1, chunk_Y1, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst2), + reinterpret_cast(src2), chunk_X2, chunk_Y2, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst3), + reinterpret_cast(src3), chunk_X3, chunk_Y3, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, 3 * num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +} // namespace +//} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_PTX_CUH_ diff --git a/torchao/experimental/mxfp8_cpp/setup.py b/torchao/experimental/mxfp8_cpp/setup.py index 2837f0f..02e5787 100644 --- a/torchao/experimental/mxfp8_cpp/setup.py +++ b/torchao/experimental/mxfp8_cpp/setup.py @@ -1,41 +1,87 @@ """ setup.py - Build configuration for MXFP8 PyTorch extension + +This extension requires NVIDIA BlACKWELL architecture (SM100+) or newer. + """ -import os +import sys -# Get CUDA compute capability import torch from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension -if torch.cuda.is_available(): - capability = torch.cuda.get_device_capability() - cuda_arch = f"{capability[0]}{capability[1]}" -else: - # Default to Hopper if CUDA not available during build - cuda_arch = "90" +# MXFP8 requires Hopper (SM90+) or Blackwell (SM100+) architecture - enforce this requirement +REQUIRED_CUDA_ARCHITECTURES = ["100"] # Hopper and Blackwell architectures +MIN_CUDA_MAJOR = 10 # Minimum major version (Hopper = 9.0, Blackwell = 10.0) + + +def validate_cuda_capability(): + """Validate that the system has compatible CUDA capability.""" + if not torch.cuda.is_available(): + print("WARNING: CUDA not available during build. Proceeding with SM90 target.") + print("Make sure you have a Hopper+ GPU (H100, H200, etc.) for runtime.") + return False + + try: + capability = torch.cuda.get_device_capability() + major, minor = capability + + if major < MIN_CUDA_MAJOR: + print( + f"ERROR: MXFP8 extension requires NVIDIA Hopper architecture (SM90+)." + ) + print(f"Current GPU: SM{major}{minor}") + print(f"Supported GPUs: H100 (SM90a), H200 (SM90a), and newer Hopper+ GPUs") + print( + f"Please use a compatible GPU or remove this extension from your build." + ) + return False -# Set up compilation flags + print(f"Detected compatible GPU: SM{major}{minor}") + return True + + except Exception as e: + print(f"WARNING: Could not detect GPU capability: {e}") + print("Proceeding with SM90 target. Ensure you have Hopper+ GPU for runtime.") + return True + + +# Validate CUDA capability before proceeding +if not validate_cuda_capability(): + sys.exit(1) + +# Set up compilation flags for Hopper+ and Blackwell architectures nvcc_flags = [ - f"-arch=sm_{cuda_arch}", "-std=c++17", "-O3", "-lineinfo", "--use_fast_math", - "-gencode", - f"arch=compute_{cuda_arch},code=sm_{cuda_arch}", + "--generate-line-info", + # Target both Hopper (SM90a) and Blackwell (SM100) architectures + # "-gencode=arch=compute_90a,code=sm_90a", + "-gencode=arch=compute_100,code=sm_100", + # Override PyTorch's default architecture list to include both + "-DTORCH_CUDA_ARCH_LIST=10.0", # 9.0a; + # Enable experimental features for Hopper and Blackwell + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + # Add flags to handle device/host function issues + "--extended-lambda", + "--default-stream=per-thread", + # Suppress warnings that cause compilation failures + "-Xcudafe", + "--diag_suppress=esa_on_defaulted_function_ignored", + "-Xcudafe", + "--diag_suppress=field_without_dll_interface", + "-Xcudafe", + "--diag_suppress=base_class_has_different_dll_interface", + "-Xcudafe", + "--diag_suppress=code_is_unreachable", + # Minimal PTX options + "--ptxas-options=--warn-on-spills", ] -# For newer GPUs (Hopper+), enable additional optimizations -if int(cuda_arch) >= 90: - nvcc_flags.extend( - [ - "--ptxas-options=-v", - "--generate-line-info", - ] - ) - setup( name="mxfp8_cuda", ext_modules=[ @@ -47,12 +93,16 @@ ], include_dirs=[ ".", # For mxfp8_quantize.cuh, mxfp8_extension.cpp, and mxfp8_cuda.cu + "/usr/local/cuda-12.8/include", # CUDA 12.8 headers + ], + library_dirs=[ + "/usr/local/cuda-12.8/lib64", # CUDA 12.8 libraries ], extra_compile_args={ "cxx": ["-std=c++17", "-O3"], "nvcc": nvcc_flags, }, - extra_link_args=["-lcuda"], + extra_link_args=["-lcuda", "-lcudart"], ), ], cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)}, diff --git a/torchao/experimental/mxfp8_cpp/test_cpp_extension.py b/torchao/experimental/mxfp8_cpp/test_cpp_extension.py new file mode 100644 index 0000000..2be5208 --- /dev/null +++ b/torchao/experimental/mxfp8_cpp/test_cpp_extension.py @@ -0,0 +1,206 @@ +import time + +import numpy as np +import torch + +try: + import mxfp8_cuda +except ImportError: + print("MXFP8 extension not found or ready. Aborting...") + exit(1) + + +def print_tensor_info(tensor, name): + """Print information about a tensor""" + print(f"{name}:") + print(f" Shape: {tensor.shape}") + print(f" Dtype: {tensor.dtype}") + print(f" Device: {tensor.device}") + print(f" Size in bytes: {tensor.numel() * tensor.element_size()}") + + +def test_basic_quantization(): + """Test basic MXFP8 quantization functionality""" + print("=== Testing Basic MXFP8 Quantization ===") + + # Create test input + rows, cols = 512, 512 + input_tensor = torch.randn(rows, cols, device="cuda", dtype=torch.float32) + + # Test rowwise quantization + print("\n1. Row-wise quantization only:") + output_rowwise, _, scales_rowwise, _ = mxfp8_cuda.quantize( + input_tensor, + rowwise=True, + colwise=False, + scale_dim_x=32, + scale_dim_y=1, + fp8_format="e4m3", + ) + + print_tensor_info(output_rowwise, "Output (row-wise)") + print_tensor_info(scales_rowwise, "Scales (row-wise)") + # print(f"{scales_rowwise[0]=}") + # print(f"{output_rowwise[0,0]=}") + + # Test colwise quantization + print("\n2. Column-wise quantization only:") + _, output_colwise, _, scales_colwise = mxfp8_cuda.quantize( + input_tensor, + rowwise=False, + colwise=True, + scale_dim_x=1, + scale_dim_y=32, + fp8_format="e4m3", + ) + + print_tensor_info(output_colwise, "Output (column-wise)") + print_tensor_info(scales_colwise, "Scales (column-wise)") + + # Test both rowwise and colwise + print("\n3. Both row-wise and column-wise quantization:") + output_rowwise, output_colwise, scales_rowwise, scales_colwise = ( + mxfp8_cuda.quantize( + input_tensor, + rowwise=True, + colwise=True, + scale_dim_x=32, + scale_dim_y=32, + fp8_format="e4m3", + ) + ) + torch.cuda.synchronize() + + print_tensor_info(output_rowwise, "Output (row-wise)") + print_tensor_info(output_colwise, "Output (column-wise)") + print_tensor_info(scales_rowwise, "Scales (row-wise)") + print_tensor_info(scales_colwise, "Scales (column-wise)") + # print(f"Maximum absolute value: {amax.item():.6f}") + + +def test_numerical_accuracy(): + """Test numerical accuracy of quantization""" + print("\n=== Numerical Accuracy Testing ===") + + # Create input with known values + input_tensor = torch.tensor( + [ + [1.0, 2.0, 3.0, 4.0] * 8, # 32 elements + [0.5, 1.5, 2.5, 3.5] * 8, + [-1.0, -2.0, -3.0, -4.0] * 8, + [-0.5, -1.5, -2.5, -3.5] * 8, + ], + device="cuda", + dtype=torch.float32, + ) + + # Quantize with block size 32 (so each row is one block) + output_rowwise, _, scales_rowwise, _ = mxfp8_cuda.quantize( + input_tensor, + rowwise=True, + colwise=False, + scale_dim_x=32, + scale_dim_y=1, + ) + + print(f"Input shape: {input_tensor.shape}") + print(f"Output shape: {output_rowwise.shape}") + print(f"Scales shape: {scales_rowwise.shape}") + + # Check that scales capture the max values per row + input_cpu = input_tensor.cpu().numpy() + for i in range(input_tensor.shape[0]): + row_max = np.abs(input_cpu[i]).max() + print(f"Row {i}: max abs value = {row_max:.3f}") + + +def test_different_block_sizes(): + """Test different block sizes""" + print("\n=== Testing Different Block Sizes ===") + + input_tensor = torch.randn(128, 128, device="cuda") + + # Test 1x32 blocks (row-wise with 32-element blocks) + print("\n1x32 blocks:") + output_rowwise, _, scales_rowwise, _ = mxfp8_cuda.quantize( + input_tensor, rowwise=True, colwise=False, scale_dim_x=32, scale_dim_y=1 + ) + print(f"Scales shape for 1x32 blocks: {scales_rowwise.shape}") + + # Test 32x1 blocks (column-wise with 32-element blocks) + print("\n32x1 blocks:") + _, output_colwise, _, scales_colwise = mxfp8_cuda.quantize( + input_tensor, rowwise=False, colwise=True, scale_dim_x=1, scale_dim_y=32 + ) + print(f"Scales shape for 32x1 blocks: {scales_colwise.shape}") + + +def test_performance(): + """Test performance of MXFP8 quantization""" + print("\n=== Performance Testing ===") + + sizes = [ + (256, 256), + (512, 512), + (1024, 1024), + (2048, 2048), + (4096, 4096), + (8192, 8192), + ] + + for rows, cols in sizes: + input_tensor = torch.randn(rows, cols, device="cuda") + + # Warmup + for _ in range(10): + _ = mxfp8_cuda.quantize(input_tensor, rowwise=True, colwise=True) + + torch.cuda.synchronize() + + # Time the operation + num_iterations = 100 + start_time = time.time() + + for _ in range(num_iterations): + _ = mxfp8_cuda.quantize(input_tensor, rowwise=True, colwise=True) + + torch.cuda.synchronize() + end_time = time.time() + + avg_time = (end_time - start_time) / num_iterations * 1000 # ms + throughput = (rows * cols * 4) / (avg_time / 1000) / 1e9 # GB/s + + print(f"\nSize: {rows}x{cols}") + print(f"Average time: {avg_time:.3f} ms") + print(f"Throughput: {throughput:.2f} GB/s") + + +def main(): + """Main test function""" + # Check CUDA availability + if not torch.cuda.is_available(): + print("CUDA is not available. Exiting.") + return + + # Check compute capability + capability = torch.cuda.get_device_capability() + print(f"GPU: {torch.cuda.get_device_name()}") + print(f"Compute Capability: {capability[0]}.{capability[1]}") + + if capability[0] < 9: + print("Warning: MXFP8 is optimized for compute capability 9.0+ (Hopper)") + print("The extension may not work correctly on older GPUs.") + + # Run tests + test_basic_quantization() + + test_different_block_sizes() + test_numerical_accuracy() + # test_error_handling() + test_performance() + + print("\n=== All tests completed ===") + + +if __name__ == "__main__": + main()