Skip to content

Commit 17a6397

Browse files
authored
[Enhancement] Add missing fence_barrier_init primitive after mbarrier init (#1121)
* [Enhancement] Add missing primitive after mbarrier init * lint
1 parent 0dc50a5 commit 17a6397

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

src/op/builtin.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ TVM_DLL const Op &initialize_descriptor();
503503
* This op is used to represent a descriptor start address setting operation in
504504
* tilelang.
505505
*/
506+
506507
TVM_DLL const Op &increase_descriptor_offset();
507508
/*!
508509
* \brief tilelang intrinsic for element-wise atomic addition.

src/tl_templates/cuda/barrier.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ TL_DEVICE void fence_proxy_async() {
133133
asm volatile("fence.proxy.async.shared::cta;" : :);
134134
}
135135

136+
TL_DEVICE void fence_barrier_init() {
137+
asm volatile("fence.mbarrier_init.release.cluster;" : :);
138+
}
139+
136140
// Indicate arrival of warp issuing TMA_STORE
137141
TL_DEVICE void tma_store_arrive() {
138142
asm volatile("cp.async.bulk.commit_group;");

src/transform/lower_hopper_intrin.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ class LowerHopperIntrin : public StmtExprMutator {
8383
stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]);
8484
stmt_seq.push_back(stmt_);
8585
if (!init_mbarrier_calls_.empty()) {
86+
// Note from FlashAttention:
87+
// Helps with visibility of barrier init operations across warps /
88+
// cta / cluster Available as a separate function so as to batch
89+
// inits across barriers and fence once Note : It must be composed
90+
// with an appropriate sync instruction with the right scope to
91+
// ensure visibility eg. __syncthreads() or a cluster_arrive() +
92+
// cluster_wait()
93+
Stmt mem_fence = Evaluate(Call(
94+
DataType::Handle(), tvm::tl::ptx_fence_barrier_init(), {}));
95+
stmt_seq.push_back(mem_fence);
8696
Stmt mem_sync =
8797
Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(),
8898
{StringImm("shared")}));

0 commit comments

Comments
 (0)