Skip to content

Commit aef0a6b

Browse files
[Language] Expose T.warpgroup_fence_operand for nvcc code motion (#986)
* remove debug print * pipeline fix * use the correct buffer access scope * rs support * warp warpgroup_fence_operand * fix * fp8 dtype ptx enhance * mma fix * TCGEN05 Interface * tcgen05 support * rebase * update * Enhance TCGEN05 support by adding new intrinsic operations and descriptors. Introduced `ptx_tcgen05_mma_ts` for tensor-memory to shared-memory instructions and `tcgen05_mma_arrive` for signaling barrier completion. Updated existing descriptors and code generation logic to accommodate these changes, ensuring compatibility with new instruction sets. Refactored related allocation functions and improved handling of shared memory descriptors. * lint fix * Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module. * wgmma fix --------- Co-authored-by: Zhiwen Mo <zm125@ic.ac.uk>
1 parent c85bb3a commit aef0a6b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2674
-804
lines changed

3rdparty/tvm

Submodule tvm updated from 0f1ebab to 1815c3e

docs/compiler_internals/inject_fence_proxy.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the
1717
### Timeline View
1818

1919
```
20-
generic initialize_descriptor → generic shared-store → async wgmma
20+
generic initialize_wgmma_descriptor → generic shared-store → async wgmma
2121
│ │ │
2222
└─ generic proxy ┴─ generic proxy ┴─ async proxy
2323
│ fence inserted here ↑
@@ -53,7 +53,7 @@ def kernel():
5353
with T.Kernel(1):
5454
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
5555
smem = T.decl_buffer((128,), "float16", scope="shared")
56-
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
56+
T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
5757
smem[0] = T.float16(0)
5858
T.ptx_wgmma_ss(
5959
"float16",
@@ -83,7 +83,7 @@ def kernel():
8383
with T.Kernel(1):
8484
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
8585
smem = T.decl_buffer((128,), "float16", scope="shared")
86-
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
86+
T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
8787
smem[0] = T.float16(0)
8888
T.fence_proxy_async()
8989
T.ptx_wgmma_ss(

src/layout/layout.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,12 @@ TVM_FFI_STATIC_INIT_BLOCK() {
546546
return makeGemmABLayoutHopper(stride, mat_continuous, continuity,
547547
element_size, k_inner);
548548
})
549+
.def("tl.make_tcgen05mma_swizzled_layout",
550+
[](int stride, int mat_continuous, int continuity, int element_size,
551+
bool k_inner) {
552+
return makeGemmABLayoutSm100(stride, mat_continuous, continuity,
553+
element_size, k_inner);
554+
})
549555
.def("tl.make_full_bank_swizzled_layout",
550556
[](int stride, int continuous, int element_size) {
551557
return makeFullBankSwizzleLayout(stride, continuous, element_size);

src/op/builtin.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,16 @@ TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs)
155155
.set_attr<TCallEffectKind>("TCallEffectKind",
156156
Integer(CallEffectKind::kOpaque));
157157

158+
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
159+
.set_num_inputs(14)
160+
.set_attr<TCallEffectKind>("TCallEffectKind",
161+
Integer(CallEffectKind::kOpaque));
162+
163+
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ts)
164+
.set_num_inputs(13)
165+
.set_attr<TCallEffectKind>("TCallEffectKind",
166+
Integer(CallEffectKind::kOpaque));
167+
158168
TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory)
159169
.set_num_inputs(2)
160170
.set_attr<TCallEffectKind>("TCallEffectKind",
@@ -219,6 +229,11 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
219229
.set_attr<TCallEffectKind>("TCallEffectKind",
220230
Integer(CallEffectKind::kOpaque));
221231

232+
TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand)
233+
.set_num_inputs(4)
234+
.set_attr<TCallEffectKind>("TCallEffectKind",
235+
Integer(CallEffectKind::kOpaque));
236+
222237
TIR_DEFINE_TL_BUILTIN(get_lane_idx)
223238
.set_num_inputs(-1)
224239
.set_attr<TCallEffectKind>("TCallEffectKind",
@@ -286,11 +301,16 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
286301
.set_attr<TCallEffectKind>("TCallEffectKind",
287302
Integer(CallEffectKind::kPure));
288303

289-
TIR_DEFINE_TL_BUILTIN(initialize_descriptor)
304+
TIR_DEFINE_TL_BUILTIN(initialize_wgmma_descriptor)
290305
.set_num_inputs(5)
291306
.set_attr<TCallEffectKind>("TCallEffectKind",
292307
Integer(CallEffectKind::kOpaque));
293308

309+
TIR_DEFINE_TL_BUILTIN(initialize_tcgen05_descriptor)
310+
.set_num_inputs(7)
311+
.set_attr<TCallEffectKind>("TCallEffectKind",
312+
Integer(CallEffectKind::kOpaque));
313+
294314
TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
295315
.set_num_inputs(2)
296316
.set_attr<TCallEffectKind>("TCallEffectKind",
@@ -311,5 +331,10 @@ TIR_DEFINE_TL_BUILTIN(device_assert_with_msg)
311331
.set_attr<TCallEffectKind>("TCallEffectKind",
312332
Integer(CallEffectKind::kOpaque));
313333

334+
TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive)
335+
.set_num_inputs(1)
336+
.set_attr<TCallEffectKind>("TCallEffectKind",
337+
Integer(CallEffectKind::kOpaque));
338+
314339
} // namespace tl
315340
} // namespace tvm

src/op/builtin.h

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,14 +241,24 @@ TVM_DLL const Op &ptx_wgmma_ss();
241241
/*!
242242
* \brief tvm intrinsics for ptx tensor core wgmma instructions.
243243
*
244-
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool
245-
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm
246-
* b_dtype_abbrv, StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr
247-
* A_offset, Var B_descriptor, Var B_offset, Var C_data, Var C_offset, bool
248-
* scale_out, bool scale_in_a, bool scale_in_b);
244+
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix,
245+
* bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
246+
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
247+
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out,
248+
* bool scale_in_a, bool scale_in_b);
249249
*/
250250
TVM_DLL const Op &ptx_wgmma_rs();
251251

252+
/*!
253+
* \brief tvm intrinsic for tcgen05 mma shared-shared instructions.
254+
*/
255+
TVM_DLL const Op &ptx_tcgen05_mma_ss();
256+
257+
/*!
258+
* \brief tvm intrinsic for tcgen05 mma tensor-shared instructions.
259+
*/
260+
TVM_DLL const Op &ptx_tcgen05_mma_ts();
261+
252262
/*!
253263
* \brief tvm intrinsics for initializing tensor memory
254264
*
@@ -361,6 +371,14 @@ TVM_DLL const Op &warpgroup_commit_batch();
361371
*/
362372
TVM_DLL const Op &warpgroup_wait();
363373

374+
/*!
375+
* \brief Fence accumulator operand registers for upcoming WGMMA operations
376+
*
377+
* warpgroup_fence_operand(dtype, ptr, offset, num_regs)
378+
*
379+
*/
380+
TVM_DLL const Op &warpgroup_fence_operand();
381+
364382
/*!
365383
* \brief Return the canonical lane index for the calling thread.
366384
*
@@ -494,7 +512,21 @@ TVM_DLL const Op &tl_shuffle_elect();
494512
* This op is used to represent a descriptor initialization operation in
495513
* tilelang.
496514
*/
497-
TVM_DLL const Op &initialize_descriptor();
515+
TVM_DLL const Op &initialize_wgmma_descriptor();
516+
517+
/*!
518+
* \brief tilelang intrinsic for initializing a descriptor buffer for
519+
* tcgen05 mma.
520+
*/
521+
TVM_DLL const Op &initialize_tcgen05_descriptor();
522+
523+
/*!
524+
* \brief tilelang intrinsic for committing UMMA (TCGEN05) barrier arrive.
525+
*
526+
* This op wraps the device-side arrive used to signal completion of MMA work
527+
* to a shared-memory mbarrier. It mirrors CUTLASS's umma_arrive.
528+
*/
529+
TVM_DLL const Op &tcgen05_mma_arrive();
498530

499531
/*!
500532
* \brief tilelang intrinsic for setting the start address of a descriptor

src/op/gemm.cc

Lines changed: 6 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -12,77 +12,13 @@
1212
#include <tvm/tir/transform.h>
1313

1414
#include "../target/utils.h"
15+
#include "tcgen5_meta.h"
1516

1617
namespace tvm {
1718
namespace tl {
1819

1920
using namespace tir;
2021

21-
struct TCGEN5MMAMeta {
22-
int atom_m, atom_n, atom_k;
23-
};
24-
25-
// Return {is_success, meta}
26-
static inline std::pair<bool, TCGEN5MMAMeta>
27-
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
28-
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
29-
#define FAIL \
30-
return { false, TCGEN5MMAMeta{0, 0, 0} }
31-
#define SUCCESS(atom_m, atom_n, atom_k) \
32-
return { \
33-
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
34-
}
35-
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
36-
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
37-
(c_dtype.is_float() && c_dtype.bits() == 32)) {
38-
if (K % 16 != 0)
39-
FAIL;
40-
if (M % 128 == 0) {
41-
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
42-
if (N % atom_n == 0)
43-
SUCCESS(128, atom_n, 16);
44-
FAIL;
45-
} else if (M % 64 == 0) {
46-
for (int atom_n : ws_valid_atom_ns)
47-
if (N % atom_n == 0)
48-
SUCCESS(64, atom_n, 16);
49-
FAIL;
50-
} else if (M % 32 == 0) {
51-
for (int atom_n : ws_valid_atom_ns)
52-
if (N % atom_n == 0)
53-
SUCCESS(32, atom_n, 16);
54-
FAIL;
55-
} else {
56-
FAIL;
57-
}
58-
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
59-
(c_dtype.is_float() && c_dtype.bits() == 32)) {
60-
if (K % 32 != 0)
61-
FAIL;
62-
if (M % 128 == 0) {
63-
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
64-
if (N % atom_n == 0)
65-
SUCCESS(128, atom_n, 32);
66-
FAIL;
67-
} else if (M % 64 == 0) {
68-
for (int atom_n : ws_valid_atom_ns)
69-
if (N % atom_n == 0)
70-
SUCCESS(64, atom_n, 32);
71-
FAIL;
72-
} else if (M % 32 == 0) {
73-
for (int atom_n : ws_valid_atom_ns)
74-
if (N % atom_n == 0)
75-
SUCCESS(32, atom_n, 32);
76-
FAIL;
77-
} else {
78-
FAIL;
79-
}
80-
}
81-
FAIL;
82-
#undef FAIL
83-
#undef SUCCESS
84-
}
85-
8622
/**
8723
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
8824
* map.
@@ -186,6 +122,8 @@ bool GemmNode::AllowWGMMA(int block_size, Target target) const {
186122
GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
187123
bool allow_tcgen5mma = AllowTCGEN5MMA(target);
188124
bool allow_wgmma = AllowWGMMA(block_size, target);
125+
LOG(INFO) << "allow_tcgen5mma: " << allow_tcgen5mma
126+
<< ", allow_wgmma: " << allow_wgmma;
189127
if (allow_tcgen5mma) {
190128
return GemmInst::kTCGEN5MMA;
191129
} else if (allow_wgmma) {
@@ -195,7 +133,7 @@ GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
195133
} else if (TargetIsCuda(target)) {
196134
return GemmInst::kMMA;
197135
} else {
198-
ICHECK(0) << "Unsupported target for gemm: " << target->str();
136+
ICHECK(0) << "Unsupported target for gemm: " << target;
199137
}
200138
}
201139

@@ -578,6 +516,8 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
578516

579517
if (A.scope() == "local.fragment") {
580518
ICHECK(B.scope() != "local.fragment");
519+
ICHECK(!trans_A)
520+
<< "gemm_rs requires the A operand to be in non-transposed layout.";
581521
op_name = "tl::gemm_rs";
582522
} else if (B.scope() == "local.fragment") {
583523
op_name = "tl::gemm_sr";

src/op/gemm_py.cc

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
#include "../support/ffi_aliases.h"
1515
#include "../target/utils.h"
16+
#include "tcgen5_meta.h"
17+
#include "tvm/ffi/string.h"
1618

1719
namespace tvm {
1820
namespace tl {
@@ -49,7 +51,6 @@ using namespace tir;
4951
*/
5052
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
5153
ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();
52-
5354
node->Aptr = args[0];
5455
node->Bptr = args[1];
5556
node->Cptr = args[2];
@@ -76,6 +77,19 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
7677
if (args.size() > 15) {
7778
node->wg_wait = args[15].as<IntImm>().value()->value;
7879
}
80+
if (args.size() > 16) {
81+
node->mbarptr = args[16];
82+
} else {
83+
node->mbarptr = IntImm(DataType::UInt(32), 0);
84+
}
85+
if (args.size() > 18) {
86+
node->C_coords = Array<PrimExpr>({args[17], args[18]});
87+
} else if (args.size() > 17) {
88+
node->C_coords = Array<PrimExpr>({args[17], IntImm(DataType::Int(32), 0)});
89+
} else {
90+
node->C_coords = Array<PrimExpr>(
91+
{IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), 0)});
92+
}
7993
data_ = std::move(node);
8094
}
8195

@@ -92,16 +106,37 @@ TileOperator GemmPyNode::Clone() const {
92106
return GemmPy(op);
93107
}
94108

95-
GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
109+
bool GemmPyNode::AllowTCGEN5MMA(Target target) const {
110+
return TargetIsSm100(target) &&
111+
((A.scope() == "shared.dyn" || A.scope() == "shared" ||
112+
A.scope() == "shared.tmem") &&
113+
(B.scope() == "shared.dyn" || B.scope() == "shared") &&
114+
C.scope() == "shared.tmem") &&
115+
GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first;
116+
}
117+
118+
bool GemmPyNode::AllowWGMMA(int block_size, Target target) const {
119+
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
120+
96121
int warp_size = TargetGetWarpSize(target);
97122
int num_warps = block_size / warp_size;
98-
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
99-
(num_warps % 4 == 0) && CheckWGMMA();
100-
if (allow_wgmma) {
123+
return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
124+
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
125+
CheckWGMMA();
126+
}
127+
128+
GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
129+
bool allow_tcgen5mma = AllowTCGEN5MMA(target);
130+
bool allow_wgmma = AllowWGMMA(block_size, target);
131+
if (allow_tcgen5mma) {
132+
return GemmInst::kTCGEN5MMA;
133+
} else if (allow_wgmma) {
101134
return GemmInst::kWGMMA;
102135
} else if (TargetIsCDNA(target)) {
103136
return GemmInst::kMFMA;
104-
} else if (TargetIsCuda(target)) {
137+
} else if (TargetIsVolta(target) || TargetIsAmpere(target) ||
138+
TargetIsTuring(target) || TargetIsHopper(target) ||
139+
TargetIsSm100(target)) {
105140
return GemmInst::kMMA;
106141
} else {
107142
ICHECK(0) << "Unsupported target for gemm: " << target->str();
@@ -290,5 +325,31 @@ TVM_FFI_STATIC_INIT_BLOCK() {
290325
});
291326
}
292327

328+
TVM_FFI_STATIC_INIT_BLOCK() {
329+
namespace refl = tvm::ffi::reflection;
330+
refl::GlobalDef().def(
331+
"tl.get_tcgen5_mma_meta",
332+
[](int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
333+
auto [success, meta] = GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype);
334+
Array<Integer> result;
335+
if (success) {
336+
result.push_back(Integer(meta.atom_m));
337+
result.push_back(Integer(meta.atom_n));
338+
result.push_back(Integer(meta.atom_k));
339+
}
340+
return result;
341+
});
342+
refl::GlobalDef().def(
343+
"tl.get_tcgen5_instr_desc",
344+
[](int atom_m, int atom_n, int atom_k, DataType ab_dtype,
345+
DataType c_dtype, bool a_is_k_major, bool b_is_k_major, int scale_in_a,
346+
int scale_in_b) {
347+
uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype,
348+
c_dtype, a_is_k_major, b_is_k_major,
349+
scale_in_a, scale_in_b);
350+
return Integer(static_cast<int64_t>(desc));
351+
});
352+
}
353+
293354
} // namespace tl
294355
} // namespace tvm

0 commit comments

Comments
 (0)