Skip to content

Commit 3e90be7

Browse files
committed
tcgen05 support
1 parent 73fa0af commit 3e90be7

File tree

28 files changed

+1110
-176
lines changed

28 files changed

+1110
-176
lines changed

docs/compiler_internals/inject_fence_proxy.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the
1717
### Timeline View
1818

1919
```
20-
generic initialize_descriptor → generic shared-store → async wgmma
20+
generic initialize_wgmma_descriptor → generic shared-store → async wgmma
2121
│ │ │
2222
└─ generic proxy ┴─ generic proxy ┴─ async proxy
2323
│ fence inserted here ↑
@@ -53,7 +53,7 @@ def kernel():
5353
with T.Kernel(1):
5454
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
5555
smem = T.decl_buffer((128,), "float16", scope="shared")
56-
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
56+
T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
5757
smem[0] = T.float16(0)
5858
T.ptx_wgmma_ss(
5959
"float16",
@@ -83,7 +83,7 @@ def kernel():
8383
with T.Kernel(1):
8484
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
8585
smem = T.decl_buffer((128,), "float16", scope="shared")
86-
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
86+
T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
8787
smem[0] = T.float16(0)
8888
T.fence_proxy_async()
8989
T.ptx_wgmma_ss(

src/layout/layout.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,12 @@ TVM_FFI_STATIC_INIT_BLOCK({
535535
return makeGemmABLayoutHopper(stride, mat_continuous, continuity,
536536
element_size, k_inner);
537537
})
538+
.def("tl.make_tcgen05mma_swizzled_layout",
539+
[](int stride, int mat_continuous, int continuity, int element_size,
540+
bool k_inner) {
541+
return makeGemmABLayoutSm100(stride, mat_continuous, continuity,
542+
element_size, k_inner);
543+
})
538544
.def("tl.make_full_bank_swizzled_layout",
539545
[](int stride, int continuous, int element_size) {
540546
return makeFullBankSwizzleLayout(stride, continuous, element_size);

src/op/builtin.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs)
154154
.set_attr<TCallEffectKind>("TCallEffectKind",
155155
Integer(CallEffectKind::kOpaque));
156156

157+
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
158+
.set_num_inputs(13)
159+
.set_attr<TCallEffectKind>("TCallEffectKind",
160+
Integer(CallEffectKind::kOpaque));
161+
157162
TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory)
158163
.set_num_inputs(2)
159164
.set_attr<TCallEffectKind>("TCallEffectKind",
@@ -270,11 +275,16 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
270275
.set_attr<TCallEffectKind>("TCallEffectKind",
271276
Integer(CallEffectKind::kPure));
272277

273-
TIR_DEFINE_TL_BUILTIN(initialize_descriptor)
278+
TIR_DEFINE_TL_BUILTIN(initialize_wgmma_descriptor)
274279
.set_num_inputs(5)
275280
.set_attr<TCallEffectKind>("TCallEffectKind",
276281
Integer(CallEffectKind::kOpaque));
277282

283+
TIR_DEFINE_TL_BUILTIN(initialize_tcgen05_descriptor)
284+
.set_num_inputs(7)
285+
.set_attr<TCallEffectKind>("TCallEffectKind",
286+
Integer(CallEffectKind::kOpaque));
287+
278288
TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
279289
.set_num_inputs(2)
280290
.set_attr<TCallEffectKind>("TCallEffectKind",

src/op/builtin.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,11 @@ TVM_DLL const Op &ptx_wgmma_ss();
246246
*/
247247
TVM_DLL const Op &ptx_wgmma_rs();
248248

249+
/*!
250+
* \brief tvm intrinsic for tcgen05 mma shared-shared instructions.
251+
*/
252+
TVM_DLL const Op &ptx_tcgen05_mma_ss();
253+
249254
/*!
250255
* \brief tvm intrinsics for initializing tensor memory
251256
*
@@ -467,7 +472,13 @@ TVM_DLL const Op &tl_shuffle_elect();
467472
* This op is used to represent a descriptor initialization operation in
468473
* tilelang.
469474
*/
470-
TVM_DLL const Op &initialize_descriptor();
475+
TVM_DLL const Op &initialize_wgmma_descriptor();
476+
477+
/*!
478+
* \brief tilelang intrinsic for initializing a descriptor buffer for
479+
* tcgen05 mma.
480+
*/
481+
TVM_DLL const Op &initialize_tcgen05_descriptor();
471482

472483
/*!
473484
* \brief tilelang intrinsic for setting the start address of a descriptor

src/op/gemm.cc

Lines changed: 2 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -12,79 +12,13 @@
1212
#include <tvm/tir/transform.h>
1313

1414
#include "../target/utils.h"
15+
#include "tcgen5_meta.h"
1516

1617
namespace tvm {
1718
namespace tl {
1819

1920
using namespace tir;
2021

21-
struct TCGEN5MMAMeta {
22-
int atom_m, atom_n, atom_k;
23-
};
24-
25-
// Return {is_success, meta}
26-
static inline std::pair<bool, TCGEN5MMAMeta>
27-
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
28-
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
29-
#define FAIL \
30-
return { \
31-
false, TCGEN5MMAMeta { 0, 0, 0 } \
32-
}
33-
#define SUCCESS(atom_m, atom_n, atom_k) \
34-
return { \
35-
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
36-
}
37-
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
38-
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
39-
(c_dtype.is_float() && c_dtype.bits() == 32)) {
40-
if (K % 16 != 0)
41-
FAIL;
42-
if (M % 128 == 0) {
43-
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
44-
if (N % atom_n == 0)
45-
SUCCESS(128, atom_n, 16);
46-
FAIL;
47-
} else if (M % 64 == 0) {
48-
for (int atom_n : ws_valid_atom_ns)
49-
if (N % atom_n == 0)
50-
SUCCESS(64, atom_n, 16);
51-
FAIL;
52-
} else if (M % 32 == 0) {
53-
for (int atom_n : ws_valid_atom_ns)
54-
if (N % atom_n == 0)
55-
SUCCESS(32, atom_n, 16);
56-
FAIL;
57-
} else {
58-
FAIL;
59-
}
60-
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
61-
(c_dtype.is_float() && c_dtype.bits() == 32)) {
62-
if (K % 32 != 0)
63-
FAIL;
64-
if (M % 128 == 0) {
65-
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
66-
if (N % atom_n == 0)
67-
SUCCESS(128, atom_n, 32);
68-
FAIL;
69-
} else if (M % 64 == 0) {
70-
for (int atom_n : ws_valid_atom_ns)
71-
if (N % atom_n == 0)
72-
SUCCESS(64, atom_n, 32);
73-
FAIL;
74-
} else if (M % 32 == 0) {
75-
for (int atom_n : ws_valid_atom_ns)
76-
if (N % atom_n == 0)
77-
SUCCESS(32, atom_n, 32);
78-
FAIL;
79-
} else {
80-
FAIL;
81-
}
82-
}
83-
FAIL;
84-
#undef FAIL
85-
#undef SUCCESS
86-
}
87-
8822
/**
8923
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
9024
* map.
@@ -199,7 +133,7 @@ GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
199133
TargetIsSm100(target)) {
200134
return GemmInst::kMMA;
201135
} else {
202-
ICHECK(0) << "Unsupported target for gemm: " << target->str();
136+
ICHECK(0) << "Unsupported target for gemm: " << target;
203137
}
204138
}
205139

src/op/gemm_py.cc

Lines changed: 41 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -13,80 +13,13 @@
1313

1414
#include "../target/utils.h"
1515
#include "tvm/ffi/string.h"
16-
#include <vector>
16+
#include "tcgen5_meta.h"
1717

1818
namespace tvm {
1919
namespace tl {
2020

2121
using namespace tir;
2222

23-
struct TCGEN5MMAMeta {
24-
int atom_m, atom_n, atom_k;
25-
};
26-
27-
// Return {is_success, meta}
28-
static inline std::pair<bool, TCGEN5MMAMeta>
29-
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
30-
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
31-
#define FAIL \
32-
return { \
33-
false, TCGEN5MMAMeta { 0, 0, 0 } \
34-
}
35-
#define SUCCESS(atom_m, atom_n, atom_k) \
36-
return { \
37-
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
38-
}
39-
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
40-
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
41-
(c_dtype.is_float() && c_dtype.bits() == 32)) {
42-
if (K % 16 != 0)
43-
FAIL;
44-
if (M % 128 == 0) {
45-
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
46-
if (N % atom_n == 0)
47-
SUCCESS(128, atom_n, 16);
48-
FAIL;
49-
} else if (M % 64 == 0) {
50-
for (int atom_n : ws_valid_atom_ns)
51-
if (N % atom_n == 0)
52-
SUCCESS(64, atom_n, 16);
53-
FAIL;
54-
} else if (M % 32 == 0) {
55-
for (int atom_n : ws_valid_atom_ns)
56-
if (N % atom_n == 0)
57-
SUCCESS(32, atom_n, 16);
58-
FAIL;
59-
} else {
60-
FAIL;
61-
}
62-
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
63-
(c_dtype.is_float() && c_dtype.bits() == 32)) {
64-
if (K % 32 != 0)
65-
FAIL;
66-
if (M % 128 == 0) {
67-
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
68-
if (N % atom_n == 0)
69-
SUCCESS(128, atom_n, 32);
70-
FAIL;
71-
} else if (M % 64 == 0) {
72-
for (int atom_n : ws_valid_atom_ns)
73-
if (N % atom_n == 0)
74-
SUCCESS(64, atom_n, 32);
75-
FAIL;
76-
} else if (M % 32 == 0) {
77-
for (int atom_n : ws_valid_atom_ns)
78-
if (N % atom_n == 0)
79-
SUCCESS(32, atom_n, 32);
80-
FAIL;
81-
} else {
82-
FAIL;
83-
}
84-
}
85-
FAIL;
86-
#undef FAIL
87-
#undef SUCCESS
88-
}
89-
9023
/**
9124
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
9225
* map.
@@ -144,6 +77,20 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
14477
if (args.size() > 15) {
14578
node->wg_wait = args[15].as<IntImm>().value()->value;
14679
}
80+
if (args.size() > 16) {
81+
node->mbarptr = args[16];
82+
} else {
83+
node->mbarptr = IntImm(DataType::UInt(32), 0);
84+
}
85+
if (args.size() > 18) {
86+
node->C_coords = Array<PrimExpr>({args[17], args[18]});
87+
} else if (args.size() > 17) {
88+
node->C_coords =
89+
Array<PrimExpr>({args[17], IntImm(DataType::Int(32), 0)});
90+
} else {
91+
node->C_coords = Array<PrimExpr>(
92+
{IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), 0)});
93+
}
14794
data_ = std::move(node);
14895
}
14996

@@ -378,5 +325,31 @@ TVM_FFI_STATIC_INIT_BLOCK({
378325
});
379326
});
380327

328+
TVM_FFI_STATIC_INIT_BLOCK({
329+
namespace refl = tvm::ffi::reflection;
330+
refl::GlobalDef().def(
331+
"tl.get_tcgen5_mma_meta",
332+
[](int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
333+
auto [success, meta] = GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype);
334+
Array<Integer> result;
335+
if (success) {
336+
result.push_back(Integer(meta.atom_m));
337+
result.push_back(Integer(meta.atom_n));
338+
result.push_back(Integer(meta.atom_k));
339+
}
340+
return result;
341+
});
342+
refl::GlobalDef().def(
343+
"tl.get_tcgen5_instr_desc",
344+
[](int atom_m, int atom_n, int atom_k, DataType ab_dtype,
345+
DataType c_dtype, bool a_is_k_major, bool b_is_k_major,
346+
int scale_in_a, int scale_in_b) {
347+
uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype,
348+
c_dtype, a_is_k_major, b_is_k_major,
349+
scale_in_a, scale_in_b);
350+
return Integer(static_cast<int64_t>(desc));
351+
});
352+
});
353+
381354
} // namespace tl
382355
} // namespace tvm

src/op/gemm_py.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class GemmPyNode : public TileOperatorNode {
2929
int stride_A, stride_B;
3030
int offset_A, offset_B;
3131
PrimExpr clear_accum = const_false();
32+
PrimExpr mbarptr;
33+
Array<PrimExpr> C_coords;
3234
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
3335
// only will be enabled under cdna mfma instructions
3436
int kPack = 1;
@@ -57,6 +59,8 @@ class GemmPyNode : public TileOperatorNode {
5759
.def_ro("offset_A", &GemmPyNode::offset_A)
5860
.def_ro("offset_B", &GemmPyNode::offset_B)
5961
.def_ro("clear_accum", &GemmPyNode::clear_accum)
62+
.def_ro("mbarptr", &GemmPyNode::mbarptr)
63+
.def_ro("C_coords", &GemmPyNode::C_coords)
6064
.def_ro("kPack", &GemmPyNode::kPack)
6165
.def_ro("wg_wait", &GemmPyNode::wg_wait)
6266
.def_ro("policy", &GemmPyNode::policy);
@@ -73,6 +77,8 @@ class GemmPyNode : public TileOperatorNode {
7377
equal(offset_A, other->offset_B) &&
7478
equal(offset_B, other->offset_B) &&
7579
equal(clear_accum, other->clear_accum) &&
80+
equal(mbarptr, other->mbarptr) &&
81+
equal(C_coords, other->C_coords) &&
7682
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
7783
equal(policy, other->policy);
7884
}
@@ -94,6 +100,8 @@ class GemmPyNode : public TileOperatorNode {
94100
hash_reduce(offset_A);
95101
hash_reduce(offset_B);
96102
hash_reduce(clear_accum);
103+
hash_reduce(mbarptr);
104+
hash_reduce(C_coords);
97105
hash_reduce(kPack);
98106
hash_reduce(wg_wait);
99107
hash_reduce(policy);

0 commit comments

Comments
 (0)