Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 31 additions & 33 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,35 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
}
}

std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma) const {
Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
}
}

std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
GemmInst gemm_inst,
Target target) const {
int num_warps = block_size / TargetGetWarpSize(target);
int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp
constexpr int kNPerWarp = 8; // Columns processed by a single warp
bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
(this->M >= 64) && (num_warps % 4 == 0);

ICHECK(this->M % kMPerWarp == 0)
<< "M must be divisible by " << kMPerWarp << ", but got " << this->M;
ICHECK(this->N % kNPerWarp == 0)
<< "N must be divisible by " << kNPerWarp << ", but got " << this->N;
if (allow_wgmma) {
if (gemm_inst == GemmInst::kWGMMA) {
ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";

constexpr int kGroup = 4; // Number of warps in a warp-group
Expand Down Expand Up @@ -268,16 +285,9 @@ bool Gemm::CheckWGMMA() const {
}

Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int warp_size = 32;
if (TargetIsCDNA(T.target)) {
warp_size = 64;
}
auto block_size = *as_const_int(T.thread_bounds->extent);
bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
(block_size / warp_size % 4 == 0) && CheckWGMMA();

auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);

std::stringstream ss;
std::string op_name = "tl::gemm_ss";
Expand All @@ -295,7 +305,7 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
// for cdna gemm, we need to specify kPack
ss << ", " << kPack;
} else if (TargetIsHopper(T.target)) {
ss << ", " << (maybe_wgmma ? "true" : "false");
ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
}
if (wg_wait != 0) {
ss << ", " << wg_wait;
Expand All @@ -321,10 +331,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(C.scope() == "local.fragment");
auto thread_range = T.thread_bounds;
auto block_size = *as_const_int(thread_range->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);

if (TargetIsVolta(T.target)) {
const int warp_size = 32;
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment =
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
Expand All @@ -347,9 +357,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
*as_const_int(B->shape[dim_B - 1]),
false, trans_B ? 2 : 1));
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
const int warp_size = 32;
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment =
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
Expand Down Expand Up @@ -383,13 +390,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(0);
}
} else if (TargetIsHopper(T.target)) {
const int warp_size = 32;
bool maybe_wgmma =
(this->M >= 64) && (block_size / warp_size % 4 == 0) && CheckWGMMA();
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
auto fragment =
maybe_wgmma
gemm_inst == GemmInst::kWGMMA
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
C->dtype.bits())
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
Expand All @@ -401,7 +403,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
const int64_t continuity =
trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
auto ABLayout =
maybe_wgmma
gemm_inst == GemmInst::kWGMMA
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
A->dtype.bits(), trans_A ? 1 : 2)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
Expand All @@ -419,7 +421,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n;
auto ABLayout =
maybe_wgmma
gemm_inst == GemmInst::kWGMMA
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
Expand All @@ -429,10 +431,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(0) << "WGMMA only support B in shared.";
}
} else if (TargetIsCDNA(T.target)) {
const int warp_size = 64;
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target);

auto fragment =
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
Expand Down
9 changes: 6 additions & 3 deletions src/op/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ class Gemm : public Operator {
} policy;

private:
std::pair<int, int>
ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma = true) const;
// Target GEMM instruction
enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA };
GemmInst GetGemmInst(int block_size, Target target) const;

std::pair<int, int> ComputeWarpPartition(int num_warps, GemmInst gemm_inst,
Target target) const;

bool CheckWGMMA() const;
Array<PrimExpr> call_args;
Expand Down
7 changes: 7 additions & 0 deletions src/target/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,12 @@ bool TargetHasStmatrix(Target target) {
return arch >= 90;
}

int TargetGetWarpSize(Target target) {
int res = 32;
if (TargetIsCDNA(target))
res = 64;
return res;
}

} // namespace tl
} // namespace tvm
1 change: 1 addition & 0 deletions src/target/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ bool TargetIsCDNA(Target target);
bool TargetHasAsyncCopy(Target target);
bool TargetHasLdmatrix(Target target);
bool TargetHasStmatrix(Target target);
int TargetGetWarpSize(Target target);

} // namespace tl
} // namespace tvm
Expand Down