Skip to content

Commit c3f1125

Browse files
committed
Refactor TileLangStorageAccessVisitor to replace buffer indices with buffer ranges for improved pointer access handling. Update AccessEntry structure to include buffer_ranges and adjust thread storage synchronization logic to account for pointer access conflicts.
1 parent c3c57fd commit c3f1125

File tree

3 files changed

+48
-12
lines changed

3 files changed

+48
-12
lines changed

src/transform/storage_access.cc

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,12 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
284284
const VarNode *buffer_var = buffer->data.as<VarNode>();
285285
buffer_data_to_buffer_.Set(GetRef<Var>(buffer_var), buffer);
286286
StorageScope scope = GetScope(GetRef<Var>(buffer_var));
287-
Array<PrimExpr> buffer_indices;
287+
Array<Range> buffer_ranges;
288288
// from indices to buffer indices
289289
ICHECK(buffer->shape.size() == load->indices.size());
290290
for (size_t i = 0; i < buffer->shape.size(); ++i) {
291-
buffer_indices.push_back(Ramp(load->indices[i], 1, buffer->shape[i]));
291+
buffer_ranges.push_back(
292+
Range::FromMinExtent(load->indices[i], buffer->shape[i]));
292293
}
293294
if (Enabled(buffer_var, scope)) {
294295
ICHECK(allow_append_);
@@ -297,7 +298,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
297298
e.thread_range = this->ComputeThreadRange(e.threads);
298299
e.dtype = dtype;
299300
e.buffer = Downcast<Var>(buffer->data);
300-
e.buffer_indices = buffer_indices;
301+
e.buffer_ranges = buffer_ranges;
301302
for (const auto &index : load->indices) {
302303
e.touched.push_back(arith::IntSet::Vector(index));
303304
}
@@ -321,14 +322,45 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
321322
// The buffer scope.
322323
if (Enabled(buffer_var, scope)) {
323324
ICHECK(allow_append_);
324-
Array<PrimExpr> buffer_indices;
325-
buffer_indices = {Ramp(offset, 1, extent)};
325+
Array<Range> buffer_ranges;
326+
if (buffer_data_to_buffer_.find(GetRef<Var>(buffer_var)) ==
327+
buffer_data_to_buffer_.end()) {
328+
// cannot find buffer map, use the default buffer
329+
buffer_ranges = {Range::FromMinExtent(offset, extent)};
330+
} else {
331+
Buffer buffer = buffer_data_to_buffer_.at(GetRef<Var>(buffer_var));
332+
auto buffer_shape = buffer->shape;
333+
// convert 1d offset to multi-dimensional index
334+
auto linear_to_indices = [this](PrimExpr offset,
335+
const Array<PrimExpr> &shape) {
336+
Array<PrimExpr> indices;
337+
PrimExpr remaining = offset;
338+
for (size_t i = 0; i < shape.size(); ++i) {
339+
PrimExpr stride = make_const(DataType::Int(32), 1);
340+
for (size_t j = i + 1; j < shape.size(); ++j) {
341+
stride = stride * shape[j];
342+
}
343+
PrimExpr idx = FloorDiv(remaining, stride);
344+
remaining = FloorMod(remaining, stride);
345+
indices.push_back(analyzer_.Simplify(idx));
346+
}
347+
return indices;
348+
};
349+
Array<PrimExpr> start_indices = linear_to_indices(offset, buffer_shape);
350+
Array<PrimExpr> end_indices =
351+
linear_to_indices(offset + extent, buffer_shape);
352+
for (size_t i = 0; i < buffer_shape.size(); ++i) {
353+
buffer_ranges.push_back(Range::FromMinExtent(
354+
start_indices[i],
355+
analyzer_.Simplify(end_indices[i] - start_indices[i])));
356+
}
357+
}
326358
AccessEntry e;
327359
e.threads = env_threads();
328360
e.thread_range = this->ComputeThreadRange(e.threads);
329361
e.dtype = dtype;
330362
e.buffer = GetRef<Var>(buffer_var);
331-
e.buffer_indices = buffer_indices;
363+
e.buffer_ranges = buffer_ranges;
332364
e.is_pointer_access = true;
333365
e.touched = {
334366
arith::IntSet::FromRange(Range::FromMinExtent(offset, extent))};

src/transform/storage_access.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ class TileLangStorageAccessVisitor : public IRVisitorWithAnalyzer {
6565
Map<Var, Range> thread_range;
6666
/*! \brief The buffer variable, if any */
6767
Array<PrimExpr> buffer_indices;
68+
/*! \brief The buffer ranges for pointer access */
69+
Array<Range> buffer_ranges;
6870
Var buffer = NullValue<Var>();
6971
/*! \brief The access data type */
7072
DataType dtype;

src/transform/thread_storage_sync.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,14 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
239239
return true;
240240
}
241241

242+
if (prev.is_pointer_access || curr.is_pointer_access) {
243+
// If either access is a pointer access, conservatively assume a
244+
// conflict. For example, address_of(A[0, 0]) may refer to an unknown
245+
// memory region, so we cannot safely determine if it overlaps with
246+
// previous accesses.
247+
return true;
248+
}
249+
242250
for (size_t i = 0; i < prev.buffer_indices.size(); i++) {
243251
auto prev_dtype = prev.dtype;
244252
auto curr_dtype = curr.dtype;
@@ -316,12 +324,6 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
316324
range_is_overlap = false;
317325
break;
318326
}
319-
} else if (prev.is_pointer_access || curr.is_pointer_access) {
320-
// If either access is a pointer access, conservatively assume a
321-
// conflict. For example, address_of(A[0, 0]) may refer to an unknown
322-
// memory region, so we cannot safely determine if it overlaps with
323-
// previous accesses.
324-
return true;
325327
}
326328

327329
if (!(has_same_index)) {

0 commit comments

Comments
 (0)