Skip to content

Commit 3e12bc5

Browse files
committed
Improve memory access safety and T.assume handling
1 parent bccb648 commit 3e12bc5

File tree

3 files changed

+23
-12
lines changed

3 files changed

+23
-12
lines changed

3rdparty/tvm

Submodule tvm updated from f4affc7 to 18a30cd

src/transform/legalize_safe_memory_access.cc

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,16 @@ struct GlobalMemChecker : public StmtExprVisitor {
109109
PrimExpr index = indices[i];
110110
PrimExpr shape_dim = buffer->shape[i];
111111

112-
bool has_variable = false;
112+
bool is_index_constant = true;
113113
PostOrderVisit(index, [&](const ObjectRef &obj) {
114114
if (const VarNode *v = obj.as<VarNode>()) {
115-
has_variable = true;
115+
is_index_constant = false;
116+
}
117+
if (const BufferLoadNode *v = obj.as<BufferLoadNode>()) {
118+
is_index_constant = false;
116119
}
117120
});
118-
if (!has_variable) {
121+
if (is_index_constant) {
119122
// If index is a constant, we can skip the check
120123
continue;
121124
}
@@ -145,18 +148,16 @@ struct GlobalMemChecker : public StmtExprVisitor {
145148
bool recursively_collect_conds_;
146149
};
147150

148-
class SafeMemorysRewriter : public StmtExprMutator {
149-
arith::Analyzer *analyzer_;
150-
151+
class SafeMemorysRewriter : public IRMutatorWithAnalyzer {
151152
public:
152153
explicit SafeMemorysRewriter(Map<Buffer, PrimExpr> annotated_safe_value_map,
153154
arith::Analyzer *analyzer)
154-
: annotated_safe_value_map_(std::move(annotated_safe_value_map)),
155-
analyzer_(analyzer) {}
155+
: arith::IRMutatorWithAnalyzer(analyzer),
156+
annotated_safe_value_map_(std::move(annotated_safe_value_map)) {}
156157

157158
private:
158159
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
159-
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
160+
auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
160161

161162
// For Load/Store, we only check the current node, not its children.
162163
// Since rewriter will recursively visit children.
@@ -181,7 +182,7 @@ class SafeMemorysRewriter : public StmtExprMutator {
181182

182183
Stmt VisitStmt_(const BufferStoreNode *op) final {
183184
// Check if the buffer is in global scope
184-
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
185+
auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
185186

186187
GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
187188
checker(store);
@@ -226,7 +227,7 @@ class SafeMemorysRewriter : public StmtExprMutator {
226227
// directly applying the boundary constraints of all parameters to the
227228
// statement. While not entirely precise, it addresses most common scenarios.
228229
Stmt VisitStmt_(const EvaluateNode *op) final {
229-
auto evaluate = Downcast<Evaluate>(op);
230+
auto evaluate = Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op));
230231

231232
if (const CallNode *call_op = op->value.as<CallNode>()) {
232233
auto call = Downcast<Call>(op->value);

src/transform/simplify.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,16 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
465465
return std::move(store);
466466
}
467467

468+
Stmt VisitStmt_(const AttrStmtNode* op) override {
469+
if (op->attr_key == "tl.assume") {
470+
PrimExpr condition = this->VisitExpr(Downcast<PrimExpr>(op->node));
471+
auto n = CopyOnWrite(op);
472+
n->node = std::move(condition);
473+
return Parent::VisitStmt_(n.get());
474+
}
475+
return Parent::VisitStmt_(op);
476+
}
477+
468478
private:
469479
bool ArrayDeepEqual(const Array<PrimExpr> &lhs, const Array<PrimExpr> &rhs) {
470480
if (lhs.size() != rhs.size()) {

0 commit comments

Comments
 (0)