diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index ff0b11a73c9b..dd1fce0fbef7 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -117,10 +117,13 @@ class LCADetector : public StmtExprVisitor { ancestor_scopes_.push_back(current_scope); - // For each accessed buffer of the block, update the buffer's lca to + // For each accessed buffer of the block + // If it accesses the opaque block iter vars, update the buffer's lca to // the lowest inclusive stmt position, which should dominate all loops - // related to the accessed opaque block iter vars in buffer indices. - UpdateDominateScopeOfOpaqueIter(op); + // related to the accessed opaque block iter vars. + // If it is the reduction block write buffer, update the buffer's lca to + // dominate all reduction iter var related loops. + UpdateDominateScopeOfNonDataParIter(op); // Update match_buffers for (const MatchBufferRegion& match_buffer : block->match_buffers) { @@ -132,43 +135,70 @@ class LCADetector : public StmtExprVisitor { ancestor_scopes_.pop_back(); } - void UpdateDominateScopeOfOpaqueIter(const BlockRealizeNode* block_realize) { - // map opaque iter var to the scope which dominate all loop carried dependencies. - std::unordered_map itervar_to_dom_scope; + void UpdateDominateScopeOfNonDataParIter(const BlockRealizeNode* block_realize) { + // map iter var to the scope which dominate all loop carried dependencies. + std::unordered_map opaque_var_scope; + // maintain highest scope which dominate all reduce loop iters. null denotes non-reduce block. + const ScopeInfo* highest_reduce_scope = nullptr; // function to collect `itervar_to_dom_scope`, the result scope for each block // iter var should be above all loop scopes the opaque iter var binding relates to. - auto do_collect_itervar_scope = [this, &itervar_to_dom_scope](const IterVar& itervar, - const PrimExpr& binding) { - PostOrderVisit(binding, [this, &itervar_to_dom_scope, &itervar](const ObjectRef& obj) { + auto do_collect_itervar_scope = [this](const IterVar& itervar, + const PrimExpr& binding) -> const ScopeInfo* { + const ScopeInfo* highest_scope = nullptr; + PostOrderVisit(binding, [this, &itervar, &highest_scope](const ObjectRef& obj) { if (const VarNode* loop_var = obj.as()) { auto it = loop_scope_map_.find(loop_var); if (it == loop_scope_map_.end()) { return; } const ScopeInfo* scope = it->second->parent_scope_info; - // find the highest loop scope the iter var binding has related to. - auto dom_scope_it = itervar_to_dom_scope.find(itervar->var.get()); - if (dom_scope_it == itervar_to_dom_scope.end()) { - itervar_to_dom_scope.insert(dom_scope_it, {itervar->var.get(), scope}); - } else if (scope->depth < dom_scope_it->second->depth) { - dom_scope_it->second = scope; + if (highest_scope == nullptr) { + highest_scope = scope; + } else if (scope->depth < highest_scope->depth) { + highest_scope = scope; } } }); + return highest_scope; }; + // collect non-data-parallel block iteration's dominate scope. + // for reduction iter type, we maintain the highest dominate scope for all reduce iters. + // for other iter type, we maintain the dict for each individual iter. + const Block& block = block_realize->block; + bool is_reduce_block = false; + for (size_t i = 0; i < block_realize->iter_values.size(); ++i) { + const IterVar& iter_var = block->iter_vars[i]; + if (iter_var->iter_type != IterVarType::kDataPar) { + const auto* scope = do_collect_itervar_scope(iter_var, block_realize->iter_values[i]); + if (scope == nullptr) continue; + if (iter_var->iter_type == IterVarType::kCommReduce) { + is_reduce_block = true; + if (highest_reduce_scope == nullptr || scope->depth < highest_reduce_scope->depth) { + highest_reduce_scope = scope; + } + } else { + opaque_var_scope[iter_var->var.get()] = scope; + for (const auto& write : block->writes) { + UpdateBufferLCA(write->buffer.get(), scope); + } + } + } + } + // function to update lca scope of the buffer with loop carried dependent buffer accesses. // the result scope should be above all loop scopes the accessed opaque block iter vars // relate to, which is record in `itervar_to_dom_scope`. - auto do_update = [this, &itervar_to_dom_scope](const BufferRegion& region) { + auto do_update = [this, &opaque_var_scope, highest_reduce_scope](const BufferRegion& region, + bool is_reduce_write = false) { const Buffer& buffer = region->buffer; const ScopeInfo* scope = ancestor_scopes_.back(); - auto handle_itervar = [&itervar_to_dom_scope, &scope](const ObjectRef& obj) { + auto handle_itervar = [&opaque_var_scope, &scope](const ObjectRef& obj) { if (const VarNode* iter_var = obj.as()) { - auto dom_scope_it = itervar_to_dom_scope.find(iter_var); - if (dom_scope_it == itervar_to_dom_scope.end()) { + auto dom_scope_it = opaque_var_scope.find(iter_var); + if (dom_scope_it == opaque_var_scope.end()) { return; } // find the highest loop scope the accessed buffer index has @@ -184,24 +214,25 @@ class LCADetector : public StmtExprVisitor { PostOrderVisit(range->min, handle_itervar); PostOrderVisit(range->min + range->extent - 1, handle_itervar); } + + // the scope should be above `highest_reduce_scope` for reduce output buffer. + if (is_reduce_write && highest_reduce_scope != nullptr && + scope->depth > highest_reduce_scope->depth) { + scope = highest_reduce_scope; + } UpdateBufferLCA(buffer.get(), scope); }; - // do collect and update - const Block& block = block_realize->block; - for (size_t i = 0; i < block_realize->iter_values.size(); ++i) { - const IterVar& iter_var = block->iter_vars[i]; - if (iter_var->iter_type != IterVarType::kDataPar && - iter_var->iter_type != IterVarType::kCommReduce) { - do_collect_itervar_scope(iter_var, block_realize->iter_values[i]); - } - } - if (!itervar_to_dom_scope.empty()) { + if (!opaque_var_scope.empty()) { for (const auto& read : block->reads) { do_update(read); } for (const auto& write : block->writes) { - do_update(write); + do_update(write, /*is_reduce_write=*/is_reduce_block); + } + } else if (is_reduce_block && highest_reduce_scope != nullptr) { + for (const auto& write : block->writes) { + do_update(write, /*is_reduce_write=*/true); } } } diff --git a/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py index a1808c841303..b3ce7efd0593 100644 --- a/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py @@ -116,9 +116,10 @@ def test_buffer_load_store(): root_block = func.body.block assert lca[A] == func.body.block - # LCA of Buffer B is reduction block - reduce_block = root_block.body[1].body.body.body.block - assert lca[B] == reduce_block + # LCA of Buffer B is the loop dominate all reduction loop + reduce_dom_loop = root_block.body[1].body + reduce_block = reduce_dom_loop.body.body.block + assert lca[B] == reduce_dom_loop # LCA of Buffer C is the second loop kk loop_jj = reduce_block.body.body diff --git a/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py index 8500f114610c..ff3fa8cf7092 100644 --- a/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py @@ -402,5 +402,55 @@ def before(dlpack_handle: T.handle, axis: T.int64) -> T.int64: _check(before, after) +def test_reduce_buffer_dominate_reduce_loops(): + """Reduction write buffer allocation should dominate all reduce loops""" + + @T.prim_func + def before(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256, 256), "float32")): + x_red_ = T.alloc_buffer((256, 256)) + for ax0_0, k1_0, ax1_0 in T.grid(4, 4, 4): + for ax0_1, k1_1, ax1_1 in T.grid(64, 64, 64): + with T.block("x_red"): + v_ax0 = T.axis.spatial(256, ax0_0 * 64 + ax0_1) + v_ax1 = T.axis.spatial(256, ax1_0 * 64 + ax1_1) + v_k1 = T.axis.reduce(256, k1_0 * 64 + k1_1) + if v_k1 == 0: + x_red_[v_ax0, v_ax1] = T.float32(0.0) + x_red_[v_ax0, v_ax1] = x_red_[v_ax0, v_ax1] + x[v_ax0, v_k1, v_ax1] + for ax0, ax1 in T.grid(64, 64): + with T.block("x_red_"): + v0 = T.axis.spatial(256, ax0_0 * 64 + ax0) + v1 = T.axis.spatial(256, ax1_0 * 64 + ax1) + x_red[v0, v1] = x_red_[v0, v1] + + @T.prim_func + def after(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256, 256), "float32")): + for ax0_0 in range(4): + with T.block(""): + T.reads(x[ax0_0 * 64 : ax0_0 * 64 + 64, 0:256, 0:256]) + T.writes(x_red[ax0_0 * 64 : ax0_0 * 64 + 64, 0:256]) + x_red_ = T.alloc_buffer((256, 256)) + for k1_0, ax1_0 in T.grid(4, 4): + for ax0_1, k1_1, ax1_1 in T.grid(64, 64, 64): + with T.block("x_red"): + v_ax0 = T.axis.spatial(256, ax0_0 * 64 + ax0_1) + v_ax1 = T.axis.spatial(256, ax1_0 * 64 + ax1_1) + v_k1 = T.axis.reduce(256, k1_0 * 64 + k1_1) + T.reads(x_red_[v_ax0, v_ax1], x[v_ax0, v_k1, v_ax1]) + T.writes(x_red_[v_ax0, v_ax1]) + if v_k1 == 0: + x_red_[v_ax0, v_ax1] = T.float32(0.0) + x_red_[v_ax0, v_ax1] = x_red_[v_ax0, v_ax1] + x[v_ax0, v_k1, v_ax1] + for ax0, ax1 in T.grid(64, 64): + with T.block("x_red_"): + v0 = T.axis.spatial(256, ax0_0 * 64 + ax0) + v1 = T.axis.spatial(256, ax1_0 * 64 + ax1) + T.reads(x_red_[v0, v1]) + T.writes(x_red[v0, v1]) + x_red[v0, v1] = x_red_[v0, v1] + + _check(before, after) + + if __name__ == "__main__": tvm.testing.main()