Skip to content

Commit 470eb74

Browse files
Improve memory access safety and T.assume handling (#1292)
* Improve memory access safety and T.assume handling * Improve memory access safety and T.assume handling * bugfix * lint fix * bugfix * bugfix * refactor legalize safe memory access pass --------- Co-authored-by: Lei Wang <leiwang1999@outlook.com>
1 parent 0d101c1 commit 470eb74

File tree

2 files changed

+58
-120
lines changed

2 files changed

+58
-120
lines changed

src/transform/legalize_safe_memory_access.cc

Lines changed: 48 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,6 @@ namespace tl {
2424
using namespace tir;
2525
using arith::IRMutatorWithAnalyzer;
2626

27-
// Helper class to find leaf For nodes in a given IR
28-
class LeafForFinder : public StmtVisitor {
29-
public:
30-
std::vector<For> leaf_for_nodes;
31-
32-
private:
33-
void VisitStmt_(const ForNode *op) final {
34-
has_child_for_ = false;
35-
bool parent_has_child_for = parent_has_child_for_;
36-
parent_has_child_for_ = false;
37-
38-
StmtVisitor::VisitStmt(op->body);
39-
40-
if (!has_child_for_) {
41-
leaf_for_nodes.push_back(tvm::ffi::GetRef<For>(op));
42-
}
43-
44-
parent_has_child_for_ = parent_has_child_for;
45-
parent_has_child_for_ = true;
46-
}
47-
48-
private:
49-
bool has_child_for_ = false;
50-
bool parent_has_child_for_ = false;
51-
};
52-
5327
// GlobalMemChecker for a BufferLoad/BufferStore node:
5428
// 1. Identify BufferLoad and BufferStore nodes.
5529
// 2. Check if the buffer is in global scope.
@@ -109,13 +83,16 @@ struct GlobalMemChecker : public StmtExprVisitor {
10983
PrimExpr index = indices[i];
11084
PrimExpr shape_dim = buffer->shape[i];
11185

112-
bool has_variable = false;
86+
bool is_index_constant = true;
11387
PostOrderVisit(index, [&](const ObjectRef &obj) {
11488
if (const VarNode *v = obj.as<VarNode>()) {
115-
has_variable = true;
89+
is_index_constant = false;
90+
}
91+
if (const BufferLoadNode *v = obj.as<BufferLoadNode>()) {
92+
is_index_constant = false;
11693
}
11794
});
118-
if (!has_variable) {
95+
if (is_index_constant) {
11996
// If index is a constant, we can skip the check
12097
continue;
12198
}
@@ -145,18 +122,31 @@ struct GlobalMemChecker : public StmtExprVisitor {
145122
bool recursively_collect_conds_;
146123
};
147124

148-
class SafeMemorysRewriter : public StmtExprMutator {
149-
arith::Analyzer *analyzer_;
150-
125+
class SafeMemorysRewriter : public IRMutatorWithAnalyzer {
151126
public:
152-
explicit SafeMemorysRewriter(Map<Buffer, PrimExpr> annotated_safe_value_map,
153-
arith::Analyzer *analyzer)
154-
: annotated_safe_value_map_(std::move(annotated_safe_value_map)),
155-
analyzer_(analyzer) {}
127+
// Static method to substitute and transform the given PrimFunc
128+
static PrimFunc Substitute(PrimFunc f) {
129+
arith::Analyzer analyzer;
130+
// Create an instance of the legalizer with the analyzer
131+
SafeMemorysRewriter substituter(&analyzer);
132+
// Get a mutable copy of the function node
133+
PrimFuncNode *fptr = f.CopyOnWrite();
134+
for (const auto &[_, buffer] : f->buffer_map) {
135+
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
136+
}
137+
// Apply the legalizer to the function body
138+
fptr->body = substituter.VisitStmt(f->body);
139+
return f;
140+
}
156141

157142
private:
143+
// Constructor initializing the base class with the analyzer
144+
SafeMemorysRewriter(arith::Analyzer *analyzer)
145+
: arith::IRMutatorWithAnalyzer(analyzer) {}
146+
// Constructor initializing the base class with the analyzer
147+
158148
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
159-
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
149+
auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
160150

161151
// For Load/Store, we only check the current node, not its children.
162152
// Since rewriter will recursively visit children.
@@ -181,7 +171,7 @@ class SafeMemorysRewriter : public StmtExprMutator {
181171

182172
Stmt VisitStmt_(const BufferStoreNode *op) final {
183173
// Check if the buffer is in global scope
184-
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
174+
auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
185175

186176
GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
187177
checker(store);
@@ -253,6 +243,25 @@ class SafeMemorysRewriter : public StmtExprMutator {
253243
return evaluate;
254244
}
255245

246+
Stmt VisitStmt_(const BlockNode *op) final {
247+
for (auto buffer : op->alloc_buffers) {
248+
buffer_data_to_buffer_.Set(buffer->data, buffer);
249+
}
250+
if (op->annotations.count(attr::kSafeValueMap)) {
251+
auto map = op->annotations.Get(attr::kSafeValueMap)
252+
->as<Map<Var, PrimExpr>>()
253+
.value();
254+
for (const auto &[var, safe_value] : map) {
255+
ICHECK(buffer_data_to_buffer_.count(var))
256+
<< "buffer " << var << " is not found in the block "
257+
<< buffer_data_to_buffer_;
258+
auto buffer = buffer_data_to_buffer_[var];
259+
annotated_safe_value_map_.Set(buffer, safe_value);
260+
}
261+
}
262+
return IRMutatorWithAnalyzer::VisitStmt_(op);
263+
}
264+
256265
bool IsLocalBuffer(const Buffer &buffer) {
257266
String scope = buffer.scope();
258267
return scope == "local" || scope == "local.fragment" ||
@@ -276,87 +285,6 @@ class SafeMemorysRewriter : public StmtExprMutator {
276285
return make_zero(buffer->dtype);
277286
}
278287

279-
Map<Buffer, PrimExpr> annotated_safe_value_map_;
280-
};
281-
282-
// Class to legalize safe memory access by transforming them appropriately
283-
class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
284-
public:
285-
// Static method to substitute and transform the given PrimFunc
286-
static PrimFunc Substitute(PrimFunc f) {
287-
arith::Analyzer analyzer;
288-
// Create an instance of the legalizer with the analyzer
289-
SafeMemoryLegalizer substituter(&analyzer);
290-
// Get a mutable copy of the function node
291-
PrimFuncNode *fptr = f.CopyOnWrite();
292-
for (const auto &[_, buffer] : f->buffer_map) {
293-
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
294-
}
295-
// Apply the legalizer to the function body
296-
fptr->body = substituter.VisitStmt(f->body);
297-
return f;
298-
}
299-
300-
private:
301-
// Constructor initializing the base class with the analyzer
302-
SafeMemoryLegalizer(arith::Analyzer *analyzer)
303-
: arith::IRMutatorWithAnalyzer(analyzer) {}
304-
305-
// Override the VisitStmt_ method to handle ForNode (loop statements)
306-
Stmt VisitStmt_(const ForNode *op) final {
307-
// Visit and potentially modify the loop node
308-
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
309-
auto has_inner_loop = HasInnerLoop(for_node->body);
310-
if (!has_inner_loop) {
311-
SafeMemorysRewriter rewriter(annotated_safe_value_map_, analyzer_);
312-
for_node.CopyOnWrite()->body = rewriter(for_node->body);
313-
// // Detect Buffer Load Node in the loop body, collect the indices and
314-
// buffer size
315-
316-
// // Run the checker on the loop body
317-
// GlobalMemChecker checker(analyzer_);
318-
// checker(for_node->body);
319-
// Array<PrimExpr> conditions = checker.GetConditions();
320-
// auto body = for_node->body;
321-
// // Note that we might have duplicate conditions
322-
// // Which will be optimized by simplify pass
323-
// // Replace the loop body with the new body
324-
// for (auto cond : conditions) {
325-
// body = IfThenElse(cond, body);
326-
// }
327-
// for_node.CopyOnWrite()->body = body;
328-
return std::move(for_node);
329-
}
330-
331-
// Visit a For Node
332-
return IRMutatorWithAnalyzer::VisitStmt_(op);
333-
}
334-
335-
Stmt VisitStmt_(const BlockNode *op) final {
336-
for (auto buffer : op->alloc_buffers) {
337-
buffer_data_to_buffer_.Set(buffer->data, buffer);
338-
}
339-
if (op->annotations.count(attr::kSafeValueMap)) {
340-
auto map = op->annotations.Get(attr::kSafeValueMap)
341-
->as<Map<Var, PrimExpr>>()
342-
.value();
343-
for (const auto &[var, safe_value] : map) {
344-
ICHECK(buffer_data_to_buffer_.count(var))
345-
<< "buffer " << var << " is not found in the block "
346-
<< buffer_data_to_buffer_;
347-
auto buffer = buffer_data_to_buffer_[var];
348-
annotated_safe_value_map_.Set(buffer, safe_value);
349-
}
350-
}
351-
return IRMutatorWithAnalyzer::VisitStmt_(op);
352-
}
353-
354-
static bool HasInnerLoop(const Stmt &stmt) {
355-
LeafForFinder finder;
356-
finder(stmt);
357-
return !finder.leaf_for_nodes.empty();
358-
}
359-
360288
Map<Var, Buffer> buffer_data_to_buffer_;
361289
Map<Buffer, PrimExpr> annotated_safe_value_map_;
362290
};
@@ -371,7 +299,7 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
371299
if (disable_safe_memory_legalize) {
372300
return f;
373301
}
374-
return SafeMemoryLegalizer::Substitute(std::move(f));
302+
return SafeMemorysRewriter::Substitute(std::move(f));
375303
};
376304
// Create and return a PrimFunc pass with the transformation function
377305
return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeSafeMemoryAccess", {});

src/transform/simplify.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,16 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
465465
return std::move(store);
466466
}
467467

468+
Stmt VisitStmt_(const AttrStmtNode *op) override {
469+
if (op->attr_key == "tl.assume") {
470+
PrimExpr condition = this->VisitExpr(Downcast<PrimExpr>(op->node));
471+
auto n = CopyOnWrite(op);
472+
n->node = std::move(condition);
473+
return Parent::VisitStmt_(n.get());
474+
}
475+
return Parent::VisitStmt_(op);
476+
}
477+
468478
private:
469479
bool ArrayDeepEqual(const Array<PrimExpr> &lhs, const Array<PrimExpr> &rhs) {
470480
if (lhs.size() != rhs.size()) {

0 commit comments

Comments
 (0)