Skip to content

Commit ae9b706

Browse files
authored
[Feature] Add ptx_cp_async_barrier_noinc intrinsic and related functionality (#809)
- Introduced a new intrinsic `ptx_cp_async_barrier_noinc` for handling the `cp.async.mbarrier.arrive.noinc` operation in TileLang. - Updated the CUDA code generation to support the new barrier operation. - Added a corresponding function in the TileLang Python API for ease of use. - Enhanced the barrier handling in CUDA templates to include the new no-increment operation, improving synchronization capabilities in parallel execution contexts.
1 parent 5e52952 commit ae9b706

File tree

8 files changed

+143
-89
lines changed

8 files changed

+143
-89
lines changed

src/op/builtin.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_stmatrix)
9090
.set_attr<TCallEffectKind>("TCallEffectKind",
9191
Integer(CallEffectKind::kOpaque));
9292

93+
TIR_DEFINE_TL_BUILTIN(ptx_cp_async_barrier_noinc)
94+
.set_num_inputs(1)
95+
.set_attr<TCallEffectKind>("TCallEffectKind",
96+
Integer(CallEffectKind::kOpaque));
97+
9398
TIR_DEFINE_TL_BUILTIN(fence_proxy_async)
9499
.set_num_inputs(0)
95100
.set_attr<TCallEffectKind>("TCallEffectKind",

src/op/builtin.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,14 @@ TVM_DLL const Op &ptx_ldmatrix();
177177
*/
178178
TVM_DLL const Op &ptx_stmatrix();
179179

180+
/*!
181+
* \brief tvm intrinsic for ptx async copy barrier using
182+
* cp.async.mbarrier.arrive.noinc
183+
*
184+
* This op is used to represent a ptx async copy barrier operation in tilelang.
185+
*/
186+
TVM_DLL const Op &ptx_cp_async_barrier_noinc();
187+
180188
/*!
181189
* \brief Pack two b16 value into a b32 value
182190
*

src/target/codegen_cuda.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
10661066
}
10671067
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
10681068
print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
1069+
} else if (op->op.same_as(tl::ptx_cp_async_barrier_noinc())) {
1070+
print_extern_call_stmt("tl::mbarrier_cp_async_arrive_noinc");
10691071
} else if (op->op.same_as(tl::mbarrier_expect_tx())) {
10701072
ICHECK_EQ(op->args.size(), 2);
10711073
this->PrintIndent();

src/tl_templates/cuda/barrier.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,22 @@ TL_DEVICE void mbarrier_cp_async_arrive(BarrierType &smem_mbar) {
113113
: "r"(smem_int_mbar));
114114
}
115115

116+
template <typename BarrierType = uint64_t>
117+
TL_DEVICE void mbarrier_cp_async_arrive_noinc(BarrierType &smem_mbar) {
118+
uint32_t smem_int_mbar;
119+
if constexpr (std::is_pointer_v<BarrierType>) {
120+
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
121+
} else {
122+
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
123+
}
124+
asm volatile("{\n\t"
125+
"cp.async.mbarrier.arrive.noinc.shared::cta.b64 [%0];\n\t"
126+
"}"
127+
:
128+
: "r"(smem_int_mbar));
129+
cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_int_mbar);
130+
}
131+
116132
TL_DEVICE void fence_proxy_async() {
117133
asm volatile("fence.proxy.async.shared::cta;" : :);
118134
}

src/transform/annotate_warp_group_reg_alloc.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,11 @@
22
* \file annotate_warp_group_reg_alloc.cc
33
* \brief Annotate warp group reg alloc for warp specialization
44
*/
5-
#include <tvm/tir/op.h>
6-
#include <tvm/tir/stmt_functor.h>
7-
#include <tvm/tir/transform.h>
85

6+
#include "warp_specialized_rewriter.h"
97
#include <unordered_set>
10-
#include <utility>
118
#include <vector>
129

13-
#include "../op/builtin.h"
14-
#include "tir/transforms/ir_utils.h"
15-
1610
namespace tvm {
1711
namespace tl {
1812

@@ -57,6 +51,11 @@ class SetMaxNRegCollector : public StmtExprVisitor {
5751
class SetMaxNRegInjector : public StmtExprMutator {
5852
public:
5953
static PrimFunc Inject(PrimFunc f) {
54+
bool warp_specialized = WarpSpecializedDetector::Detect(f->body);
55+
if (warp_specialized) {
56+
// Should handle set_max_nreg when using hand-written warp specialized
57+
return f;
58+
}
6059
auto T = SetMaxNRegInjector();
6160
T.nreg_ = SetMaxNRegCollector::Collect(f);
6261
f.CopyOnWrite()->body = T(f->body);

src/transform/warp_specialized_rewriter.cc

Lines changed: 1 addition & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,7 @@
33
* \brief Warp specialized Pipeline for cuda GPU (sm90+)
44
*/
55

6-
#include "arith/ir_visitor_with_analyzer.h"
7-
#include "tir/analysis/var_use_def_analysis.h"
8-
#include <tvm/ffi/reflection/registry.h>
9-
#include <tvm/tir/analysis.h>
10-
#include <tvm/tir/builtin.h>
11-
#include <tvm/tir/op.h>
12-
#include <tvm/tir/stmt_functor.h>
13-
#include <tvm/tir/transform.h>
14-
15-
#include <utility>
16-
17-
#include "../op/builtin.h"
18-
#include "./common/collector.h"
19-
#include "runtime/thread_storage_scope.h"
20-
#include "tir/transforms/ir_utils.h"
6+
#include "warp_specialized_rewriter.h"
217

228
namespace tvm {
239
namespace tl {
@@ -1284,73 +1270,6 @@ class WarpSpecializedRewriter : public StmtExprMutator {
12841270
bool disable_shuffle_elect_ = false;
12851271
};
12861272

1287-
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
1288-
public:
1289-
// return true means this aws will be disabled
1290-
static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) {
1291-
WarpSpecializedDetector detector;
1292-
detector.VisitStmt(stmt);
1293-
if (detector.has_warp_specialization_) {
1294-
LOG(WARNING) << "Auto warp specialization will be disabled because warp "
1295-
"specialization is manually enabled";
1296-
return true;
1297-
}
1298-
if (detector.has_tma_op_ && detector.has_mbarrier_op_) {
1299-
LOG(WARNING) << "Auto warp specialization will be disabled because TMA "
1300-
"and mbarrier are both present";
1301-
return true;
1302-
}
1303-
return false;
1304-
}
1305-
1306-
WarpSpecializedDetector() {
1307-
has_tma_op_ = false;
1308-
has_mbarrier_op_ = false;
1309-
has_warp_specialization_ = false;
1310-
}
1311-
1312-
private:
1313-
void VisitStmt_(const EvaluateNode *op) final {
1314-
if (const CallNode *call = op->value.as<CallNode>()) {
1315-
if (call->op.same_as(create_list_of_mbarrier()) ||
1316-
call->op.same_as(mbarrier_wait_parity()) ||
1317-
call->op.same_as(builtin::ptx_arrive_barrier()) ||
1318-
call->op.same_as(builtin::ptx_cp_async_barrier())) {
1319-
has_mbarrier_op_ = true;
1320-
}
1321-
}
1322-
IRVisitorWithAnalyzer::VisitStmt_(op);
1323-
}
1324-
1325-
void VisitExpr_(const CallNode *op) final {
1326-
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
1327-
op->op.same_as(set_max_nreg())) {
1328-
has_tma_op_ = true;
1329-
}
1330-
IRVisitorWithAnalyzer::VisitExpr_(op);
1331-
}
1332-
1333-
void VisitStmt_(const AttrStmtNode *op) final {
1334-
if (op->attr_key == "warp_specialize" &&
1335-
op->value.as<IntImmNode>()->value == 1) {
1336-
has_warp_specialization_ = true;
1337-
}
1338-
if (op->attr_key == tir::attr::thread_extent) {
1339-
IterVar iv = Downcast<IterVar>(op->node);
1340-
if (iv->thread_tag == "threadIdx.x") {
1341-
ICHECK(iv->dom->extent.as<IntImmNode>());
1342-
thread_var_ = iv;
1343-
}
1344-
}
1345-
IRVisitorWithAnalyzer::VisitStmt_(op);
1346-
}
1347-
1348-
bool has_tma_op_{false};
1349-
IterVar thread_var_;
1350-
bool has_mbarrier_op_{false};
1351-
bool has_warp_specialization_{false};
1352-
};
1353-
13541273
using namespace tir::transform;
13551274

13561275
tvm::transform::Pass WarpSpecialized() {
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*!
2+
* \file warp_specialized_rewriter.h
3+
* \brief tools for warp-specialized-related analysis and transformation
4+
*/
5+
6+
#pragma once
7+
8+
#include "arith/ir_visitor_with_analyzer.h"
9+
#include "tir/analysis/var_use_def_analysis.h"
10+
#include <tvm/ffi/reflection/registry.h>
11+
#include <tvm/tir/analysis.h>
12+
#include <tvm/tir/builtin.h>
13+
#include <tvm/tir/op.h>
14+
#include <tvm/tir/stmt_functor.h>
15+
#include <tvm/tir/transform.h>
16+
17+
#include <utility>
18+
19+
#include "../op/builtin.h"
20+
#include "./common/collector.h"
21+
#include "runtime/thread_storage_scope.h"
22+
#include "tir/transforms/ir_utils.h"
23+
24+
namespace tvm {
25+
namespace tl {
26+
27+
using namespace tir;
28+
using namespace runtime;
29+
using arith::IRVisitorWithAnalyzer;
30+
31+
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
32+
public:
33+
// return true means this aws will be disabled
34+
static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) {
35+
WarpSpecializedDetector detector;
36+
detector.VisitStmt(stmt);
37+
if (detector.has_warp_specialization_) {
38+
LOG(WARNING) << "Auto warp specialization will be disabled because warp "
39+
"specialization is manually enabled";
40+
return true;
41+
}
42+
if (detector.has_tma_op_ && detector.has_mbarrier_op_) {
43+
LOG(WARNING) << "Auto warp specialization will be disabled because TMA "
44+
"and mbarrier are both present";
45+
return true;
46+
}
47+
return false;
48+
}
49+
50+
WarpSpecializedDetector() {
51+
has_tma_op_ = false;
52+
has_mbarrier_op_ = false;
53+
has_warp_specialization_ = false;
54+
}
55+
56+
private:
57+
void VisitStmt_(const EvaluateNode *op) final {
58+
if (const CallNode *call = op->value.as<CallNode>()) {
59+
if (call->op.same_as(create_list_of_mbarrier()) ||
60+
call->op.same_as(mbarrier_wait_parity()) ||
61+
call->op.same_as(builtin::ptx_arrive_barrier()) ||
62+
call->op.same_as(builtin::ptx_cp_async_barrier())) {
63+
has_mbarrier_op_ = true;
64+
}
65+
}
66+
IRVisitorWithAnalyzer::VisitStmt_(op);
67+
}
68+
69+
void VisitExpr_(const CallNode *op) final {
70+
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
71+
op->op.same_as(set_max_nreg())) {
72+
has_tma_op_ = true;
73+
}
74+
IRVisitorWithAnalyzer::VisitExpr_(op);
75+
}
76+
77+
void VisitStmt_(const AttrStmtNode *op) final {
78+
if (op->attr_key == "warp_specialize" &&
79+
op->value.as<IntImmNode>()->value == 1) {
80+
has_warp_specialization_ = true;
81+
}
82+
if (op->attr_key == tir::attr::thread_extent) {
83+
IterVar iv = Downcast<IterVar>(op->node);
84+
if (iv->thread_tag == "threadIdx.x") {
85+
ICHECK(iv->dom->extent.as<IntImmNode>());
86+
thread_var_ = iv;
87+
}
88+
}
89+
IRVisitorWithAnalyzer::VisitStmt_(op);
90+
}
91+
92+
bool has_tma_op_{false};
93+
IterVar thread_var_;
94+
bool has_mbarrier_op_{false};
95+
bool has_warp_specialization_{false};
96+
};
97+
98+
} // namespace tl
99+
} // namespace tvm

tilelang/language/builtin.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,9 @@ def sync_grid():
350350
"""Synchronize all threads in a grid.
351351
"""
352352
return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid"))
353+
354+
355+
def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]):
356+
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
357+
"""
358+
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)

0 commit comments

Comments
 (0)