Skip to content

Commit

Permalink
c10::nullopt -> std::nullopt (#1032)
Browse files Browse the repository at this point in the history
Differential Revision: D64835967

Pull Request resolved: #1151
  • Loading branch information
r-barnes authored Oct 23, 2024
1 parent 4ef024c commit ad4c447
Showing 1 changed file with 14 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ template <typename Out_t, int InnerKTiles, int groupSize, bool kDequant = true>
__global__ void _dequantize_int4_kernel(
const at::PackedTensorAccessor32<int32_t, 4, at::RestrictPtrTraits> in,
at::PackedTensorAccessor32<Out_t, 2, at::RestrictPtrTraits> out,
at::optional<const at::PackedTensorAccessor32<c10::BFloat16, 3, at::RestrictPtrTraits>> scales_and_zeros = c10::nullopt)
at::optional<const at::PackedTensorAccessor32<c10::BFloat16, 3, at::RestrictPtrTraits>> scales_and_zeros = std::nullopt)
{

constexpr int32_t kNTileSize = 8;
Expand All @@ -85,16 +85,16 @@ __global__ void _dequantize_int4_kernel(

// n dimension that this lane loads from
auto n0 = nTile * kNTileSize + (t / 4);

// 8 k-tile values, 4 per m16n8k16 mma.sync operand B
// int32_t ks[8];
//Only need 4 offsets since TC layout for single tile is 2x2 (2 pairs of 2 contiguous values)
int32_t ks[4];

// Store address base offset
// Store address base offset
auto pOut = &out[n0][0];
// Unpack 2 k-tiles at a time since min pack size is InnerKTiles = 2

// Unpack 2 k-tiles at a time since min pack size is InnerKTiles = 2
#pragma unroll
for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) {
//Tensor-core layout for m16n8k16 is such that each tile has 2 pairs of 2 contiguous values
Expand All @@ -111,15 +111,15 @@ __global__ void _dequantize_int4_kernel(

// inner k-tiles unpack two at a time
int32_t pack = in[nTile][kOuterTile][t][innerKTile / 2];

if constexpr(kDequant) {
// static_assert(scales_and_zeros.has_value(), "scales_and_zeros must be set when dequantizing");
static_assert(std::is_same<Out_t, c10::BFloat16>::value, "Out must be BFloat16 when dequantizing");
// __nv_bfloat16 v[8];

// // Extract u4, convert to s4 by subtracting by 2 ** nbits / 2, then convert to bfloat16
bf16x2x4 v_bf16x2x4 = convert_i4x8_to_bf16x2x4(pack);

// All b values within a 16x16 tile should fall within the same q group
// Hence we load 1 scale and zero per loop
int qgroup = ks[0] / groupSize;
Expand All @@ -132,7 +132,7 @@ __global__ void _dequantize_int4_kernel(
#pragma unroll
for (int i = 0; i < 4; i++) {
reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2);
}
}
}
else {
static_assert(std::is_same<Out_t, int32_t>::value, "Out must be int32_t when unpacking to int");
Expand All @@ -150,8 +150,8 @@ __global__ void _dequantize_int4_kernel(

#pragma unroll
for (int i = 0; i < 4; ++i) {
reinterpret_cast<int2 *>(&pOut[ks[i]])[0] = v_i32x2[i];
}
reinterpret_cast<int2 *>(&pOut[ks[i]])[0] = v_i32x2[i];
}
}
}
}
Expand All @@ -164,7 +164,7 @@ at::Tensor _dequantize_tensor_core_tiled_layout(
const at::Tensor& packed_w,
const at::Tensor& scales_and_zeros,
int64_t group_size,
int64_t innerKTiles)
int64_t innerKTiles)
{

constexpr int32_t kNTileSize = 8;
Expand Down Expand Up @@ -201,7 +201,7 @@ at::Tensor _dequantize_tensor_core_tiled_layout(

auto stream = at::cuda::getCurrentCUDAStream();
dim3 grid(kSuperTiles, nTiles);

#define RUN_DEQUANT(QGROUPSIZE) \
do { \
switch(innerKTiles) { \
Expand Down Expand Up @@ -259,7 +259,7 @@ at::Tensor _dequantize_tensor_core_tiled_layout(
// input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2]
at::Tensor _unpack_tensor_core_tiled_layout(
const at::Tensor& packed_w,
int64_t innerKTiles)
int64_t innerKTiles)
{

c10::cuda::CUDAGuard g(packed_w.device());
Expand Down Expand Up @@ -288,7 +288,7 @@ at::Tensor _unpack_tensor_core_tiled_layout(

auto stream = at::cuda::getCurrentCUDAStream();
dim3 grid(kSuperTiles, nTiles);

if (innerKTiles == 2) {
_dequantize_int4_kernel<int32_t, 2, 0, false><<<grid, kWarpSize, 0, stream>>>(
packed_w.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>(),
Expand Down

0 comments on commit ad4c447

Please sign in to comment.