diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index ce13ac56c81d..8b3598e3563d 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -153,8 +153,12 @@ void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) { std::vector relaxed_region; - for (const PrimExpr& index : op->indices) { + for (PrimExpr index : op->indices) { PrimExpr remapped_index = Substitute(index, let_bindings_); + while (!remapped_index.same_as(index)) { + index = remapped_index; + remapped_index = Substitute(index, let_bindings_); + } relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(remapped_index), dom_map_)); } Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region); @@ -236,8 +240,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) { std::vector relaxed_region; - for (const PrimExpr& index : op->indices) { + for (PrimExpr index : op->indices) { PrimExpr remapped_index = Substitute(index, let_bindings_); + while (!remapped_index.same_as(index)) { + index = remapped_index; + remapped_index = Substitute(index, let_bindings_); + } relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(remapped_index), dom_map_)); } Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region); diff --git a/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py b/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py index a65277df612d..1fa013399e12 100644 --- a/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py +++ b/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py @@ -385,5 +385,31 @@ def func( tvm.ir.assert_structural_equal(block.writes, ret[1]) +def test_buffer_access_with_nested_let_binding(): + @T.prim_func + def func( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), + ): + for i, s in T.grid(16, 16): + with T.block("copy"): + vi, vs = T.axis.remap("SS", [i, s]) + T.reads(A[vi, vs], B[vi, vs]) + T.writes(C[vi, vs]) + vi1: T.int32 = vi + vi2: T.int32 = vi1 + vs1: T.int32 = vs + vs2: T.int32 = vs1 + vs3: T.int32 = vs2 + C[vi, vs1] = A[vi1, vs2] + B[vi2, vs3] + + block = func.body.block.body.body.body.block + buffer_var_map = {buf.data: buf for buf in func.buffer_map.values()} + ret = tir.analysis.get_block_access_region(block, buffer_var_map) + tvm.ir.assert_structural_equal(block.reads, ret[0]) + tvm.ir.assert_structural_equal(block.writes, ret[1]) + + if __name__ == "__main__": tvm.testing.main()