From 2f26fb249077b626429f71379765b6a10c36ae1c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Jul 2022 09:38:48 +0900 Subject: [PATCH] introduce CompletePipelineLoopStatements function for further refactor --- .../transforms/inject_software_pipeline.cc | 182 ++++++++++-------- 1 file changed, 101 insertions(+), 81 deletions(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 3af09f2a9ab87..227935bf72dd2 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -561,10 +561,11 @@ class PipelineRewriter : public StmtExprMutator { // async_commit_queue for each producer. Thus, we need multiple sets of indices. std::vector> commit_groups; - // TODO + // This is set to true when we reach a stage that consumes this async stage. bool consumed{false}; }; + /*! Structure holding intermediate information for pipeline loop rewriting. */ struct RewrittenBlockInfo { int stage; PrimExpr predicate; @@ -573,15 +574,16 @@ class PipelineRewriter : public StmtExprMutator { bool is_async; }; - void DetermineWaitCounts(const std::vector& new_blocks, - arith::Analyzer& ana_normalized, - const std::unordered_map& buffer_to_commit_group, - std::map& async_states_local) { + // Determine where to insert async_wait and the corresponding wait count. + void PopulateWaitCounts(const std::vector& new_blocks, + arith::Analyzer* ana_normalized, + const std::unordered_map& buffer_to_commit_group, + std::map* async_states_local) { for (size_t i = 0; i < new_blocks.size(); ++i) { if (new_blocks[i].is_async) { // Record the fact that we have encountered these write buffers. for (auto write_region : new_blocks[i].block->writes) { - async_states_local[new_blocks[i].stage].seen.insert(write_region->buffer.get()); + (*async_states_local)[new_blocks[i].stage].seen.insert(write_region->buffer.get()); } } @@ -641,7 +643,7 @@ class PipelineRewriter : public StmtExprMutator { // done by the previous iteration, so its wait_count is calculated as ((i - 1) + 3) - i. The // sum of the two wait_counts gives 5. - auto& dep_local_state = async_states_local[producer_stage_idx]; + auto& dep_local_state = (*async_states_local)[producer_stage_idx]; const auto num_commit_group = dep_local_state.commit_groups.size(); std::vector> producer_head_per_commit; @@ -675,7 +677,7 @@ class PipelineRewriter : public StmtExprMutator { auto wait_count = [=, &ana_normalized]() { auto sum = PrimExpr(0); for (auto producer_head : producer_head_per_commit) { - if (producer_head && ana_normalized.CanProve(producer_head.value() >= 0)) { + if (producer_head && ana_normalized->CanProve(producer_head.value() >= 0)) { // Here, new_blocks[i].access_index corresponds to "consumer_head". // The difference of producer_head and consumer_head is precisely the number of // async commit groups that can still be in flight after this wait. @@ -699,6 +701,78 @@ class PipelineRewriter : public StmtExprMutator { } } + // Given pipelined blocks and async-related information, generate final loop statements with async + // scopes (if any). + Array CompletePipelineLoopStatements( + const std::vector& blocks, + const std::map& async_states_local, + arith::Analyzer* ana_normalized) const { + std::vector new_blocks = blocks; + std::vector commit_group_indices(new_blocks.size(), -1); + for (const auto& kv : async_states_local) { + const int stage_id = kv.first; + const AsyncStateLocal& state = kv.second; + + if (!state.commit_groups.empty()) { + for (size_t i = 0; i < state.commit_groups.size(); ++i) { + for (size_t j = 0; j < state.commit_groups[i].size(); ++j) { + ICHECK(state.commit_groups[i][0] + j < new_blocks.size()); + commit_group_indices[state.commit_groups[i][0] + j] = stage_id; + } + } + } + + if (state.pending_wait.valid()) { + auto attach_wait_scope = [&new_blocks](int i, int stage_id, PrimExpr wait_count) { + auto& block = new_blocks[i].block; + BlockNode* n = block.CopyOnWrite(); + auto zero = make_zero(DataType::Int(32)); + n->body = + AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id, + AttrStmt(zero, tir::attr::async_wait_inflight_count, wait_count, n->body)); + }; + + if (state.predicate && !ana_normalized->CanProve(state.predicate.value())) { + // If the async operation that this wait_queue is waiting on is predicated, and we cannot + // prove that the predicate is always true, the precise wait count is only valid + // at iterations where the predicate is true; + auto wait_count = Call(DataType::Int(32), builtin::if_then_else(), + {state.predicate.value(), state.pending_wait.wait_count, 0}); + attach_wait_scope(state.pending_wait.insert_before, stage_id, wait_count); + } else { + attach_wait_scope(state.pending_wait.insert_before, stage_id, + state.pending_wait.wait_count); + } + } + } + + Array stmts; + + for (size_t i = 0; i < new_blocks.size();) { + if (commit_group_indices[i] == -1) { + // A synchrnous block, not part of any commit group + stmts.push_back(BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block)); + ++i; + } else { + Array group_bodies; + auto stage_id = commit_group_indices[i]; + auto predicate = new_blocks[i].predicate; + for (; i < commit_group_indices.size() && commit_group_indices[i] == stage_id; ++i) { + ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate)) + << "Predicates in the same stage are expected to be identical"; + group_bodies.push_back(new_blocks[i].block->body); + } + auto body = group_bodies.size() > 1 ? SeqStmt(group_bodies) : group_bodies[0]; + auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)), + tir::attr::async_commit_queue_scope, stage_id, body); + auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_); + stmts.push_back(BlockRealize({}, predicate, new_block)); + } + } + + return stmts; + } + /*! * \brief Emit the pipeline loop in the given range. * \param start The start of the range @@ -707,7 +781,6 @@ class PipelineRewriter : public StmtExprMutator { * \return The result loop. */ Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop) { - Array stmts; PrimExpr new_loop_var; PrimExpr extent = end - start; @@ -811,52 +884,36 @@ class PipelineRewriter : public StmtExprMutator { for (auto kv : async_states) { int producer_stage_id = kv.first; if (producer_stage_id <= stage && kv.second.writes(read_region->buffer)) { - async_states_local[producer_stage_id].consumed = true; + async_states_local[producer_stage_id].consumed = true; } } } } - DetermineWaitCounts(new_blocks, ana_normalized, buffer_to_commit_group, async_states_local); + PopulateWaitCounts(new_blocks, &ana_normalized, buffer_to_commit_group, &async_states_local); + auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local, &ana_normalized); - std::vector commit_group_indices(new_blocks.size(), -1); + Stmt new_loop{nullptr}; + + if (stmts.empty()) { + return make_nop(); + } + if (stmts.size() == 1) { + new_loop = stmts[0]; + } else { + new_loop = SeqStmt(stmts); + } + + if (!is_unit_loop) { + new_loop = For(Downcast(new_loop_var), pipeline_loop_->min, extent, + unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop)); + } + // Update producer heads in the global async states. for (const auto& kv : async_states_local) { const int stage_id = kv.first; const AsyncStateLocal& state = kv.second; - if (!state.commit_groups.empty()) { - for (size_t i = 0; i < state.commit_groups.size(); ++i) { - for (size_t j = 0; j < state.commit_groups[i].size(); ++j) { - ICHECK(state.commit_groups[i][0] + j < new_blocks.size()); - commit_group_indices[state.commit_groups[i][0] + j] = stage_id; - } - } - } - - if (state.pending_wait.valid()) { - auto attach_wait_scope = [&new_blocks](int i, int stage_id, PrimExpr wait_count) { - auto& block = new_blocks[i].block; - BlockNode* n = block.CopyOnWrite(); - auto zero = make_zero(DataType::Int(32)); - n->body = - AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id, - AttrStmt(zero, tir::attr::async_wait_inflight_count, wait_count, n->body)); - }; - - if (state.predicate && !ana_normalized.CanProve(state.predicate.value())) { - // If the async operation that this wait_queue is waiting on is predicated, and we cannot - // prove that the predicate is always true, the precise wait count is only valid - // at iterations where the predicate is true; - auto wait_count = Call(DataType::Int(32), builtin::if_then_else(), - {state.predicate.value(), state.pending_wait.wait_count, 0}); - attach_wait_scope(state.pending_wait.insert_before, stage_id, wait_count); - } else { - attach_wait_scope(state.pending_wait.insert_before, stage_id, - state.pending_wait.wait_count); - } - } - if (state.predicate && ana_normalized.CanProve(state.predicate.value()) && async_states[stage_id].producer_head) { // Advance the "global" producer head if it is still valid and we know exactly how much we @@ -869,43 +926,6 @@ class PipelineRewriter : public StmtExprMutator { } } - for (size_t i = 0; i < new_blocks.size();) { - if (commit_group_indices[i] == -1) { - // A synchrnous block, not part of any commit group - stmts.push_back(BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block)); - ++i; - } else { - Array group_bodies; - auto stage_id = commit_group_indices[i]; - auto predicate = new_blocks[i].predicate; - for (; i < commit_group_indices.size() && commit_group_indices[i] == stage_id; ++i) { - ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate)) - << "Predicates in the same stage are expected to be identical"; - group_bodies.push_back(new_blocks[i].block->body); - } - auto body = group_bodies.size() > 1 ? SeqStmt(group_bodies) : group_bodies[0]; - auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)), - tir::attr::async_commit_queue_scope, stage_id, body); - auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_); - stmts.push_back(BlockRealize({}, predicate, new_block)); - } - } - - Stmt new_loop{nullptr}; - - if (stmts.empty()) { - return make_nop(); - } - if (stmts.size() == 1) { - new_loop = stmts[0]; - } else { - new_loop = SeqStmt(stmts); - } - - if (!is_unit_loop) { - new_loop = For(Downcast(new_loop_var), pipeline_loop_->min, extent, - unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop)); - } return BlockRealize({}, Bool(true), MakeBlock(std::move(new_loop), buffer_data_to_buffer_)); }