@@ -24,32 +24,6 @@ namespace tl {
2424using namespace tir ;
2525using arith::IRMutatorWithAnalyzer;
2626
27- // Helper class to find leaf For nodes in a given IR
28- class LeafForFinder : public StmtVisitor {
29- public:
30- std::vector<For> leaf_for_nodes;
31-
32- private:
33- void VisitStmt_ (const ForNode *op) final {
34- has_child_for_ = false ;
35- bool parent_has_child_for = parent_has_child_for_;
36- parent_has_child_for_ = false ;
37-
38- StmtVisitor::VisitStmt (op->body );
39-
40- if (!has_child_for_) {
41- leaf_for_nodes.push_back (tvm::ffi::GetRef<For>(op));
42- }
43-
44- parent_has_child_for_ = parent_has_child_for;
45- parent_has_child_for_ = true ;
46- }
47-
48- private:
49- bool has_child_for_ = false ;
50- bool parent_has_child_for_ = false ;
51- };
52-
5327// GlobalMemChecker for a BufferLoad/BufferStore node:
5428// 1. Identify BufferLoad and BufferStore nodes.
5529// 2. Check if the buffer is in global scope.
@@ -109,13 +83,16 @@ struct GlobalMemChecker : public StmtExprVisitor {
10983 PrimExpr index = indices[i];
11084 PrimExpr shape_dim = buffer->shape [i];
11185
112- bool has_variable = false ;
86+ bool is_index_constant = true ;
11387 PostOrderVisit (index, [&](const ObjectRef &obj) {
11488 if (const VarNode *v = obj.as <VarNode>()) {
115- has_variable = true ;
89+ is_index_constant = false ;
90+ }
91+ if (const BufferLoadNode *v = obj.as <BufferLoadNode>()) {
92+ is_index_constant = false ;
11693 }
11794 });
118- if (!has_variable ) {
95+ if (is_index_constant ) {
11996 // If index is a constant, we can skip the check
12097 continue ;
12198 }
@@ -145,18 +122,31 @@ struct GlobalMemChecker : public StmtExprVisitor {
145122 bool recursively_collect_conds_;
146123};
147124
148- class SafeMemorysRewriter : public StmtExprMutator {
149- arith::Analyzer *analyzer_;
150-
125+ class SafeMemorysRewriter : public IRMutatorWithAnalyzer {
151126public:
152- explicit SafeMemorysRewriter (Map<Buffer, PrimExpr> annotated_safe_value_map,
153- arith::Analyzer *analyzer)
154- : annotated_safe_value_map_(std::move(annotated_safe_value_map)),
155- analyzer_(analyzer) {}
127+ // Static method to substitute and transform the given PrimFunc
128+ static PrimFunc Substitute (PrimFunc f) {
129+ arith::Analyzer analyzer;
130+ // Create an instance of the legalizer with the analyzer
131+ SafeMemorysRewriter substituter (&analyzer);
132+ // Get a mutable copy of the function node
133+ PrimFuncNode *fptr = f.CopyOnWrite ();
134+ for (const auto &[_, buffer] : f->buffer_map ) {
135+ substituter.buffer_data_to_buffer_ .Set (buffer->data , buffer);
136+ }
137+ // Apply the legalizer to the function body
138+ fptr->body = substituter.VisitStmt (f->body );
139+ return f;
140+ }
156141
157142private:
143+ // Constructor initializing the base class with the analyzer
144+ SafeMemorysRewriter (arith::Analyzer *analyzer)
145+ : arith::IRMutatorWithAnalyzer(analyzer) {}
146+ // Constructor initializing the base class with the analyzer
147+
158148 PrimExpr VisitExpr_ (const BufferLoadNode *op) final {
159- auto load = Downcast<BufferLoad>(StmtExprMutator ::VisitExpr_ (op));
149+ auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer ::VisitExpr_ (op));
160150
161151 // For Load/Store, we only check the current node, not its children.
162152 // Since rewriter will recursively visit children.
@@ -181,7 +171,7 @@ class SafeMemorysRewriter : public StmtExprMutator {
181171
182172 Stmt VisitStmt_ (const BufferStoreNode *op) final {
183173 // Check if the buffer is in global scope
184- auto store = Downcast<BufferStore>(StmtExprMutator ::VisitStmt_ (op));
174+ auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer ::VisitStmt_ (op));
185175
186176 GlobalMemChecker checker (analyzer_, /* recursively_collect_conds=*/ false );
187177 checker (store);
@@ -253,6 +243,25 @@ class SafeMemorysRewriter : public StmtExprMutator {
253243 return evaluate;
254244 }
255245
246+ Stmt VisitStmt_ (const BlockNode *op) final {
247+ for (auto buffer : op->alloc_buffers ) {
248+ buffer_data_to_buffer_.Set (buffer->data , buffer);
249+ }
250+ if (op->annotations .count (attr::kSafeValueMap )) {
251+ auto map = op->annotations .Get (attr::kSafeValueMap )
252+ ->as <Map<Var, PrimExpr>>()
253+ .value ();
254+ for (const auto &[var, safe_value] : map) {
255+ ICHECK (buffer_data_to_buffer_.count (var))
256+ << " buffer " << var << " is not found in the block "
257+ << buffer_data_to_buffer_;
258+ auto buffer = buffer_data_to_buffer_[var];
259+ annotated_safe_value_map_.Set (buffer, safe_value);
260+ }
261+ }
262+ return IRMutatorWithAnalyzer::VisitStmt_ (op);
263+ }
264+
256265 bool IsLocalBuffer (const Buffer &buffer) {
257266 String scope = buffer.scope ();
258267 return scope == " local" || scope == " local.fragment" ||
@@ -276,87 +285,6 @@ class SafeMemorysRewriter : public StmtExprMutator {
276285 return make_zero (buffer->dtype );
277286 }
278287
279- Map<Buffer, PrimExpr> annotated_safe_value_map_;
280- };
281-
282- // Class to legalize safe memory access by transforming them appropriately
283- class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
284- public:
285- // Static method to substitute and transform the given PrimFunc
286- static PrimFunc Substitute (PrimFunc f) {
287- arith::Analyzer analyzer;
288- // Create an instance of the legalizer with the analyzer
289- SafeMemoryLegalizer substituter (&analyzer);
290- // Get a mutable copy of the function node
291- PrimFuncNode *fptr = f.CopyOnWrite ();
292- for (const auto &[_, buffer] : f->buffer_map ) {
293- substituter.buffer_data_to_buffer_ .Set (buffer->data , buffer);
294- }
295- // Apply the legalizer to the function body
296- fptr->body = substituter.VisitStmt (f->body );
297- return f;
298- }
299-
300- private:
301- // Constructor initializing the base class with the analyzer
302- SafeMemoryLegalizer (arith::Analyzer *analyzer)
303- : arith::IRMutatorWithAnalyzer(analyzer) {}
304-
305- // Override the VisitStmt_ method to handle ForNode (loop statements)
306- Stmt VisitStmt_ (const ForNode *op) final {
307- // Visit and potentially modify the loop node
308- For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_ (op));
309- auto has_inner_loop = HasInnerLoop (for_node->body );
310- if (!has_inner_loop) {
311- SafeMemorysRewriter rewriter (annotated_safe_value_map_, analyzer_);
312- for_node.CopyOnWrite ()->body = rewriter (for_node->body );
313- // // Detect Buffer Load Node in the loop body, collect the indices and
314- // buffer size
315-
316- // // Run the checker on the loop body
317- // GlobalMemChecker checker(analyzer_);
318- // checker(for_node->body);
319- // Array<PrimExpr> conditions = checker.GetConditions();
320- // auto body = for_node->body;
321- // // Note that we might have duplicate conditions
322- // // Which will be optimized by simplify pass
323- // // Replace the loop body with the new body
324- // for (auto cond : conditions) {
325- // body = IfThenElse(cond, body);
326- // }
327- // for_node.CopyOnWrite()->body = body;
328- return std::move (for_node);
329- }
330-
331- // Visit a For Node
332- return IRMutatorWithAnalyzer::VisitStmt_ (op);
333- }
334-
335- Stmt VisitStmt_ (const BlockNode *op) final {
336- for (auto buffer : op->alloc_buffers ) {
337- buffer_data_to_buffer_.Set (buffer->data , buffer);
338- }
339- if (op->annotations .count (attr::kSafeValueMap )) {
340- auto map = op->annotations .Get (attr::kSafeValueMap )
341- ->as <Map<Var, PrimExpr>>()
342- .value ();
343- for (const auto &[var, safe_value] : map) {
344- ICHECK (buffer_data_to_buffer_.count (var))
345- << " buffer " << var << " is not found in the block "
346- << buffer_data_to_buffer_;
347- auto buffer = buffer_data_to_buffer_[var];
348- annotated_safe_value_map_.Set (buffer, safe_value);
349- }
350- }
351- return IRMutatorWithAnalyzer::VisitStmt_ (op);
352- }
353-
354- static bool HasInnerLoop (const Stmt &stmt) {
355- LeafForFinder finder;
356- finder (stmt);
357- return !finder.leaf_for_nodes .empty ();
358- }
359-
360288 Map<Var, Buffer> buffer_data_to_buffer_;
361289 Map<Buffer, PrimExpr> annotated_safe_value_map_;
362290};
@@ -371,7 +299,7 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
371299 if (disable_safe_memory_legalize) {
372300 return f;
373301 }
374- return SafeMemoryLegalizer ::Substitute (std::move (f));
302+ return SafeMemorysRewriter ::Substitute (std::move (f));
375303 };
376304 // Create and return a PrimFunc pass with the transformation function
377305 return CreatePrimFuncPass (pass_func, 0 , " tl.LegalizeSafeMemoryAccess" , {});
0 commit comments