Skip to content

Commit d2afb51

Browse files
authored
[Refactor] Introduce GemmInst for different targets handling (#688)
* [Enhancement] Refactor GEMM operations for improved warp partitioning and target instruction handling - Introduced a new `GetGemmInst` method to determine the appropriate GEMM instruction based on block size and target architecture. - Updated `ComputeWarpPartition` to accept the GEMM instruction type, enhancing flexibility in warp partitioning logic. - Added `TargetGetWarpSize` utility to streamline warp size retrieval based on target architecture. - Refactored layout inference and lowering methods to utilize the new GEMM instruction handling, improving clarity and maintainability of the codebase. * bug fix * test fix * lint fix
1 parent 73bf834 commit d2afb51

File tree

4 files changed

+45
-36
lines changed

4 files changed

+45
-36
lines changed

src/op/gemm.cc

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,35 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
5858
}
5959
}
6060

61-
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
62-
bool maybe_hopper_wgmma) const {
61+
Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
62+
int warp_size = TargetGetWarpSize(target);
63+
int num_warps = block_size / warp_size;
64+
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
65+
(num_warps % 4 == 0) && CheckWGMMA();
66+
if (allow_wgmma) {
67+
return GemmInst::kWGMMA;
68+
} else if (TargetIsCDNA(target)) {
69+
return GemmInst::kMFMA;
70+
} else if (TargetIsCuda(target)) {
71+
return GemmInst::kMMA;
72+
} else {
73+
ICHECK(0) << "Unsupported target for gemm: " << target->str();
74+
}
75+
}
76+
77+
std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
78+
GemmInst gemm_inst,
79+
Target target) const {
80+
int num_warps = block_size / TargetGetWarpSize(target);
6381
int m_warp = 1, n_warp = 1;
6482
constexpr int kMPerWarp = 16; // Rows processed by a single warp
6583
constexpr int kNPerWarp = 8; // Columns processed by a single warp
66-
bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
67-
(this->M >= 64) && (num_warps % 4 == 0);
84+
6885
ICHECK(this->M % kMPerWarp == 0)
6986
<< "M must be divisible by " << kMPerWarp << ", but got " << this->M;
7087
ICHECK(this->N % kNPerWarp == 0)
7188
<< "N must be divisible by " << kNPerWarp << ", but got " << this->N;
72-
if (allow_wgmma) {
89+
if (gemm_inst == GemmInst::kWGMMA) {
7390
ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";
7491

7592
constexpr int kGroup = 4; // Number of warps in a warp-group
@@ -268,16 +285,9 @@ bool Gemm::CheckWGMMA() const {
268285
}
269286

270287
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
271-
int warp_size = 32;
272-
if (TargetIsCDNA(T.target)) {
273-
warp_size = 64;
274-
}
275288
auto block_size = *as_const_int(T.thread_bounds->extent);
276-
bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
277-
(block_size / warp_size % 4 == 0) && CheckWGMMA();
278-
279-
auto [warp_m, warp_n] =
280-
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
289+
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
290+
auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
281291

282292
std::stringstream ss;
283293
std::string op_name = "tl::gemm_ss";
@@ -295,7 +305,7 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
295305
// for cdna gemm, we need to specify kPack
296306
ss << ", " << kPack;
297307
} else if (TargetIsHopper(T.target)) {
298-
ss << ", " << (maybe_wgmma ? "true" : "false");
308+
ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
299309
}
300310
if (wg_wait != 0) {
301311
ss << ", " << wg_wait;
@@ -321,10 +331,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
321331
ICHECK(C.scope() == "local.fragment");
322332
auto thread_range = T.thread_bounds;
323333
auto block_size = *as_const_int(thread_range->extent);
334+
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
335+
auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
336+
324337
if (TargetIsVolta(T.target)) {
325-
const int warp_size = 32;
326-
auto [warp_m, warp_n] =
327-
ComputeWarpPartition(block_size / warp_size, T.target);
328338
auto fragment =
329339
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
330340
results.Set(C, fragment->BindThreadRange(thread_range));
@@ -347,9 +357,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
347357
*as_const_int(B->shape[dim_B - 1]),
348358
false, trans_B ? 2 : 1));
349359
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
350-
const int warp_size = 32;
351-
auto [warp_m, warp_n] =
352-
ComputeWarpPartition(block_size / warp_size, T.target);
353360
auto fragment =
354361
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
355362
results.Set(C, fragment->BindThreadRange(thread_range));
@@ -383,13 +390,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
383390
ICHECK(0);
384391
}
385392
} else if (TargetIsHopper(T.target)) {
386-
const int warp_size = 32;
387-
bool maybe_wgmma =
388-
(this->M >= 64) && (block_size / warp_size % 4 == 0) && CheckWGMMA();
389-
auto [warp_m, warp_n] =
390-
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
391393
auto fragment =
392-
maybe_wgmma
394+
gemm_inst == GemmInst::kWGMMA
393395
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
394396
C->dtype.bits())
395397
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
@@ -401,7 +403,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
401403
const int64_t continuity =
402404
trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
403405
auto ABLayout =
404-
maybe_wgmma
406+
gemm_inst == GemmInst::kWGMMA
405407
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
406408
A->dtype.bits(), trans_A ? 1 : 2)
407409
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
@@ -419,7 +421,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
419421
const int64_t continuity =
420422
trans_B ? mat_continuous : mat_continuous / warp_n;
421423
auto ABLayout =
422-
maybe_wgmma
424+
gemm_inst == GemmInst::kWGMMA
423425
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
424426
B->dtype.bits(), trans_B ? 2 : 1)
425427
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
@@ -429,10 +431,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
429431
ICHECK(0) << "WGMMA only support B in shared.";
430432
}
431433
} else if (TargetIsCDNA(T.target)) {
432-
const int warp_size = 64;
433-
auto [warp_m, warp_n] =
434-
ComputeWarpPartition(block_size / warp_size, T.target);
435-
436434
auto fragment =
437435
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
438436
results.Set(C, fragment->BindThreadRange(thread_range));

src/op/gemm.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@ class Gemm : public Operator {
2727
} policy;
2828

2929
private:
30-
std::pair<int, int>
31-
ComputeWarpPartition(int num_warps, Target target,
32-
bool maybe_hopper_wgmma = true) const;
30+
// Target GEMM instruction
31+
enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA };
32+
GemmInst GetGemmInst(int block_size, Target target) const;
33+
34+
std::pair<int, int> ComputeWarpPartition(int num_warps, GemmInst gemm_inst,
35+
Target target) const;
3336

3437
bool CheckWGMMA() const;
3538
Array<PrimExpr> call_args;

src/target/utils.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,5 +97,12 @@ bool TargetHasStmatrix(Target target) {
9797
return arch >= 90;
9898
}
9999

100+
int TargetGetWarpSize(Target target) {
101+
int res = 32;
102+
if (TargetIsCDNA(target))
103+
res = 64;
104+
return res;
105+
}
106+
100107
} // namespace tl
101108
} // namespace tvm

src/target/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ bool TargetIsCDNA(Target target);
2424
bool TargetHasAsyncCopy(Target target);
2525
bool TargetHasLdmatrix(Target target);
2626
bool TargetHasStmatrix(Target target);
27+
int TargetGetWarpSize(Target target);
2728

2829
} // namespace tl
2930
} // namespace tvm

0 commit comments

Comments
 (0)