Skip to content

Commit b9db119

Browse files
authored
[Refactor] Enhance MergeSharedMemoryAllocations Pass for Improved Liveness Analysis and Scope Management (#508)
* Introduced a new StmtAttr structure to track the scope level of statements, enhancing the liveness analysis process. * Updated the UpdateStmtAttr function to manage statement attributes effectively during memory allocation visits. * Modified the VisitStmt_ methods to utilize the new scope level tracking, ensuring accurate memory access patterns. * Refactored the LivenessAnalysis and PlanMemory functions to incorporate statement attributes, improving the handling of gen and kill points in memory management. * Added a new helper function allow_warp_specialized in phase.py to conditionally enable warp specialization based on pass context and target, addressing potential bugs in the MergeSharedMemoryAllocations pass. * Enhanced the OptimizeForTarget function to conditionally apply the MergeSharedMemoryAllocations pass based on warp specialization settings, improving robustness in memory allocation strategies.
1 parent 0fb8da7 commit b9db119

File tree

2 files changed

+228
-19
lines changed

2 files changed

+228
-19
lines changed

src/transform/merge_shared_memory_allocations.cc

Lines changed: 205 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,19 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
119119
const AllocateNode *alloc{nullptr};
120120
};
121121

122+
struct StmtAttr {
123+
// the level in the scope stack
124+
size_t level{0};
125+
};
126+
127+
void UpdateStmtAttr(const Object *stmt, size_t level) {
128+
if (stmt_attrs_.find(stmt) == stmt_attrs_.end()) {
129+
stmt_attrs_[stmt] = StmtAttr{level};
130+
} else {
131+
stmt_attrs_[stmt].level = level;
132+
}
133+
}
134+
122135
void VisitStmt_(const AllocateNode *op) final {
123136
size_t level = scope_.size();
124137
const VarNode *buf = op->buffer_var.get();
@@ -137,13 +150,14 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
137150
if (it != alloc_info_.end() && it->second.alloc) {
138151
ICHECK_LT(it->second.level, scope_.size());
139152
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
140-
scope_[it->second.level].touched.push_back(buf);
153+
scope_[scope_.size() - 1].touched.push_back(buf);
141154
}
142155
}
143156
StmtEntry e = scope_.back();
144157
scope_.pop_back();
145158
if (e.touched.size() != 0) {
146159
e.stmt = op;
160+
UpdateStmtAttr(op, scope_level_);
147161
linear_seq_.push_back(e);
148162
}
149163
}
@@ -156,6 +170,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
156170
scope_.pop_back();
157171
if (e.touched.size() != 0) {
158172
e.stmt = op;
173+
UpdateStmtAttr(op, scope_level_);
159174
linear_seq_.push_back(e);
160175
}
161176
}
@@ -169,7 +184,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
169184
ICHECK_LT(it->second.level, scope_.size())
170185
<< "Load memory in places other than store.";
171186
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
172-
scope_[it->second.level].touched.push_back(buf);
187+
scope_[scope_.size() - 1].touched.push_back(buf);
173188
}
174189
}
175190
}
@@ -180,7 +195,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
180195
if (it != alloc_info_.end() && it->second.alloc) {
181196
ICHECK_LT(it->second.level, scope_.size());
182197
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
183-
scope_[it->second.level].touched.push_back(buf);
198+
scope_[scope_.size() - 1].touched.push_back(buf);
184199
}
185200
}
186201
}
@@ -189,6 +204,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
189204
scope_.push_back(StmtEntry());
190205
StmtEntry e;
191206
e.stmt = op;
207+
UpdateStmtAttr(op, scope_level_);
192208
int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
193209
// before scope.
194210
linear_seq_.push_back(e);
@@ -226,7 +242,15 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
226242

227243
void VisitStmt_(const IfThenElseNode *op) final { VisitNewScope(op); }
228244

229-
void VisitStmt_(const ForNode *op) final { VisitNewScope(op); }
245+
void VisitStmt_(const ForNode *op) final {
246+
if (op->body->IsInstance<SeqStmtNode>()) {
247+
scope_level_++;
248+
VisitNewScope(op);
249+
scope_level_--;
250+
} else {
251+
VisitNewScope(op);
252+
}
253+
}
230254

