Skip to content

Commit acb8cf2

Browse files
committed
Refactor linear index conversion in TileLangStorageAccessVisitor to utilize the analyzer for simplification. Update buffer index calculations to ensure consistent simplification of range expressions.
1 parent 5559ba8 commit acb8cf2

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/transform/storage_access.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,8 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
330330
Buffer buffer = buffer_data_to_buffer_.at(GetRef<Var>(buffer_var));
331331
auto buffer_shape = buffer->shape;
332332
// convert 1d offset to multi-dimensional index
333-
auto linear_to_indices = [](PrimExpr offset,
334-
const Array<PrimExpr> &shape) {
333+
auto linear_to_indices = [this](PrimExpr offset,
334+
const Array<PrimExpr> &shape) {
335335
Array<PrimExpr> indices;
336336
PrimExpr remaining = offset;
337337
for (size_t i = 0; i < shape.size(); ++i) {
@@ -341,7 +341,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
341341
}
342342
PrimExpr idx = FloorDiv(remaining, stride);
343343
remaining = FloorMod(remaining, stride);
344-
indices.push_back(idx);
344+
indices.push_back(analyzer_.Simplify(idx));
345345
}
346346
return indices;
347347
};
@@ -350,7 +350,8 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
350350
linear_to_indices(offset + extent, buffer_shape);
351351
for (size_t i = 0; i < buffer_shape.size(); ++i) {
352352
buffer_indices.push_back(
353-
Ramp(start_indices[i], 1, end_indices[i] - start_indices[i]));
353+
Ramp(start_indices[i], 1,
354+
analyzer_.Simplify(end_indices[i] - start_indices[i])));
354355
}
355356
}
356357

0 commit comments

Comments
 (0)