Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Sundar Rabindranath committed Jun 14, 2024
1 parent 2748e67 commit 67409d3
Showing 1 changed file with 98 additions and 28 deletions.
126 changes: 98 additions & 28 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType, int32_t M, bool IsSmallN>
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, int32_t M,
bool IsSmallN> // IsSmallN is true if N < 8192
struct sm90_int8_config {
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule =
Expand All @@ -287,12 +289,14 @@ struct sm90_int8_config {
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType, bool IsSmallN>
struct sm90_int8_config<InType, OutType, 128, IsSmallN> {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
bool IsSmallN>
struct sm90_int8_config<InType, OutType, Epilogue, 128, IsSmallN> {
// Specialization for M in (64, 128] and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule =
Expand All @@ -301,47 +305,51 @@ struct sm90_int8_config<InType, OutType, 128, IsSmallN> {
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType, bool IsSmallN>
struct sm90_int8_config<InType, OutType, 64, IsSmallN> {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
bool IsSmallN>
struct sm90_int8_config<InType, OutType, Epilogue, 64, IsSmallN> {
// Specialization for M in (32, 64] and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType>
struct sm90_int8_config<InType, OutType, 32, false> {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config<InType, OutType, Epilogue, 32, false> {
// Specialization for M in [1, 32] and N >= 8192
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _256>;
using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType>
struct sm90_int8_config<InType, OutType, 32, true> {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config<InType, OutType, Epilogue, 32, true> {
// Specialization for M in [1, 32] and N < 8192
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};

} // namespace
Expand All @@ -357,9 +365,9 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);

static const int32_t MDimDontCare = 0;

using Cutlass3xGemmDefault =
typename sm90_fp8_config<InType, OutType, Epilogue, 0>::Cutlass3xGemm;
typename sm90_fp8_config<InType, OutType, Epilogue,
MDimDontCare>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_fp8_config<InType, OutType, Epilogue, 64>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
Expand All @@ -384,6 +392,70 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
}
}

template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);

static const int32_t MDimDontCare = 0;
static const bool NDimDontCare = false;

// Same config for Large N and Small N
using Cutlass3xGemmDefault =
typename sm90_int8_config<InType, OutType, Epilogue, MDimDontCare,
NDimDontCare>::Cutlass3xGemm;
// Same config for Large N and Small N
using Cutlass3xGemmM128 =
typename sm90_int8_config<InType, OutType, Epilogue, 128,
NDimDontCare>::Cutlass3xGemm;
// Same config for Large N and Small N
using Cutlass3xGemmM64 =
typename sm90_int8_config<InType, OutType, Epilogue, 64,
NDimDontCare>::Cutlass3xGemm;
// Different configs for Large N and Small N
using Cutlass3xGemmM32LargeN =
typename sm90_int8_config<InType, OutType, Epilogue, 32,
false>::Cutlass3xGemm;
using Cutlass3xGemmM32SmallN =
typename sm90_int8_config<InType, OutType, Epilogue, 32,
true>::Cutlass3xGemm;

uint32_t const n = a.size(1);
bool const is_small_n = n < 8192;

uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2

if (mp2 <= 32) {
// m in [1, 32]
if (is_small_n) {
return cutlass_gemm_caller<Cutlass3xGemmM32SmallN>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
return cutlass_gemm_caller<Cutlass3xGemmM32LargeN>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
} else if (mp2 <= 64) {
// m in (32, 64]
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}

void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
Expand All @@ -395,15 +467,13 @@ void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8);

if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_3x_gemm<
int8_t, cutlass::bfloat16_t, ScaledEpilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>>(out, a, b, a_scales, b_scales);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
ScaledEpilogue>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);

return cutlass_gemm_caller<
cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>>(
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t,
ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
} else {
Expand Down

0 comments on commit 67409d3

Please sign in to comment.