Skip to content

Commit 260c10b

Browse files
committed
[Refactor] Enhance TMA barrier validation and support for additional architectures (#463)
* Updated the TMA barrier validation in `inject_tma_barrier.cc` to check for non-empty `barrier_id_to_range_` before raising an error for missing `create_list_of_mbarrier`. * Refactored architecture checks in `phase.py` to utilize a new constant `SUPPORTED_TMA_ARCHS`, allowing for easier updates and improved readability in the target architecture validation logic.
1 parent d4ccd6d commit 260c10b

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

src/transform/inject_tma_barrier.cc

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
#include <tvm/tir/transform.h>
3333

3434
#include "../op/builtin.h"
35+
#include "./common/attr.h"
36+
#include "./common/collector.h"
3537
#include "arith/ir_mutator_with_analyzer.h"
3638
#include "arith/ir_visitor_with_analyzer.h"
37-
#include "./common/collector.h"
38-
#include "./common/attr.h"
3939

4040
namespace tvm {
4141
namespace tl {
@@ -192,13 +192,15 @@ class TmaBarrierCollector : public IRVisitorWithAnalyzer {
192192
tma_op_to_barrier_id_.Set(tma_call, barrier_id);
193193
}
194194
auto const_int_bound = analyzer_.const_int_bound(thread_var_);
195-
auto extent = const_int_bound->max_value - const_int_bound->min_value + 1;
195+
auto extent =
196+
const_int_bound->max_value - const_int_bound->min_value + 1;
196197
UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent));
197198
pending_tma_ops_.clear();
198199
} else if (call->op.same_as(builtin::ptx_wait_barrier())) {
199200
PrimExpr barrier_id = call->args[0];
200201
auto const_int_bound = analyzer_.const_int_bound(thread_var_);
201-
auto extent = const_int_bound->max_value - const_int_bound->min_value + 1;
202+
auto extent =
203+
const_int_bound->max_value - const_int_bound->min_value + 1;
202204
UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent));
203205
}
204206
}
@@ -237,25 +239,25 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
237239
TmaBarrierCollector collector;
238240
collector(f->body);
239241
bool has_create_list_of_mbarrier = false;
240-
PostOrderVisit(f->body, [&](const ObjectRef& node) {
241-
if (const auto* call = node.as<CallNode>()) {
242+
PostOrderVisit(f->body, [&](const ObjectRef &node) {
243+
if (const auto *call = node.as<CallNode>()) {
242244
if (call->op.same_as(create_list_of_mbarrier())) {
243245
has_create_list_of_mbarrier = true;
244246
}
245247
}
246248
});
247249
TmaBarrierRewriter rewriter(analyzer, collector.tma_op_to_barrier_id(),
248-
collector.barrier_id_to_range(), has_create_list_of_mbarrier);
250+
collector.barrier_id_to_range(),
251+
has_create_list_of_mbarrier);
249252
f.CopyOnWrite()->body = rewriter(f->body);
250253
return f;
251254
}
252255

253256
private:
254-
255-
256-
Stmt VisitStmt_(const BlockNode *op){
257+
Stmt VisitStmt_(const BlockNode *op) {
257258
auto block = GetRef<Block>(op);
258-
if (!has_create_list_of_mbarrier_ && op->name_hint == MainBlockName) {
259+
if (!has_create_list_of_mbarrier_ && barrier_id_to_range_.size() > 0 &&
260+
op->name_hint == MainBlockName) {
259261
ICHECK(false) << "Please declare create_list_of_mbarrier.";
260262
}
261263
return IRMutatorWithAnalyzer::VisitStmt_(op);

tilelang/engine/phase.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
from tilelang.transform import PassContext
77
from typing import Optional
88

9+
SUPPORTED_TMA_ARCHS = {"sm_90", "sm_90a"}
10+
911

1012
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
1113
target: Optional[Target] = None) -> bool:
1214
if pass_ctx is None:
1315
pass_ctx = tilelang.transform.get_pass_context()
14-
if target.arch not in {"sm_90", "sm_90a"}:
16+
if target.arch not in SUPPORTED_TMA_ARCHS:
1517
return False
1618
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
1719
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
@@ -20,7 +22,7 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
2022

2123

2224
def allow_fence_proxy(target: Optional[Target] = None) -> bool:
23-
return target.arch in {"sm_90", "sm_90a"}
25+
return target.arch in SUPPORTED_TMA_ARCHS
2426

2527

2628
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:

0 commit comments

Comments
 (0)