diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index c44083108d45..5ff1c65ca9e9 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -36,13 +36,10 @@ namespace tir { class ScriptCompleter : public StmtMutator { public: explicit ScriptCompleter(Map* buffer_var_map) : buffer_var_map_(buffer_var_map) {} - /*! \brief Whether the stmt contains at least one block. */ - bool contains_block = false; private: Map* buffer_var_map_; - Stmt VisitStmt_(const BlockRealizeNode* op) override { - contains_block = true; + Stmt VisitStmt_(const BlockRealizeNode* op) final { for (const PrimExpr& value : op->iter_values) { CHECK(value.dtype().is_int()) << "BlockRealize iter_value expected a IntImm, but got " << value.dtype(); @@ -50,7 +47,7 @@ class ScriptCompleter : public StmtMutator { return StmtMutator::VisitStmt_(op); } - Stmt VisitStmt_(const BlockNode* op) override { + Stmt VisitStmt_(const BlockNode* op) final { // Buffers allocated in the block can be accessed by its body. for (const auto& alloc_buffer : op->alloc_buffers) { buffer_var_map_->Set(alloc_buffer->data, alloc_buffer); @@ -59,7 +56,12 @@ class ScriptCompleter : public StmtMutator { const Buffer& target_buffer = match_buffer->buffer; buffer_var_map_->Set(target_buffer->data, target_buffer); } + + bool is_root_block = this->is_root_block_; + this->is_root_block_ = false; Block block = Downcast(StmtMutator::VisitStmt_(op)); + this->is_root_block_ = is_root_block; + // Remove buffers allocated inside block to detect its access region for (const auto& alloc_buffer : op->alloc_buffers) { buffer_var_map_->erase(alloc_buffer->data); @@ -85,8 +87,10 @@ class ScriptCompleter : public StmtMutator { << "ValueError: Can not auto detect buffer access region from tir.Load, tir.Store or " "direct access by buffer data. Please annotation the access region manually"; auto n = CopyOnWrite(block.operator->()); - if (mask & 1) n->reads = reads; - if (mask & 2) n->writes = writes; + if (!is_root_block) { + if (mask & 1) n->reads = reads; + if (mask & 2) n->writes = writes; + } n->annotations = op->annotations; n->annotations.erase(attr::script_parsing_detect_access); return Block(n); @@ -94,6 +98,8 @@ class ScriptCompleter : public StmtMutator { return std::move(block); } } + + bool is_root_block_ = true; }; PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py index fe7482665def..932a5d156c28 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -235,8 +235,6 @@ def layer_norm(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "f @T.prim_func def expected(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "float32"), C: T.Buffer((4, 4, 32), "float32"), T_layer_norm: T.Buffer((1, 4, 4, 32), "float32")): with T.block("root"): - T.reads(A[0, 0:4, 0:4, 0:32], B[0:4, 0:4, 0:32], C[0:4, 0:4, 0:32]) - T.writes(T_layer_norm[0, 0:4, 0:4, 0:32]) A_red_temp_v0 = T.alloc_buffer((1,)) A_red_temp_v1 = T.alloc_buffer((1,)) for ax0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}): diff --git a/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py b/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py index e45c879c42f1..2657863f7619 100644 --- a/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py +++ b/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py @@ -315,8 +315,6 @@ def main(a: T.handle, b: T.handle): A = T.match_buffer(a, (1024, 1024)) B = T.match_buffer(b, (1024, 1024)) with T.block("root"): - T.reads(A[0:1024, 0:1024]) - T.writes(B[0:1024, 0:1024]) T.block_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): @@ -583,8 +581,6 @@ class TransformedWmmaToGlobal: @T.prim_func def main(C: T.Buffer((1024, 1024), "float32")): with T.block("root"): - T.reads() - T.writes(C[0:1024, 0:1024]) T.block_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): @@ -785,8 +781,6 @@ def main(A: T.Buffer((1024,), "float32"), C: T.Buffer((1024, 1024), "float32")) s1 = T.int32() # body with T.block("root"): - T.reads(A[0:1024]) - T.writes(C[0:1024, 0:1024]) T.block_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): @@ -1009,8 +1003,6 @@ class TransformedMmaToGlobal: @T.prim_func def main(C: T.Buffer((1024, 1024), "float32")): with T.block("root"): - T.reads() - T.writes(C[0:1024, 0:1024]) T.block_attr({"warp_execution": T.bool(True)}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index 2f81b0302626..6d435a906e37 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -153,6 +153,10 @@ def test_complete_matmul_original(): def _check_elementwise(func): A, B, C = [func.buffer_map[x] for x in func.params] + root_block = func.body.block + assert len(root_block.reads) == 0 + assert len(root_block.writes) == 0 + block1 = func.body.block.body[0].body.body.block assert isinstance(block1, tvm.tir.Block) vi, vj = [x.var for x in block1.iter_vars]