Skip to content

Commit b858874

Browse files
committed
fix
1 parent d95e708 commit b858874

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

src/transform/thread_storage_sync.cc

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,9 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
8383
// simulation based approach to find dependencies
8484
for (size_t i = 0; i < seq.size(); ++i) {
8585
const StmtEntry &s = seq[i];
86-
const bool pre_marked_sync = (syncs_inserted_.count(s.stmt) != 0);
87-
bool sync_before_stmt = pre_marked_sync;
86+
// check if sync before statement is needed.
87+
bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
88+
// Apply the syncs added already.
8889

8990
if (sync_before_stmt) {
9091
reads.clear();
@@ -108,12 +109,12 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
108109
writes.clear();
109110
}
110111
}
111-
112+
// If sync is inserted. remove the irrelevant things.
112113
if (sync_before_stmt) {
113114
reads.clear();
114115
writes.clear();
115116
}
116-
117+
// Add the read/write of current statement
117118
for (const AccessEntry &acc : s.access) {
118119
if (acc.type == kRead) {
119120
reads.push_back(acc);
@@ -125,7 +126,7 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
125126
}
126127
}
127128

128-
if (sync_before_stmt && !pre_marked_sync) {
129+
if (sync_before_stmt) {
129130
insert_syncs(s.stmt);
130131
}
131132
}
@@ -244,10 +245,14 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
244245
return true;
245246
}
246247
if (prev.is_pointer_access || curr.is_pointer_access) {
247-
// If either access is a pointer access, conservatively assume a
248-
// conflict. For example, address_of(A[0, 0]) may refer to an unknown
249-
// memory region, so we cannot safely determine if it overlaps with
250-
// previous accesses.
248+
// For accesses created via tvm_access_ptr we may still be able to prove
249+
// disjointness using their byte ranges. If both sides expose a touched
250+
// interval and we can show they don't overlap, skip the conflict.
251+
if (prev.is_pointer_access && curr.is_pointer_access &&
252+
PointerAccessIsDisjoint(prev, curr)) {
253+
return false;
254+
}
255+
// Otherwise fall back to the conservative answer: treat them as overlapping.
251256
return true;
252257
}
253258

@@ -353,6 +358,27 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
353358
return range_is_overlap;
354359
}
355360

361+
bool PointerAccessIsDisjoint(const AccessEntry &lhs,
362+
const AccessEntry &rhs) {
363+
if (lhs.touched.size() != 1 || rhs.touched.size() != 1) {
364+
return false;
365+
}
366+
PrimExpr lhs_min = analyzer_.Simplify(lhs.touched[0].min());
367+
PrimExpr lhs_max = analyzer_.Simplify(lhs.touched[0].max());
368+
PrimExpr rhs_min = analyzer_.Simplify(rhs.touched[0].min());
369+
PrimExpr rhs_max = analyzer_.Simplify(rhs.touched[0].max());
370+
371+
if (analyzer_.CanProve(lhs_max < rhs_min,
372+
arith::ProofStrength::kSymbolicBound)) {
373+
return true;
374+
}
375+
if (analyzer_.CanProve(rhs_max < lhs_min,
376+
arith::ProofStrength::kSymbolicBound)) {
377+
return true;
378+
}
379+
return false;
380+
}
381+
356382
void VisitStmt_(const AttrStmtNode *op) final {
357383
if (op->attr_key == tvm::tir::attr::thread_extent) {
358384
IterVar iv = Downcast<IterVar>(op->node);

0 commit comments

Comments
 (0)