@@ -284,11 +284,12 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
284284 const VarNode *buffer_var = buffer->data .as <VarNode>();
285285 buffer_data_to_buffer_.Set (GetRef<Var>(buffer_var), buffer);
286286 StorageScope scope = GetScope (GetRef<Var>(buffer_var));
287- Array<PrimExpr> buffer_indices ;
287+ Array<Range> buffer_ranges ;
288288 // from indices to buffer indices
289289 ICHECK (buffer->shape .size () == load->indices .size ());
290290 for (size_t i = 0 ; i < buffer->shape .size (); ++i) {
291- buffer_indices.push_back (Ramp (load->indices [i], 1 , buffer->shape [i]));
291+ buffer_ranges.push_back (
292+ Range::FromMinExtent (load->indices [i], buffer->shape [i]));
292293 }
293294 if (Enabled (buffer_var, scope)) {
294295 ICHECK (allow_append_);
@@ -297,7 +298,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
297298 e.thread_range = this ->ComputeThreadRange (e.threads );
298299 e.dtype = dtype;
299300 e.buffer = Downcast<Var>(buffer->data );
300- e.buffer_indices = buffer_indices ;
301+ e.buffer_ranges = buffer_ranges ;
301302 for (const auto &index : load->indices ) {
302303 e.touched .push_back (arith::IntSet::Vector (index));
303304 }
@@ -321,14 +322,45 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
321322 // The buffer scope.
322323 if (Enabled (buffer_var, scope)) {
323324 ICHECK (allow_append_);
324- Array<PrimExpr> buffer_indices;
325- buffer_indices = {Ramp (offset, 1 , extent)};
325+ Array<Range> buffer_ranges;
326+ if (buffer_data_to_buffer_.find (GetRef<Var>(buffer_var)) ==
327+ buffer_data_to_buffer_.end ()) {
328+ // cannot find buffer map, use the default buffer
329+ buffer_ranges = {Range::FromMinExtent (offset, extent)};
330+ } else {
331+ Buffer buffer = buffer_data_to_buffer_.at (GetRef<Var>(buffer_var));
332+ auto buffer_shape = buffer->shape ;
333+ // convert 1d offset to multi-dimensional index
334+ auto linear_to_indices = [this ](PrimExpr offset,
335+ const Array<PrimExpr> &shape) {
336+ Array<PrimExpr> indices;
337+ PrimExpr remaining = offset;
338+ for (size_t i = 0 ; i < shape.size (); ++i) {
339+ PrimExpr stride = make_const (DataType::Int (32 ), 1 );
340+ for (size_t j = i + 1 ; j < shape.size (); ++j) {
341+ stride = stride * shape[j];
342+ }
343+ PrimExpr idx = FloorDiv (remaining, stride);
344+ remaining = FloorMod (remaining, stride);
345+ indices.push_back (analyzer_.Simplify (idx));
346+ }
347+ return indices;
348+ };
349+ Array<PrimExpr> start_indices = linear_to_indices (offset, buffer_shape);
350+ Array<PrimExpr> end_indices =
351+ linear_to_indices (offset + extent, buffer_shape);
352+ for (size_t i = 0 ; i < buffer_shape.size (); ++i) {
353+ buffer_ranges.push_back (Range::FromMinExtent (
354+ start_indices[i],
355+ analyzer_.Simplify (end_indices[i] - start_indices[i])));
356+ }
357+ }
326358 AccessEntry e;
327359 e.threads = env_threads ();
328360 e.thread_range = this ->ComputeThreadRange (e.threads );
329361 e.dtype = dtype;
330362 e.buffer = GetRef<Var>(buffer_var);
331- e.buffer_indices = buffer_indices ;
363+ e.buffer_ranges = buffer_ranges ;
332364 e.is_pointer_access = true ;
333365 e.touched = {
334366 arith::IntSet::FromRange (Range::FromMinExtent (offset, extent))};
0 commit comments