-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Update Cutlass int8 kernel configs for SM90 #5514
Merged
comaniac
merged 7 commits into
vllm-project:main
from
neuralmagic:cutlass-h100-i8-configs
Jun 20, 2024
Merged
Changes from 4 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -278,6 +278,80 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 64> { | |
KernelSchedule, EpilogueSchedule>; | ||
}; | ||
|
||
template <typename InType, typename OutType, | ||
template <typename, typename, typename> typename Epilogue, int32_t M, | ||
bool IsSmallN> // IsSmallN is true if N < 8192 | ||
struct sm90_int8_config { | ||
static_assert(std::is_same<InType, int8_t>()); | ||
using KernelSchedule = | ||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; | ||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; | ||
using TileShape = Shape<_128, _128, _128>; | ||
using ClusterShape = Shape<_2, _1, _1>; | ||
using Cutlass3xGemm = | ||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, | ||
KernelSchedule, EpilogueSchedule>; | ||
}; | ||
|
||
template <typename InType, typename OutType, | ||
template <typename, typename, typename> typename Epilogue, | ||
bool IsSmallN> | ||
struct sm90_int8_config<InType, OutType, Epilogue, 128, IsSmallN> { | ||
// Specialization for M in (64, 128] and any N | ||
static_assert(std::is_same<InType, int8_t>()); | ||
using KernelSchedule = | ||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; | ||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; | ||
using TileShape = Shape<_64, _128, _128>; | ||
using ClusterShape = Shape<_2, _1, _1>; | ||
using Cutlass3xGemm = | ||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, | ||
KernelSchedule, EpilogueSchedule>; | ||
}; | ||
|
||
template <typename InType, typename OutType, | ||
template <typename, typename, typename> typename Epilogue, | ||
bool IsSmallN> | ||
struct sm90_int8_config<InType, OutType, Epilogue, 64, IsSmallN> { | ||
// Specialization for M in (32, 64] and any N | ||
static_assert(std::is_same<InType, int8_t>()); | ||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; | ||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; | ||
using TileShape = Shape<_64, _64, _256>; | ||
using ClusterShape = Shape<_1, _1, _1>; | ||
using Cutlass3xGemm = | ||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, | ||
KernelSchedule, EpilogueSchedule>; | ||
}; | ||
|
||
template <typename InType, typename OutType, | ||
template <typename, typename, typename> typename Epilogue> | ||
struct sm90_int8_config<InType, OutType, Epilogue, 32, false> { | ||
// Specialization for M in [1, 32] and N >= 8192 | ||
static_assert(std::is_same<InType, int8_t>()); | ||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; | ||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; | ||
using TileShape = Shape<_64, _128, _256>; | ||
using ClusterShape = Shape<_1, _4, _1>; | ||
using Cutlass3xGemm = | ||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, | ||
KernelSchedule, EpilogueSchedule>; | ||
}; | ||
|
||
template <typename InType, typename OutType, | ||
template <typename, typename, typename> typename Epilogue> | ||
struct sm90_int8_config<InType, OutType, Epilogue, 32, true> { | ||
// Specialization for M in [1, 32] and N < 8192 | ||
static_assert(std::is_same<InType, int8_t>()); | ||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; | ||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; | ||
using TileShape = Shape<_64, _64, _256>; | ||
using ClusterShape = Shape<_1, _8, _1>; | ||
using Cutlass3xGemm = | ||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, | ||
KernelSchedule, EpilogueSchedule>; | ||
}; | ||
|
||
} // namespace | ||
|
||
template <typename InType, typename OutType, | ||
|
@@ -290,8 +364,10 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, | |
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); | ||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); | ||
|
||
static const int32_t MDimDontCare = 0; | ||
using Cutlass3xGemmDefault = | ||
typename sm90_fp8_config<InType, OutType, Epilogue, 0>::Cutlass3xGemm; | ||
typename sm90_fp8_config<InType, OutType, Epilogue, | ||
MDimDontCare>::Cutlass3xGemm; | ||
using Cutlass3xGemmM64 = | ||
typename sm90_fp8_config<InType, OutType, Epilogue, 64>::Cutlass3xGemm; | ||
using Cutlass3xGemmM128 = | ||
|
@@ -316,6 +392,70 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, | |
} | ||
} | ||
|
||
template <typename InType, typename OutType, | ||
template <typename, typename, typename> typename Epilogue, | ||
typename... EpilogueArgs> | ||
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, | ||
torch::Tensor const& b, | ||
EpilogueArgs&&... args) { | ||
static_assert(std::is_same<InType, int8_t>()); | ||
TORCH_CHECK(a.dtype() == torch::kInt8); | ||
TORCH_CHECK(b.dtype() == torch::kInt8); | ||
|
||
static const int32_t MDimDontCare = 0; | ||
static const bool NDimDontCare = false; | ||
|
||
// Same config for Large N and Small N | ||
using Cutlass3xGemmDefault = | ||
typename sm90_int8_config<InType, OutType, Epilogue, MDimDontCare, | ||
NDimDontCare>::Cutlass3xGemm; | ||
// Same config for Large N and Small N | ||
using Cutlass3xGemmM128 = | ||
typename sm90_int8_config<InType, OutType, Epilogue, 128, | ||
NDimDontCare>::Cutlass3xGemm; | ||
// Same config for Large N and Small N | ||
using Cutlass3xGemmM64 = | ||
typename sm90_int8_config<InType, OutType, Epilogue, 64, | ||
NDimDontCare>::Cutlass3xGemm; | ||
// Different configs for Large N and Small N | ||
using Cutlass3xGemmM32LargeN = | ||
typename sm90_int8_config<InType, OutType, Epilogue, 32, | ||
false>::Cutlass3xGemm; | ||
using Cutlass3xGemmM32SmallN = | ||
typename sm90_int8_config<InType, OutType, Epilogue, 32, | ||
true>::Cutlass3xGemm; | ||
|
||
uint32_t const n = a.size(1); | ||
bool const is_small_n = n < 8192; | ||
|
||
uint32_t const m = a.size(0); | ||
uint32_t const mp2 = | ||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2 | ||
Comment on lines
+421
to
+422
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be replaced with the utility function introduced by #5275 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
if (mp2 <= 32) { | ||
// m in [1, 32] | ||
if (is_small_n) { | ||
return cutlass_gemm_caller<Cutlass3xGemmM32SmallN>( | ||
out, a, b, std::forward<EpilogueArgs>(args)...); | ||
} else { | ||
return cutlass_gemm_caller<Cutlass3xGemmM32LargeN>( | ||
out, a, b, std::forward<EpilogueArgs>(args)...); | ||
} | ||
} else if (mp2 <= 64) { | ||
// m in (32, 64] | ||
return cutlass_gemm_caller<Cutlass3xGemmM64>( | ||
out, a, b, std::forward<EpilogueArgs>(args)...); | ||
} else if (mp2 <= 128) { | ||
// m in (64, 128] | ||
return cutlass_gemm_caller<Cutlass3xGemmM128>( | ||
out, a, b, std::forward<EpilogueArgs>(args)...); | ||
} else { | ||
// m in (128, inf) | ||
return cutlass_gemm_caller<Cutlass3xGemmDefault>( | ||
out, a, b, std::forward<EpilogueArgs>(args)...); | ||
} | ||
} | ||
|
||
void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a, | ||
torch::Tensor const& b, | ||
torch::Tensor const& a_scales, | ||
|
@@ -326,22 +466,14 @@ void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a, | |
if (a.dtype() == torch::kInt8) { | ||
TORCH_CHECK(b.dtype() == torch::kInt8); | ||
|
||
using TileShape = Shape<_128, _128, _128>; | ||
using ClusterShape = Shape<_1, _2, _1>; | ||
using KernelSchedule = | ||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; | ||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; | ||
|
||
if (out.dtype() == torch::kBFloat16) { | ||
return cutlass_gemm_caller<cutlass_3x_gemm< | ||
int8_t, cutlass::bfloat16_t, ScaledEpilogue, TileShape, ClusterShape, | ||
KernelSchedule, EpilogueSchedule>>(out, a, b, a_scales, b_scales); | ||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t, | ||
ScaledEpilogue>( | ||
out, a, b, a_scales, b_scales); | ||
} else { | ||
TORCH_CHECK(out.dtype() == torch::kFloat16); | ||
|
||
return cutlass_gemm_caller< | ||
cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape, | ||
ClusterShape, KernelSchedule, EpilogueSchedule>>( | ||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, | ||
ScaledEpilogue>( | ||
out, a, b, a_scales, b_scales); | ||
} | ||
} else { | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One high level comment about these PRs: I think the template specialization makes it a little less clear what's going on than just having a different function name for each case, especially because the template parameters aren't really functional -- the functions never actually look at
IsSmallN
andM
This is a little bit of a nit, however, as the dispatching in these kernels is quite a few levels deep at this point and I think we'll want to clean it up anyway in a separate PR after this and #5560 land
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tlrmchlsmth .
I agree. I'll take a stab at making it better 👍
cutlass_2x_gemm/cutlass_3x_gemm
structs to a.cuh
, and[arch]_config
structs and the corresponding dispatch function into a.cuh
, then.cu