231255
void VisitStmt_(const WhileNode *op) final { VisitNewScope(op); }
232256

@@ -236,6 +260,8 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
236260
std::vector<StmtEntry> linear_seq_;
237261
// The storage scope of each buffer
238262
std::unordered_map<const VarNode *, AllocEntry> alloc_info_;
263+
// The attribute of each statement
264+
std::unordered_map<const Object *, StmtAttr> stmt_attrs_;
239265

240266
private:
241267
// Wrapper function to determine if the shared memory allocation for a
@@ -251,6 +277,8 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
251277
bool in_thread_env_{false};
252278
// The scope stack.
253279
std::vector<StmtEntry> scope_;
280+
// The size of the scope.
281+
size_t scope_level_{0};
254282
};
255283

256284
/*!
@@ -279,8 +307,8 @@ class SharedMemoryRewriter : public StmtExprMutator {
279307
bool verbose = false) {
280308
SharedMemLinearAccessPatternFinder finder(is_dynamic, verbose);
281309
finder(stmt);
282-
this->LivenessAnalysis(finder.linear_seq_);
283-
this->PlanMemory(finder.linear_seq_);
310+
this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_);
311+
this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_);
284312
}
285313

286314
private:
@@ -491,6 +519,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
491519
}
492520

493521
using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry;
522+
using StmtAttr = SharedMemLinearAccessPatternFinder::StmtAttr;
494523
struct StorageEntry {
495524
// The constant size of the buffer in bits, only used if it is constant
496525
uint64_t const_nbits{0};
@@ -515,7 +544,9 @@ class SharedMemoryRewriter : public StmtExprMutator {
515544
* \brief Liveness analysis to find gen and kill point of each variable.
516545
* \param seq the linear pattern of storage access
517546
*/
518-
void LivenessAnalysis(const std::vector<StmtEntry> &seq) {
547+
void LivenessAnalysis(
548+
const std::vector<StmtEntry> &seq,
549+
const std::unordered_map<const Object *, StmtAttr> &stmt_attrs) {
519550
// find kill point, do a reverse linear scan.
520551
std::unordered_set<const VarNode *> touched;
521552
for (size_t i = seq.size(); i != 0; --i) {
@@ -543,17 +574,174 @@ class SharedMemoryRewriter : public StmtExprMutator {
543574
}
544575

545576
if (verbose_) {
546-
LOG(DEBUG) << "Liveness Analysis Results for "
577+
std::vector<const Object *> stmt_keys;
578+
for (const auto &stmt_entry : seq) {
579+
auto stmt = stmt_entry.stmt;
580+
if (std::find(stmt_keys.begin(), stmt_keys.end(), stmt) ==
581+
stmt_keys.end()) {
582+
stmt_keys.push_back(stmt);
583+
}
584+
}
585+
LOG(DEBUG) << "Before reorder kill points, Liveness Analysis Results for "
547586
<< (is_dynamic_ ? "Dynamic" : "Static") << " Shared Memory:";
548-
for (const auto &pair : event_map_) {
549-
const Object *stmt_obj = pair.first;
550-
const EventEntry &entry = pair.second;
587+
for (const auto &stmt_key : stmt_keys) {
588+
auto it = event_map_.find(stmt_key);
589+
if (it == event_map_.end())
590+
continue;
591+
592+
const EventEntry &entry = it->second;
593+
if (entry.gen.empty() && entry.kill.empty())
594+
continue;
595+
ICHECK(stmt_attrs.count(stmt_key))
596+
<< "stmt_key = " << stmt_key->GetTypeKey();
597+
auto level = stmt_attrs.at(stmt_key).level;
598+
LOG(DEBUG) << " Statement: " << stmt_key->GetTypeKey()
599+
<< " (scope_level: " << level << ")";
600+
601+
std::stringstream gen_vars_ss;
602+
bool x_generated = false;
603+
for (const VarNode *var : entry.gen) {
604+
gen_vars_ss << var->name_hint << " ";
605+
if (var->name_hint == "x") {
606+
x_generated = true;
607+
}
608+
}
609+
if (!entry.gen.empty()) {
610+
std::string gen_log_msg = " GEN: " + gen_vars_ss.str();
611+
if (x_generated) {
612+
gen_log_msg += " <-- Buffer 'x' generated";
613+
}
614+
LOG(DEBUG) << gen_log_msg;
615+
}
616+
617+
std::stringstream kill_vars_ss;
618+
bool x_killed = false;
619+
for (const VarNode *var : entry.kill) {
620+
kill_vars_ss << var->name_hint << " ";
621+
if (var->name_hint == "x") {
622+
x_killed = true;
623+
}
624+
}
625+
if (!entry.kill.empty()) {
626+
std::string kill_log_msg = " KILL: " + kill_vars_ss.str();
627+
if (x_killed) {
628+
kill_log_msg += " <-- Buffer 'x' killed";
629+
}
630+
LOG(DEBUG) << kill_log_msg;
631+
}
632+
}
633+
LOG(DEBUG) << "End of Liveness Analysis Results.";
634+
}
635+
636+
// Reorder kill points:
637+
// For each buffer, if its kill statement is at a deeper scope level than
638+
// its gen statement, we need to move the kill point to the end of the gen
639+
// statement's scope level. This ensures proper memory deallocation at the
640+
// right scope boundary.
641+
std::vector<StmtEntry> gen_kill_seq;
642+
for (const auto &stmt_entry : seq) {
643+
// if has gen and kill, add to gen_kill_seq
644+
if (event_map_[stmt_entry.stmt].gen.size() > 0 ||
645+
event_map_[stmt_entry.stmt].kill.size() > 0) {
646+
gen_kill_seq.push_back(stmt_entry);
647+
}
648+
}
551649

552-
if (entry.gen.empty() && entry.kill.empty()) {
553-
continue; // Skip statements with no gen/kill events for brevity
650+
for (auto &event_pair : event_map_) {
651+
const Object *stmt = event_pair.first;
652+
EventEntry &event = event_pair.second;
653+
654+
// Skip if no kill points to process
655+
if (event.kill.empty())
656+
continue;
657+
658+
// Get scope level of current statement
659+
ICHECK(stmt_attrs.count(stmt));
660+
int kill_level = stmt_attrs.at(stmt).level;
661+
662+
std::unordered_set<const VarNode *> visited_buffers;
663+
664+
// For each killed buffer, find its gen statement and check scope levels
665+
for (auto it = event.kill.begin(); it != event.kill.end();) {
666+
const VarNode *buffer = *it;
667+
bool found_gen = false;
668+
int gen_level = 0;
669+
670+
// Find the gen statement for this buffer
671+
for (const auto &gen_pair : event_map_) {
672+
const auto &gen_event = gen_pair.second;
673+
if (std::find(gen_event.gen.begin(), gen_event.gen.end(), buffer) !=
674+
gen_event.gen.end()) {
675+
found_gen = true;
676+
gen_level = stmt_attrs.at(gen_pair.first).level;
677+
break;
678+
}
679+
}
680+
681+
if (found_gen && kill_level > gen_level) {
682+
if (visited_buffers.count(buffer)) {
683+
++it;
684+
continue;
685+
}
686+
// Need to move kill point - remove from current event
687+
it = event.kill.erase(it);
688+
689+
// Find the last statement at gen_level and add kill point there
690+
// Find the last statement at gen_level in the sequence
691+
const Object *last_stmt_at_level = nullptr;
692+
auto stmt_it = gen_kill_seq.begin();
693+
for (; stmt_it != gen_kill_seq.end(); ++stmt_it) {
694+
if (stmt_it->stmt == stmt) {
695+
break;
696+
}
697+
}
698+
// start from current statement and find the last statement at
699+
// gen_level
700+
701+
for (; stmt_it != gen_kill_seq.end(); ++stmt_it) {
702+
// Check if next statement has different level
703+
auto next_it = stmt_it + 1;
704+
if (next_it == gen_kill_seq.end() ||
705+
stmt_attrs.at(next_it->stmt).level == gen_level) {
706+
last_stmt_at_level = stmt_it->stmt;
707+
break;
708+
}
709+
}
710+
if (last_stmt_at_level) {
711+
event_map_[last_stmt_at_level].kill.push_back(buffer);
712+
visited_buffers.insert(buffer);
713+
}
714+
} else {
715+
++it;
554716
}
717+
}
718+
}
555719

556-
LOG(DEBUG) << " Statement: " << stmt_obj->GetTypeKey();
720+
std::vector<const Object *> stmt_keys;
721+
for (const auto &stmt_entry : seq) {
722+
auto stmt = stmt_entry.stmt;
723+
if (std::find(stmt_keys.begin(), stmt_keys.end(), stmt) ==
724+
stmt_keys.end()) {
725+
stmt_keys.push_back(stmt);
726+
}
727+
}
728+
729+
if (verbose_) {
730+
LOG(DEBUG) << "Liveness Analysis Results for "
731+
<< (is_dynamic_ ? "Dynamic" : "Static") << " Shared Memory:";
732+
for (const auto &stmt_key : stmt_keys) {
733+
auto it = event_map_.find(stmt_key);
734+
if (it == event_map_.end())
735+
continue;
736+
737+
const EventEntry &entry = it->second;
738+
if (entry.gen.empty() && entry.kill.empty())
739+
continue;
740+
ICHECK(stmt_attrs.count(stmt_key))
741+
<< "stmt_key = " << stmt_key->GetTypeKey();
742+
auto level = stmt_attrs.at(stmt_key).level;
743+
LOG(DEBUG) << " Statement: " << stmt_key->GetTypeKey()
744+
<< " (scope_level: " << level << ")";
557745

558746
std::stringstream gen_vars_ss;
559747
bool x_generated = false;
@@ -596,7 +784,9 @@ class SharedMemoryRewriter : public StmtExprMutator {
596784
* \param seq the linear pattern of storage access
597785
* \param alloc_info
598786
*/
599-
void PlanMemory(const std::vector<StmtEntry> &seq) {
787+
void
788+
PlanMemory(const std::vector<StmtEntry> &seq,
789+
const std::unordered_map<const Object *, StmtAttr> &stmt_attrs) {
600790
std::unordered_set<const VarNode *> inplace_flag;
601791

602792
for (size_t i = 0; i < seq.size(); ++i) {

tilelang/engine/phase.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,19 @@
88
from typing import Optional
99

1010

11+
def allow_warp_specialized(pass_ctx: Optional[PassContext] = None,
12+
target: Optional[Target] = None) -> bool:
13+
# avoid circular import
14+
from tilelang.jit.adapter.utils import is_cuda_target
15+
16+
if pass_ctx is None:
17+
pass_ctx = tilelang.transform.get_pass_context()
18+
if not is_cuda_target(target):
19+
return False
20+
disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False)
21+
return not disable_warp_specialized
22+
23+
1124
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
1225
target: Optional[Target] = None) -> bool:
1326
# avoid circular import
@@ -18,9 +31,7 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
1831
if not is_cuda_target(target) or not have_tma(target):
1932
return False
2033
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
21-
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
22-
disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False)
23-
return not (disable_tma_lower and disable_warp_specialized)
34+
return not disable_tma_lower and allow_warp_specialized(pass_ctx=pass_ctx, target=target)
2435

2536

2637
def allow_fence_proxy(target: Optional[Target] = None) -> bool:
@@ -130,7 +141,15 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
130141

131142
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
132143
mod = tir.transform.SplitHostDevice()(mod)
133-
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
144+
145+
if allow_warp_specialized(pass_ctx=pass_ctx, target=target):
146+
# This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass
147+
# when warp specialization is enabled, as different warp threads may access different
148+
# buffers, but the liveness analysis is hard because we need to do pipeline.
149+
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
150+
else:
151+
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
152+
134153
mod = tilelang.transform.MakePackedAPI()(mod)
135154
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
136155

0 commit comments

Comments
 (0)