Skip to content

Commit f35958f

Browse files
committed
Update pyproject.toml to add Cython as a build dependency. Enhance thread storage synchronization in thread_storage_sync.cc by introducing new thread variable handling and improving index disjointness checks.
1 parent 25cfe30 commit f35958f

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ requires = [
1010
"auditwheel",
1111
"patchelf",
1212
"ninja",
13+
"Cython",
1314
]
1415
build-backend = "setuptools.build_meta"
1516

src/transform/thread_storage_sync.cc

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,9 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
246246
const auto &curr_indice = curr.buffer_indices[i];
247247

248248
if (!ExprDeepEqual()(prev_indice, curr_indice)) {
249-
auto prev_indice_bytes =
249+
PrimExpr prev_indice_bytes =
250250
analyzer_.Simplify(prev_indice * prev_dtype.bytes());
251-
auto curr_indice_bytes =
251+
PrimExpr curr_indice_bytes =
252252
analyzer_.Simplify(curr_indice * curr_dtype.bytes());
253253

254254
has_same_index = false;
@@ -277,6 +277,32 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
277277
continue;
278278
}
279279

280+
// provably disjoint means no overlap, for example:
281+
// we can prove that tx - 128 < tx + 128, tx in [0, 128]
282+
// However, we should apply tx split because
283+
// tx < tx + 32 when tx in [0, 128] is not disjoint
284+
// because [0, 128] is not disjoint with [32, 160]
285+
// so we should split tx into tx0 and tx1.
286+
287+
struct ThreadVarInfo {
288+
const char* name_prev;
289+
const char* name_curr;
290+
IterVar iv;
291+
} thread_vars[] = {
292+
{"tx1", "tx2", tx_},
293+
{"ty1", "ty2", ty_},
294+
{"tz1", "tz2", tz_},
295+
};
296+
297+
for (const auto& info : thread_vars) {
298+
Var prev_var(info.name_prev, prev_indice.dtype());
299+
Var curr_var(info.name_curr, curr_indice.dtype());
300+
analyzer_.Bind(prev_var, info.iv->dom);
301+
analyzer_.Bind(curr_var, info.iv->dom);
302+
prev_indice_bytes = Substitute(prev_indice_bytes, {{info.iv->var, prev_var}});
303+
curr_indice_bytes = Substitute(curr_indice_bytes, {{info.iv->var, curr_var}});
304+
}
305+
280306
bool provably_disjoint =
281307
analyzer_.CanProve(prev_indice_bytes < curr_indice_bytes,
282308
arith::ProofStrength::kSymbolicBound) ||
@@ -313,6 +339,16 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
313339
}
314340

315341
void VisitStmt_(const AttrStmtNode *op) final {
342+
if (op->attr_key == tvm::tir::attr::thread_extent) {
343+
IterVar iv = Downcast<IterVar>(op->node);
344+
if (iv->thread_tag == "threadIdx.x") {
345+
tx_ = iv;
346+
} else if (iv->thread_tag == "threadIdx.y") {
347+
ty_ = iv;
348+
} else if (iv->thread_tag == "threadIdx.z") {
349+
tz_ = iv;
350+
}
351+
}
316352
TileLangStorageAccessVisitor::VisitStmt_(op);
317353
}
318354

@@ -323,6 +359,15 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
323359
}
324360

325361
private:
362+
363+
364+
// Member variables
365+
IterVar tx_ =
366+
IterVar(Range::FromMinExtent(0, 1), Var("tx"), IterVarType::kDataPar);
367+
IterVar ty_ =
368+
IterVar(Range::FromMinExtent(0, 1), Var("ty"), IterVarType::kDataPar);
369+
IterVar tz_ =
370+
IterVar(Range::FromMinExtent(0, 1), Var("tz"), IterVarType::kDataPar);
326371
// synchronization scope
327372
StorageScope sync_scope_;
328373
};

0 commit comments

Comments
 (0)