Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Add GPU architecture guards to the CUTLASS w8a8 kernels to reduce binary size #5157

Merged
merged 8 commits into from
Jun 5, 2024
105 changes: 70 additions & 35 deletions csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,44 @@ using namespace cute;

namespace {

template <typename Arch, typename ElementAB_, typename ElementD_,
typename TileShape, typename WarpShape, typename InstructionShape,
int32_t MainLoopStages>
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm75_to_sm80 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};

template <typename Kernel>
struct enable_sm80_to_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};

template <typename Kernel>
struct enable_sm89_to_sm90 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};

template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
Expand Down Expand Up @@ -101,7 +136,7 @@ struct cutlass_2x_gemm {
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
float, cutlass::layout::RowMajor, 4,
Expand All @@ -112,7 +147,7 @@ struct cutlass_2x_gemm {
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
MainLoopStages, Operator,
1 /* epilogue stages */
>::GemmKernel;
>::GemmKernel>;
// clang-format on

using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
Expand Down Expand Up @@ -208,16 +243,16 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;

if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<
cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<
cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::half_t, TileShape,
WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
}
}

Expand All @@ -235,16 +270,16 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;

if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<
cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<
cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::half_t, TileShape,
WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
}
}

Expand All @@ -263,32 +298,32 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8);

if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<
cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
} else {
assert(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<
cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);

if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
}
}
}
19 changes: 17 additions & 2 deletions csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ uint32_t next_pow_2(uint32_t const num) {
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}

// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm90_or_later : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};

template <typename ElementAB_, typename ElementD_, typename TileShape,
typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
Expand Down Expand Up @@ -126,9 +141,9 @@ struct cutlass_3x_gemm {
KernelSchedule>::CollectiveOp;
// clang-format on

using KernelType = cutlass::gemm::kernel::GemmUniversal<
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>;
cutlass::gemm::PersistentScheduler>>;

struct GemmKernel : public KernelType {};
};
Expand Down
Loading