Skip to content

Commit bd1c7b3

Browse files
authored
[Refactor] Use has_simt_copy to decide whether to insert set_max_nreg (#982)
1 parent 8f001e0 commit bd1c7b3

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

examples/deepseek_v32/fp8_lighting_indexer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,6 @@ def mqa_attn_return_logits_kernel(
136136
cu_k_s_min = T.alloc_local([1], index_dtype)
137137
cu_k_e_max = T.alloc_local([1], index_dtype)
138138

139-
T.no_set_max_nreg()
140-
141139
cu_k_s_min[0] = 2147483647
142140
cu_k_e_max[0] = -2147483648
143141

src/transform/annotate_warp_group_reg_alloc.cc

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,27 @@ class SetMaxNRegCollector : public StmtExprVisitor {
5959
bool warp_specialized_ = false;
6060
};
6161

62+
class SimtCopyDetector : public StmtExprVisitor {
63+
public:
64+
static bool Detect(const Stmt &stmt) {
65+
SimtCopyDetector detector;
66+
detector.VisitStmt(stmt);
67+
return detector.has_simt_copy_;
68+
}
69+
70+
private:
71+
void VisitStmt_(const BufferStoreNode *op) final {
72+
auto scope =
73+
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
74+
if (scope.to_string() != "global") {
75+
has_simt_copy_ = true;
76+
}
77+
StmtExprVisitor::VisitStmt_(op);
78+
}
79+
80+
bool has_simt_copy_{false};
81+
};
82+
6283
class SetMaxNRegInjector : public StmtExprMutator {
6384
public:
6485
static PrimFunc Inject(PrimFunc f) {
@@ -113,9 +134,7 @@ class SetMaxNRegInjector : public StmtExprMutator {
113134
auto dec_reg_stmt = Evaluate(0);
114135

115136
// Only inject if we have valid register hints and no SIMT copy
116-
// For now, we assume no SIMT copy detection is available here
117-
// TODO: Add SIMT copy detection if needed
118-
bool has_simt_copy = false; // Placeholder
137+
bool has_simt_copy = SimtCopyDetector::Detect(producer_body);
119138

120139
if (dec_reg >= 0 && inc_reg >= 0 && !has_simt_copy) {
121140
auto inc_reg_num =

tilelang/engine/phase.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
135135
mod = tilelang.transform.MultiVersionBuffer()(mod)
136136
mod = tilelang.transform.WarpSpecialized()(mod)
137137
mod = tilelang.transform.InjectTmaBarrier()(mod)
138-
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
139138
# if tma is not enabled, we can also do pipeline planning
140139
# to get better performance with async copy
141140
mod = tilelang.transform.PipelinePlanning()(mod)
@@ -206,6 +205,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
206205
# Inject PTX async copy must behind the thread sync pass
207206
# as ptx async copy won't be recognized as a valid buffer load
208207
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
208+
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
209+
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
209210
mod = tilelang.transform.MakePackedAPI()(mod)
210211
mod = tilelang.transform.LowerDeviceKernelLaunch()(mod)
211212

0 commit comments

Comments
 (0)