@@ -109,13 +109,16 @@ struct GlobalMemChecker : public StmtExprVisitor {
109109 PrimExpr index = indices[i];
110110 PrimExpr shape_dim = buffer->shape [i];
111111
112- bool has_variable = false ;
112+ bool is_index_constant = true ;
113113 PostOrderVisit (index, [&](const ObjectRef &obj) {
114114 if (const VarNode *v = obj.as <VarNode>()) {
115- has_variable = true ;
115+ is_index_constant = false ;
116+ }
117+ if (const BufferLoadNode *v = obj.as <BufferLoadNode>()) {
118+ is_index_constant = false ;
116119 }
117120 });
118- if (!has_variable ) {
121+ if (is_index_constant ) {
119122 // If index is a constant, we can skip the check
120123 continue ;
121124 }
@@ -145,18 +148,16 @@ struct GlobalMemChecker : public StmtExprVisitor {
145148 bool recursively_collect_conds_;
146149};
147150
148- class SafeMemorysRewriter : public StmtExprMutator {
149- arith::Analyzer *analyzer_;
150-
151+ class SafeMemorysRewriter : public IRMutatorWithAnalyzer {
151152public:
152153 explicit SafeMemorysRewriter (Map<Buffer, PrimExpr> annotated_safe_value_map,
153154 arith::Analyzer *analyzer)
154- : annotated_safe_value_map_(std::move(annotated_safe_value_map) ),
155- analyzer_(analyzer ) {}
155+ : arith::IRMutatorWithAnalyzer(analyzer ),
156+ annotated_safe_value_map_(std::move(annotated_safe_value_map) ) {}
156157
157158private:
158159 PrimExpr VisitExpr_ (const BufferLoadNode *op) final {
159- auto load = Downcast<BufferLoad>(StmtExprMutator ::VisitExpr_ (op));
160+ auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer ::VisitExpr_ (op));
160161
161162 // For Load/Store, we only check the current node, not its children.
162163 // Since rewriter will recursively visit children.
@@ -181,7 +182,7 @@ class SafeMemorysRewriter : public StmtExprMutator {
181182
182183 Stmt VisitStmt_ (const BufferStoreNode *op) final {
183184 // Check if the buffer is in global scope
184- auto store = Downcast<BufferStore>(StmtExprMutator ::VisitStmt_ (op));
185+ auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer ::VisitStmt_ (op));
185186
186187 GlobalMemChecker checker (analyzer_, /* recursively_collect_conds=*/ false );
187188 checker (store);
@@ -226,7 +227,7 @@ class SafeMemorysRewriter : public StmtExprMutator {
226227 // directly applying the boundary constraints of all parameters to the
227228 // statement. While not entirely precise, it addresses most common scenarios.
228229 Stmt VisitStmt_ (const EvaluateNode *op) final {
229- auto evaluate = Downcast<Evaluate>(op );
230+ auto evaluate = Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_ (op) );
230231
231232 if (const CallNode *call_op = op->value .as <CallNode>()) {
232233 auto call = Downcast<Call>(op->value );
0 commit comments