File tree Expand file tree Collapse file tree 3 files changed +15
-0
lines changed
Expand file tree Collapse file tree 3 files changed +15
-0
lines changed Original file line number Diff line number Diff line change @@ -503,6 +503,7 @@ TVM_DLL const Op &initialize_descriptor();
503503 * This op is used to represent a descriptor start address setting operation in
504504 * tilelang.
505505 */
506+
506507TVM_DLL const Op &increase_descriptor_offset ();
507508/* !
508509 * \brief tilelang intrinsic for element-wise atomic addition.
Original file line number Diff line number Diff line change @@ -133,6 +133,10 @@ TL_DEVICE void fence_proxy_async() {
133133 asm volatile (" fence.proxy.async.shared::cta;" : :);
134134}
135135
136+ TL_DEVICE void fence_barrier_init () {
137+ asm volatile (" fence.mbarrier_init.release.cluster;" : :);
138+ }
139+
136140// Indicate arrival of warp issuing TMA_STORE
137141TL_DEVICE void tma_store_arrive () {
138142 asm volatile (" cp.async.bulk.commit_group;" );
Original file line number Diff line number Diff line change @@ -83,6 +83,16 @@ class LowerHopperIntrin : public StmtExprMutator {
8383 stmts.size () > 1 ? SeqStmt (stmts) : stmts[0 ]);
8484 stmt_seq.push_back (stmt_);
8585 if (!init_mbarrier_calls_.empty ()) {
86+ // Note from FlashAttention:
87+ // Helps with visibility of barrier init operations across warps /
88+ // cta / cluster Available as a separate function so as to batch
89+ // inits across barriers and fence once Note : It must be composed
90+ // with an appropriate sync instruction with the right scope to
91+ // ensure visibility eg. __syncthreads() or a cluster_arrive() +
92+ // cluster_wait()
93+ Stmt mem_fence = Evaluate (Call (
94+ DataType::Handle (), tvm::tl::ptx_fence_barrier_init (), {}));
95+ stmt_seq.push_back (mem_fence);
8696 Stmt mem_sync =
8797 Evaluate (Call (DataType::Handle (), builtin::tvm_storage_sync (),
8898 {StringImm (" shared" )}));
You can’t perform that action at this time.
0 commit comments