diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index 7af29caac..af051279b 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -73,7 +73,7 @@ template __global__ void _dequantize_int4_kernel( const at::PackedTensorAccessor32 in, at::PackedTensorAccessor32 out, - at::optional> scales_and_zeros = c10::nullopt) + at::optional> scales_and_zeros = std::nullopt) { constexpr int32_t kNTileSize = 8; @@ -85,16 +85,16 @@ __global__ void _dequantize_int4_kernel( // n dimension that this lane loads from auto n0 = nTile * kNTileSize + (t / 4); - + // 8 k-tile values, 4 per m16n8k16 mma.sync operand B // int32_t ks[8]; //Only need 4 offsets since TC layout for single tile is 2x2 (2 pairs of 2 contiguous values) int32_t ks[4]; - // Store address base offset + // Store address base offset auto pOut = &out[n0][0]; - -// Unpack 2 k-tiles at a time since min pack size is InnerKTiles = 2 + +// Unpack 2 k-tiles at a time since min pack size is InnerKTiles = 2 #pragma unroll for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) { //Tensor-core layout for m16n8k16 is such that each tile has 2 pairs of 2 contiguous values @@ -111,7 +111,7 @@ __global__ void _dequantize_int4_kernel( // inner k-tiles unpack two at a time int32_t pack = in[nTile][kOuterTile][t][innerKTile / 2]; - + if constexpr(kDequant) { // static_assert(scales_and_zeros.has_value(), "scales_and_zeros must be set when dequantizing"); static_assert(std::is_same::value, "Out must be BFloat16 when dequantizing"); @@ -119,7 +119,7 @@ __global__ void _dequantize_int4_kernel( // // Extract u4, convert to s4 by subtracting by 2 ** nbits / 2, then convert to bfloat16 bf16x2x4 v_bf16x2x4 = convert_i4x8_to_bf16x2x4(pack); - + // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; @@ -132,7 +132,7 @@ __global__ void _dequantize_int4_kernel( #pragma unroll for (int i = 0; i < 4; i++) { reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2); - } + } } else { static_assert(std::is_same::value, "Out must be int32_t when unpacking to int"); @@ -150,8 +150,8 @@ __global__ void _dequantize_int4_kernel( #pragma unroll for (int i = 0; i < 4; ++i) { - reinterpret_cast(&pOut[ks[i]])[0] = v_i32x2[i]; - } + reinterpret_cast(&pOut[ks[i]])[0] = v_i32x2[i]; + } } } } @@ -164,7 +164,7 @@ at::Tensor _dequantize_tensor_core_tiled_layout( const at::Tensor& packed_w, const at::Tensor& scales_and_zeros, int64_t group_size, - int64_t innerKTiles) + int64_t innerKTiles) { constexpr int32_t kNTileSize = 8; @@ -201,7 +201,7 @@ at::Tensor _dequantize_tensor_core_tiled_layout( auto stream = at::cuda::getCurrentCUDAStream(); dim3 grid(kSuperTiles, nTiles); - + #define RUN_DEQUANT(QGROUPSIZE) \ do { \ switch(innerKTiles) { \ @@ -259,7 +259,7 @@ at::Tensor _dequantize_tensor_core_tiled_layout( // input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] at::Tensor _unpack_tensor_core_tiled_layout( const at::Tensor& packed_w, - int64_t innerKTiles) + int64_t innerKTiles) { c10::cuda::CUDAGuard g(packed_w.device()); @@ -288,7 +288,7 @@ at::Tensor _unpack_tensor_core_tiled_layout( auto stream = at::cuda::getCurrentCUDAStream(); dim3 grid(kSuperTiles, nTiles); - + if (innerKTiles == 2) { _dequantize_int4_kernel<<>>( packed_w.packed_accessor32(),