@@ -246,9 +246,9 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
246246 const auto &curr_indice = curr.buffer_indices [i];
247247
248248 if (!ExprDeepEqual ()(prev_indice, curr_indice)) {
249- auto prev_indice_bytes =
249+ PrimExpr prev_indice_bytes =
250250 analyzer_.Simplify (prev_indice * prev_dtype.bytes ());
251- auto curr_indice_bytes =
251+ PrimExpr curr_indice_bytes =
252252 analyzer_.Simplify (curr_indice * curr_dtype.bytes ());
253253
254254 has_same_index = false ;
@@ -277,6 +277,32 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
277277 continue ;
278278 }
279279
280+ // provably disjoint means no overlap, for example:
281+ // we can prove that tx - 128 < tx + 128, tx in [0, 128]
282+ // However, we should apply tx split because
283+ // tx < tx + 32 when tx in [0, 128] is not disjoint
284+ // because [0, 128] is not disjoint with [32, 160]
285+ // so we should split tx into tx0 and tx1.
286+
287+ struct ThreadVarInfo {
288+ const char * name_prev;
289+ const char * name_curr;
290+ IterVar iv;
291+ } thread_vars[] = {
292+ {" tx1" , " tx2" , tx_},
293+ {" ty1" , " ty2" , ty_},
294+ {" tz1" , " tz2" , tz_},
295+ };
296+
297+ for (const auto & info : thread_vars) {
298+ Var prev_var (info.name_prev , prev_indice.dtype ());
299+ Var curr_var (info.name_curr , curr_indice.dtype ());
300+ analyzer_.Bind (prev_var, info.iv ->dom );
301+ analyzer_.Bind (curr_var, info.iv ->dom );
302+ prev_indice_bytes = Substitute (prev_indice_bytes, {{info.iv ->var , prev_var}});
303+ curr_indice_bytes = Substitute (curr_indice_bytes, {{info.iv ->var , curr_var}});
304+ }
305+
280306 bool provably_disjoint =
281307 analyzer_.CanProve (prev_indice_bytes < curr_indice_bytes,
282308 arith::ProofStrength::kSymbolicBound ) ||
@@ -313,6 +339,16 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
313339 }
314340
315341 void VisitStmt_ (const AttrStmtNode *op) final {
342+ if (op->attr_key == tvm::tir::attr::thread_extent) {
343+ IterVar iv = Downcast<IterVar>(op->node );
344+ if (iv->thread_tag == " threadIdx.x" ) {
345+ tx_ = iv;
346+ } else if (iv->thread_tag == " threadIdx.y" ) {
347+ ty_ = iv;
348+ } else if (iv->thread_tag == " threadIdx.z" ) {
349+ tz_ = iv;
350+ }
351+ }
316352 TileLangStorageAccessVisitor::VisitStmt_ (op);
317353 }
318354
@@ -323,6 +359,15 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
323359 }
324360
325361private:
362+
363+
364+ // Member variables
365+ IterVar tx_ =
366+ IterVar (Range::FromMinExtent(0 , 1 ), Var(" tx" ), IterVarType::kDataPar );
367+ IterVar ty_ =
368+ IterVar (Range::FromMinExtent(0 , 1 ), Var(" ty" ), IterVarType::kDataPar );
369+ IterVar tz_ =
370+ IterVar (Range::FromMinExtent(0 , 1 ), Var(" tz" ), IterVarType::kDataPar );
326371 // synchronization scope
327372 StorageScope sync_scope_;
328373};
0 commit comments