@@ -83,8 +83,9 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
8383 // simulation based approach to find dependencies
8484 for (size_t i = 0 ; i < seq.size (); ++i) {
8585 const StmtEntry &s = seq[i];
86- const bool pre_marked_sync = (syncs_inserted_.count (s.stmt ) != 0 );
87- bool sync_before_stmt = pre_marked_sync;
86+ // check if sync before statement is needed.
87+ bool sync_before_stmt = (syncs_inserted_.count (s.stmt ) != 0 );
88+ // Apply the syncs added already.
8889
8990 if (sync_before_stmt) {
9091 reads.clear ();
@@ -108,12 +109,12 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
108109 writes.clear ();
109110 }
110111 }
111-
112+ // If sync is inserted. remove the irrelevant things.
112113 if (sync_before_stmt) {
113114 reads.clear ();
114115 writes.clear ();
115116 }
116-
117+ // Add the read/write of current statement
117118 for (const AccessEntry &acc : s.access ) {
118119 if (acc.type == kRead ) {
119120 reads.push_back (acc);
@@ -125,7 +126,7 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
125126 }
126127 }
127128
128- if (sync_before_stmt && !pre_marked_sync ) {
129+ if (sync_before_stmt) {
129130 insert_syncs (s.stmt );
130131 }
131132 }
@@ -244,10 +245,14 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
244245 return true ;
245246 }
246247 if (prev.is_pointer_access || curr.is_pointer_access ) {
247- // If either access is a pointer access, conservatively assume a
248- // conflict. For example, address_of(A[0, 0]) may refer to an unknown
249- // memory region, so we cannot safely determine if it overlaps with
250- // previous accesses.
248+ // For accesses created via tvm_access_ptr we may still be able to prove
249+ // disjointness using their byte ranges. If both sides expose a touched
250+ // interval and we can show they don't overlap, skip the conflict.
251+ if (prev.is_pointer_access && curr.is_pointer_access &&
252+ PointerAccessIsDisjoint (prev, curr)) {
253+ return false ;
254+ }
255+ // Otherwise fall back to the conservative answer: treat them as overlapping.
251256 return true ;
252257 }
253258
@@ -353,6 +358,27 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
353358 return range_is_overlap;
354359 }
355360
361+ bool PointerAccessIsDisjoint (const AccessEntry &lhs,
362+ const AccessEntry &rhs) {
363+ if (lhs.touched .size () != 1 || rhs.touched .size () != 1 ) {
364+ return false ;
365+ }
366+ PrimExpr lhs_min = analyzer_.Simplify (lhs.touched [0 ].min ());
367+ PrimExpr lhs_max = analyzer_.Simplify (lhs.touched [0 ].max ());
368+ PrimExpr rhs_min = analyzer_.Simplify (rhs.touched [0 ].min ());
369+ PrimExpr rhs_max = analyzer_.Simplify (rhs.touched [0 ].max ());
370+
371+ if (analyzer_.CanProve (lhs_max < rhs_min,
372+ arith::ProofStrength::kSymbolicBound )) {
373+ return true ;
374+ }
375+ if (analyzer_.CanProve (rhs_max < lhs_min,
376+ arith::ProofStrength::kSymbolicBound )) {
377+ return true ;
378+ }
379+ return false ;
380+ }
381+
356382 void VisitStmt_ (const AttrStmtNode *op) final {
357383 if (op->attr_key == tvm::tir::attr::thread_extent) {
358384 IterVar iv = Downcast<IterVar>(op->node );
0 commit comments