diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp index 23d0587bbdc5d..bf04bb400790f 100644 --- a/csrc/quantization/cutlass_w8a8/common.hpp +++ b/csrc/quantization/cutlass_w8a8/common.hpp @@ -17,3 +17,11 @@ inline uint32_t next_pow_2(uint32_t const num) { return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device); + return max_shared_mem_per_block_opt_in; +} + diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 740b9fb64a754..38a20a1727d18 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -250,12 +250,39 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, CUTLASS_CHECK(status); } +template +void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... args) { + // In some cases, the GPU isn't able to accommodate the + // shared memory requirements of the Gemm. In such cases, use + // the FallbackGemm instead. + static const int max_shared_mem_per_block_opt_in = + get_cuda_max_shared_memory_per_block_opt_in(0); + + size_t const gemm_shared_mem_size = + sizeof(typename Gemm::KernelType::SharedStorage); + size_t const fallback_gemm_shared_mem_size = + sizeof(typename FallbackGemm::KernelType::SharedStorage); + + if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) { + return cutlass_gemm_caller(out, a, b, + std::forward(args)...); + } else { + TORCH_CHECK(fallback_gemm_shared_mem_size <= + max_shared_mem_per_block_opt_in); + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } +} + template typename Epilogue> struct sm80_config_default { // This config is used in 2 cases, // - M in (128, inf) // - M in (64, 128] and N >= 8192 + // Shared Memory required by this Gemm - 81920 bytes static_assert(std::is_same()); using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; @@ -271,6 +298,7 @@ struct sm80_config_M64 { // This config is used in 2 cases, // - M in (32, 64] // - M in (64, 128] and N < 8192 + // Shared Memory required by this Gemm - 122880 bytes static_assert(std::is_same()); using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; @@ -284,6 +312,7 @@ template typename Epilogue> struct sm80_config_M32 { // M in (16, 32] + // Shared Memory required by this Gemm - 61440 bytes static_assert(std::is_same()); using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; @@ -297,6 +326,7 @@ template typename Epilogue> struct sm80_config_M16 { // M in [1, 16] + // Shared Memory required by this Gemm - 51200 bytes static_assert(std::is_same()); using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; @@ -331,35 +361,45 @@ void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a, using Cutlass2xGemmM16 = typename sm80_config_M16::Cutlass2xGemm; + // Due to shared memory requirements, some Gemms may fail to run on some + // GPUs. As the name indicates, the Fallback Gemm is used as an alternative + // in such cases. + // sm80_config_M16 has the least shared-memory requirement. However, + // based on some profiling, we select sm80_config_M32 as a better alternative + // performance wise. + using FallbackGemm = + typename sm80_config_M32::Cutlass2xGemm; + uint32_t const m = a.size(0); uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); // next power of 2 if (mp2 <= 16) { // M in [1, 16] - return cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 32) { // M in (16, 32] - return cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 64) { // M in (32, 64] - return cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 128) { // M in (64, 128] uint32_t const n = out.size(1); bool const small_n = n < 8192; if (small_n) { - return cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else { - return cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } } else { // M in (128, inf) - return cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } }