From 29bc3e07cbba98eb1437ca071efc319867b0c835 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 22 Mar 2022 21:56:01 -0700 Subject: [PATCH] Use local complete block and local reduction block to identify compact dataflow (#10705) * inint * upd * upd * remove redundant print * upd * change the reads/writes region for argmin/val * fix wrong push --- src/tir/schedule/analysis.h | 9 +- src/tir/schedule/analysis/analysis.cc | 109 +++++++++++++----- src/tir/schedule/primitive/for_kind.cc | 4 +- .../unittest/test_tir_schedule_for_kind.py | 89 ++++++++++++++ 4 files changed, 173 insertions(+), 38 deletions(-) diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 9c6d1e6e96da5..d398f22ed4672 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -168,16 +168,11 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl /*! * \brief Check the subtree compact dataflow property. The scope root may have one or more subtrees * rooted at its direct children, and this property requires all the blocks of the subtree - * that the specified sref is in to be complete block or reduction block. + * that the specified sref is in to be local complete block or local reduction block. * \param self The schedule state * \param subtree_root The sref of the subtree root to be checked - * \param scope_root_sref The scope root of the block - * \throw ScheduleError If the subtree that the sref is in doesn't satisfy the compact - * dataflow condition, i.e. a block in the subtree is neither complete block nor - * reduction block */ -void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root, - const StmtSRef& scope_root_sref); +void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root); /*! * \brief Check if the block is an output block, i.e. the block writes to at least a buffer that is * not allocated under the current scope diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index c7ed67187793b..388413d73b5f5 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -143,25 +143,53 @@ ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) { return std::move(visitor.result); } +/*! + * \brief Check whether the given sref_a is higher than or equal to sref_b. + */ +void CheckSRefHigherOrEqual(const StmtSRef& sref_a, const StmtSRef& sref_b) { + const StmtSRefNode* p = sref_b.get(); + for (; p != nullptr; p = p->parent) { + if (p == sref_a.get()) { + return; + } + } + CHECK(false) << "Expect StmtSRef " << sref_a << "to be higher than or equal to " << sref_b; +} + /*! * \brief Check the dominant property of a block: - * the block is the only writer of its output, dominating the reader of its output buffers - * \param scope The block-scope of the block to be checked - * \param block_sref The block whose dominant property is to be checked - * \return A boolean indicating if the block is a dominant block + * the block is the only writer of its output, dominating the reader of its output buffers under the + * given root scope. + * \param self The schedule state. + * \param scope_root_sref The StmtSRef corresponding to the root scope. + * \param block_sref The block whose dominant property is to be checked. + * \return A boolean indicating if the block is a dominant block. */ -bool IsDominantBlock(const BlockScope& scope, const StmtSRef& block_sref) { +bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref, + const StmtSRef& block_sref) { + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; + CheckSRefHigherOrEqual(scope_root_sref, block_sref); + const BlockNode* maybe_root_block = scope_root_sref->StmtAs(); + if (maybe_root_block) { + BlockScope scope = self->GetBlockScope(scope_root_sref); + buffer_writers = scope->buffer_writers; + } else { + // Collect all child blocks of root sub-tree, and merge their buffer writers. + Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_sref); + for (const StmtSRef& child_block_sref : child_block_srefs) { + BlockScope child_scope = self->GetBlockScope(child_block_sref); + for (const auto& it : child_scope->buffer_writers) { + buffer_writers.insert(it); + } + } + } // Check whether the input block is the only writer of its outputs const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); - const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& buffer_writers = - scope->buffer_writers; for (const BufferRegion& write_region : block->writes) { - ICHECK(buffer_writers.count(write_region->buffer)) - << "InternalError: buffer \"" << write_region->buffer->name - << "\" does not exist in the current scope, when querying block:\n" - << GetRef(block); - if (buffer_writers.at(write_region->buffer).size() != 1) { - return false; + if (buffer_writers.count(write_region->buffer)) { + if (buffer_writers.at(write_region->buffer).size() != 1) { + return false; + } } } return true; @@ -178,7 +206,6 @@ bool IsDominantBlock(const BlockScope& scope, const StmtSRef& block_sref) { */ int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { - BlockScope scope = self->GetBlockScope(scope_root_sref); // Cond 1. All block vars are data parallel const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); for (const IterVar& iter_var : block->iter_vars) { @@ -188,7 +215,7 @@ int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block } // Cond 2. Dominant: the block is the only writer of its output, // dominating the reader of its output buffers - if (!IsDominantBlock(scope, block_sref)) { + if (!IsDominantBlock(self, scope_root_sref, block_sref)) { return 2; } // Cond 3. No overlap between the buffers the block reads and writes @@ -217,6 +244,18 @@ static const char* kReductionBlockDefinition = R"(Definition of a reduction bloc 4) Dominant: the block is the only writer of its output, dominating the reader of its output buffers 5) The reduction block vars are not used to index the output buffers)"; +static const char* kLocalCompleteBlockDefinition = R"(Definition of a local complete block: +1) All block vars are data parallel +2) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree +3) No overlap between the buffers the block reads and writes)"; + +static const char* kLocalReductionBlockDefinition = R"(Definition of a reduction block: +1) The block has the `init` statement +2) All the block bindings are quasi-affine expressions +3) All block vars are either data parallel block vars or reduction block vars +4) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree +5) The reduction block vars are not used to index the output buffers)"; + bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { return CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref) == 0; @@ -260,7 +299,6 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, */ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { - BlockScope scope = self->GetBlockScope(scope_root_sref); const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); // Cond 1. The block has the `init` statement. if (!block->init.defined()) { @@ -277,7 +315,7 @@ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& bloc } // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its // output buffers. - if (!IsDominantBlock(scope, block_sref)) { + if (!IsDominantBlock(self, scope_root_sref, block_sref)) { return 4; } // Cond 5. The reduction block vars are not used to index the output buffers. @@ -363,24 +401,35 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl reduction_block_error_code); } -void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root, - const StmtSRef& scope_root_sref) { +void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root) { class NotCompactDataFlowError : public ScheduleError { public: - explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block) + explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block, + int local_complete_block_code, int local_reduction_block_code) : mod_(std::move(mod)), subtree_root_(std::move(subtree_root)), - violate_block_(std::move(violate_block)) { + violate_block_(std::move(violate_block)), + local_complete_block_code_(local_complete_block_code), + local_reduction_block_code_(local_reduction_block_code) { ICHECK(subtree_root_->IsInstance() || subtree_root_->IsInstance()); } String FastErrorString() const final { return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, " - "because some of its child block on SRef tree is neither a complete block nor a " - "reduction block"; + "because some of its child block on SRef tree is neither a local complete block nor a " + "local reduction block."; } String DetailRenderTemplate() const final { - return "The queried subtree root {0} in SRef tree does not have compact dataflow, because " - "its child block {1} on SRef tree is neither a complete block nor a reduction block"; + std::ostringstream os; + os << "The queried subtree root {0} in SRef tree does not have compact dataflow, because " + "its child block {1} on SRef tree is neither a local complete block nor a local " + "reduction block.\n"; + os << "It violates condition #" << local_complete_block_code_ + << " as a local complete block.\n"; + os << kLocalCompleteBlockDefinition << "\n"; + os << "It violates condition #" << local_reduction_block_code_ + << " as a local reduction block.\n"; + os << kLocalReductionBlockDefinition << "\n"; + return os.str(); } IRModule mod() const final { return mod_; } Array LocationsOfInterest() const final { return {subtree_root_, violate_block_}; } @@ -388,15 +437,19 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt IRModule mod_; Stmt subtree_root_; Block violate_block_; + int local_complete_block_code_; + int local_reduction_block_code_; }; Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, subtree_root); for (const StmtSRef& block_sref : child_block_srefs) { - if (!IsCompleteBlock(self, block_sref, scope_root_sref) && - !IsReductionBlock(self, block_sref, scope_root_sref)) { + int local_complete_block_code = CheckCompleteBlockErrorCode(self, block_sref, subtree_root), + local_reduction_block_code = CheckReductionBlockErrorCode(self, block_sref, subtree_root); + if (local_complete_block_code != 0 && local_reduction_block_code != 0) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); throw NotCompactDataFlowError(self->mod, GetRef(subtree_root->stmt), - GetRef(block)); + GetRef(block), local_complete_block_code, + local_reduction_block_code); } } } diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index 333d783464537..ec337224e59d7 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -157,9 +157,7 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref * parallelized/vectorized/bound. */ // Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow. - StmtSRef scope_root_sref = GetScopeRoot(self, loop_sref, - /*require_stage_pipeline=*/true); - CheckSubtreeCompactDataflow(self, loop_sref, scope_root_sref); + CheckSubtreeCompactDataflow(self, loop_sref); // Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each // underlying block. diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py index caecde05b40fc..ac8288901688a 100644 --- a/tests/python/unittest/test_tir_schedule_for_kind.py +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -330,6 +330,72 @@ def decomposed_gemm_after_vectorize( C[vi, vj] = local[vi, vj] +@T.prim_func +def decomposed_gemm_parallelize_init( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], +) -> None: + local = T.alloc_buffer([16, 16], dtype="float32") + for i, j in T.grid(4, 4): + for ii in T.serial(4): + for jj in T.vectorized(4): + with T.block("init"): + vi = T.axis.spatial(16, i * 4 + ii) + vj = T.axis.spatial(16, j * 4 + jj) + T.reads() + T.writes(local[vi, vj]) + local[vi, vj] = 0 + for k, ii, jj in T.grid(16, 4, 4): + with T.block("update"): + vi = T.axis.spatial(16, i * 4 + ii) + vj = T.axis.spatial(16, j * 4 + jj) + vk = T.axis.reduce(16, k) + T.reads(local[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(local[vi, vj]) + local[vi, vj] = local[vi, vj] + A[vi, vk] * B[vj, vk] + for ii, jj in T.grid(4, 4): + with T.block("C"): + vi = T.axis.spatial(16, i * 4 + ii) + vj = T.axis.spatial(16, j * 4 + jj) + T.reads(local[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = local[vi, vj] + + +@T.prim_func +def scatter_compute(A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"]): + for i in T.grid(8): + with T.block("first_half"): + vi = T.axis.spatial(16, 8 + i) + B[vi] = A[vi - 8] + + for i in T.grid(8): + with T.block("last_half"): + vi = T.axis.spatial(16, i) + B[vi] = A[vi + 8] + + +@T.prim_func +def scatter_compute_parallelize( + A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"] +) -> None: + # body + # with T.block("root") + for i in T.parallel(8): + with T.block("first_half"): + vi = T.axis.spatial(16, 8 + i) + T.reads(A[vi - 8]) + T.writes(B[vi]) + B[vi] = A[vi - 8] + for i in T.parallel(8): + with T.block("last_half"): + vi = T.axis.spatial(16, i) + T.reads(A[vi + 8]) + T.writes(B[vi]) + B[vi] = A[vi + 8] + + # pylint: enable=no-member,invalid-name,unused-variable @@ -468,5 +534,28 @@ def test_vectorize_after_decompose(): verify_trace_roundtrip(s, mod=decomposed_gemm) +def test_vectorize_init(): + s = tir.Schedule(decomposed_gemm, debug_mask="all") + init_blk = s.get_block("init") + upd_blk = s.get_block("update") + _, _, ii_0, jj_0 = s.get_loops(init_blk) + _, _, k_1, ii_1, jj_1 = s.get_loops(upd_blk) + s.vectorize(jj_0) + tvm.ir.assert_structural_equal(s.mod["main"], decomposed_gemm_parallelize_init) + verify_trace_roundtrip(s, mod=decomposed_gemm) + + +def test_scatter_parallelize(): + s = tir.Schedule(scatter_compute, debug_mask="all") + first = s.get_block("first_half") + last = s.get_block("last_half") + (i_0,) = s.get_loops(first) + (i_1,) = s.get_loops(last) + s.parallel(i_0) + s.parallel(i_1) + tvm.ir.assert_structural_equal(s.mod["main"], scatter_compute_parallelize) + verify_trace_roundtrip(s, mod=scatter_compute) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))