Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] [Kernel] Add Cutlass2x fallback kernels #5744

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions csrc/quantization/cutlass_w8a8/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,11 @@ inline uint32_t next_pow_2(uint32_t const num) {
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}

inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need the device parameter? Seems like you're not using it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes. the 0 in

  cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
                        cudaDevAttrMaxSharedMemoryPerBlockOptin,
                        0);

is the device. Ill fix that! thanks.

int max_shared_mem_per_block_opt_in = 0;
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
device);
return max_shared_mem_per_block_opt_in;
}

52 changes: 46 additions & 6 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,39 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK(status);
}

template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
// In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use
// the FallbackGemm instead.
static const int max_shared_mem_per_block_opt_in =
get_cuda_max_shared_memory_per_block_opt_in(0);

size_t const gemm_shared_mem_size =
sizeof(typename Gemm::KernelType::SharedStorage);
size_t const fallback_gemm_shared_mem_size =
sizeof(typename FallbackGemm::KernelType::SharedStorage);

if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
return cutlass_gemm_caller<Gemm>(out, a, b,
std::forward<EpilogueArgs>(args)...);
} else {
TORCH_CHECK(fallback_gemm_shared_mem_size <=
max_shared_mem_per_block_opt_in);
return cutlass_gemm_caller<FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}

template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_default {
// This config is used in 2 cases,
// - M in (128, inf)
// - M in (64, 128] and N >= 8192
// Shared Memory required by this Gemm - 81920 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
Expand All @@ -271,6 +298,7 @@ struct sm80_config_M64 {
// This config is used in 2 cases,
// - M in (32, 64]
// - M in (64, 128] and N < 8192
// Shared Memory required by this Gemm - 122880 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
Expand All @@ -284,6 +312,7 @@ template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M32 {
// M in (16, 32]
// Shared Memory required by this Gemm - 61440 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
Expand All @@ -297,6 +326,7 @@ template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M16 {
// M in [1, 16]
// Shared Memory required by this Gemm - 51200 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
Expand Down Expand Up @@ -331,35 +361,45 @@ void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a,
using Cutlass2xGemmM16 =
typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;

// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm80_config_M16 has the least shared-memory requirement. However,
// based on some profiling, we select sm80_config_M32 as a better alternative
// performance wise.
using FallbackGemm =
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;

uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return cutlass_gemm_caller<Cutlass2xGemmM16>(
return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return cutlass_gemm_caller<Cutlass2xGemmM32>(
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return cutlass_gemm_caller<Cutlass2xGemmM64>(
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
uint32_t const n = out.size(1);
bool const small_n = n < 8192;
if (small_n) {
return cutlass_gemm_caller<Cutlass2xGemmM128SmallN>(
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
return cutlass_gemm_caller<Cutlass2xGemmM128BigN>(
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
} else {
// M in (128, inf)
return cutlass_gemm_caller<Cutlass2xGemmDefault>(
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
Expand Down
Loading