@@ -273,38 +273,48 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
273273 }
274274
275275 for (size_t i = 0 ; i < prev.buffer_indices .size (); i++) {
276+ auto prev_dtype = prev.dtype ;
277+ auto curr_dtype = curr.dtype ;
278+
276279 const auto &prev_indice = prev.buffer_indices [i];
277280 const auto &curr_indice = curr.buffer_indices [i];
281+
278282 if (!ExprDeepEqual ()(prev_indice, curr_indice)) {
283+ auto prev_indice_bytes =
284+ analyzer_.Simplify (prev_indice * prev_dtype.bytes ());
285+ auto curr_indice_bytes =
286+ analyzer_.Simplify (curr_indice * curr_dtype.bytes ());
287+
279288 has_same_index = false ;
280289
281290 // If both are const, we can check if they are disjoint
282291 // by checking if the bounds are disjoint
283292 // [1024, 2048], [2048, 3072] are disjoint
284293 // [1024, 2048], [1024, 1024] are not disjoint
285- auto prev_bound = analyzer_.const_int_bound (prev_indice );
286- auto curr_bound = analyzer_.const_int_bound (curr_indice );
294+ auto prev_bound = analyzer_.const_int_bound (prev_indice_bytes );
295+ auto curr_bound = analyzer_.const_int_bound (curr_indice_bytes );
287296 if (prev_bound.defined () && curr_bound.defined ()) {
288- if (prev_bound->min_value > curr_bound->max_value ||
289- curr_bound->min_value > prev_bound->max_value ) {
297+ if (( prev_bound->min_value ) > ( curr_bound->max_value ) ||
298+ ( curr_bound->min_value ) > ( prev_bound->max_value ) ) {
290299 range_is_overlap = false ;
291300 break ;
292301 }
293302 }
294303
295304 // if we can prove prev_indice < curr_indice or prev_indice >
296305 // curr_indice, then they are not overlap
297- auto prev_dtype = prev_indice.dtype ();
298- auto curr_dtype = curr_indice.dtype ();
299- if (prev_dtype .lanes () != curr_dtype .lanes ()) {
306+ auto prev_indices_dtype = prev_indice.dtype ();
307+ auto curr_indices_dtype = curr_indice.dtype ();
308+ if (prev_indices_dtype .lanes () != curr_indices_dtype .lanes ()) {
300309 // can not support different lanes binary op like <, >, <=, >=
301310 // skip otherwise it will lead to error
302311 continue ;
303312 }
313+
304314 bool provably_disjoint =
305- analyzer_.CanProve (prev_indice < curr_indice ,
315+ analyzer_.CanProve (prev_indice_bytes < curr_indice_bytes ,
306316 arith::ProofStrength::kSymbolicBound ) ||
307- analyzer_.CanProve (prev_indice > curr_indice ,
317+ analyzer_.CanProve (prev_indice_bytes > curr_indice_bytes ,
308318 arith::ProofStrength::kSymbolicBound );
309319
310320 if (provably_disjoint) {
0 commit comments