Skip to content

Commit 722c2a8

Browse files
authored
[Bugfix] Consider buffer data type into indices provably disjoint analysis (#664)
1 parent a16f0cf commit 722c2a8

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

src/transform/thread_storage_sync.cc

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -273,38 +273,48 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
273273
}
274274

275275
for (size_t i = 0; i < prev.buffer_indices.size(); i++) {
276+
auto prev_dtype = prev.dtype;
277+
auto curr_dtype = curr.dtype;
278+
276279
const auto &prev_indice = prev.buffer_indices[i];
277280
const auto &curr_indice = curr.buffer_indices[i];
281+
278282
if (!ExprDeepEqual()(prev_indice, curr_indice)) {
283+
auto prev_indice_bytes =
284+
analyzer_.Simplify(prev_indice * prev_dtype.bytes());
285+
auto curr_indice_bytes =
286+
analyzer_.Simplify(curr_indice * curr_dtype.bytes());
287+
279288
has_same_index = false;
280289

281290
// If both are const, we can check if they are disjoint
282291
// by checking if the bounds are disjoint
283292
// [1024, 2048], [2048, 3072] are disjoint
284293
// [1024, 2048], [1024, 1024] are not disjoint
285-
auto prev_bound = analyzer_.const_int_bound(prev_indice);
286-
auto curr_bound = analyzer_.const_int_bound(curr_indice);
294+
auto prev_bound = analyzer_.const_int_bound(prev_indice_bytes);
295+
auto curr_bound = analyzer_.const_int_bound(curr_indice_bytes);
287296
if (prev_bound.defined() && curr_bound.defined()) {
288-
if (prev_bound->min_value > curr_bound->max_value ||
289-
curr_bound->min_value > prev_bound->max_value) {
297+
if ((prev_bound->min_value) > (curr_bound->max_value) ||
298+
(curr_bound->min_value) > (prev_bound->max_value)) {
290299
range_is_overlap = false;
291300
break;
292301
}
293302
}
294303

295304
// if we can prove prev_indice < curr_indice or prev_indice >
296305
// curr_indice, then they are not overlap
297-
auto prev_dtype = prev_indice.dtype();
298-
auto curr_dtype = curr_indice.dtype();
299-
if (prev_dtype.lanes() != curr_dtype.lanes()) {
306+
auto prev_indices_dtype = prev_indice.dtype();
307+
auto curr_indices_dtype = curr_indice.dtype();
308+
if (prev_indices_dtype.lanes() != curr_indices_dtype.lanes()) {
300309
// can not support different lanes binary op like <, >, <=, >=
301310
// skip otherwise it will lead to error
302311
continue;
303312
}
313+
304314
bool provably_disjoint =
305-
analyzer_.CanProve(prev_indice < curr_indice,
315+
analyzer_.CanProve(prev_indice_bytes < curr_indice_bytes,
306316
arith::ProofStrength::kSymbolicBound) ||
307-
analyzer_.CanProve(prev_indice > curr_indice,
317+
analyzer_.CanProve(prev_indice_bytes > curr_indice_bytes,
308318
arith::ProofStrength::kSymbolicBound);
309319

310320
if (provably_disjoint) {

0 commit comments

Comments
 (0)