diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 831dbcdd4aa8..79e14feee47d 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -269,7 +269,7 @@ Buffer AllocBuffer(Array shape, DataType dtype, Optional data, Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align, offset_factor, buffer_type_str, axis_separators); IRBuilder builder = IRBuilder::Current(); - if (Optional frame = builder->GetLastFrame()) { + if (Optional frame = builder->FindFrame()) { frame.value()->alloc_buffers.push_back(buffer); } else if (Optional frame = builder->GetLastFrame()) { frame.value()->root_alloc_buffers.push_back(buffer); diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index d5ee2e07729a..68e9adeff267 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -588,5 +588,28 @@ def func3(): T.evaluate(0) +def test_alloc_inside_block(): + @T.prim_func(private=True) + def func() -> None: + with T.block(): + A = T.alloc_buffer([10], "float32") + for i in T.serial(0, 10): + B = T.alloc_buffer([10], "float32") + for j in T.serial(0, 10): + B[j] = T.float32(j) + A[i] += B[j] + + @T.prim_func(private=True) + def expected() -> None: + with T.block(): + A = T.alloc_buffer([10], "float32") + B = T.alloc_buffer([10], "float32") + for i, j in T.grid(10, 10): + B[j] = T.float32(j) + A[i] += B[j] + + tvm.ir.assert_structural_equal(func, expected) + + if __name__ == "__main__": tvm.testing.main()