1- /*
2- * Licensed to the Apache Software Foundation (ASF) under one
3- * or more contributor license agreements. See the NOTICE file
4- * distributed with this work for additional information
5- * regarding copyright ownership. The ASF licenses this file
6- * to you under the Apache License, Version 2.0 (the
7- * "License"); you may not use this file except in compliance
8- * with the License. You may obtain a copy of the License at
9- *
10- * http://www.apache.org/licenses/LICENSE-2.0
11- *
12- * Unless required by applicable law or agreed to in writing,
13- * software distributed under the License is distributed on an
14- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15- * KIND, either express or implied. See the License for the
16- * specific language governing permissions and limitations
17- * under the License.
18- */
19-
201/* !
212 * \file inject_software_pipeline.cc
223 * \brief Transform annotated loops into pipelined one that parallelize
@@ -79,6 +60,8 @@ struct PipelineAnnotation {
7960 int stage;
8061 int order;
8162 bool async;
63+ // Index of the statement in the original loop body order (SeqStmt order)
64+ int original_idx = -1 ;
8265};
8366
8467using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
@@ -304,15 +287,17 @@ class PipelineRewriter : public StmtExprMutator {
304287 }
305288
306289 // Step 2: Emit the pipeline prologue, body and epilogue.
307- Stmt prologue = EmitImpl (pipeline_loop_->min ,
308- pipeline_loop_->min + max_stage_, true , true );
309- Stmt body =
310- EmitImpl (pipeline_loop_->min + max_stage_,
311- pipeline_loop_->min + pipeline_loop_->extent , false , false );
312- Stmt epilogue = EmitImpl (
313- pipeline_loop_->min + pipeline_loop_->extent ,
314- pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true , true );
315-
290+ Stmt prologue =
291+ EmitImpl (pipeline_loop_->min , pipeline_loop_->min + max_stage_, true ,
292+ true , false );
293+ Stmt body = EmitImpl (pipeline_loop_->min + max_stage_,
294+ pipeline_loop_->min + pipeline_loop_->extent , false ,
295+ false , false );
296+
297+ Stmt epilogue =
298+ EmitImpl (pipeline_loop_->min + pipeline_loop_->extent ,
299+ pipeline_loop_->min + pipeline_loop_->extent + max_stage_,
300+ true , true , true );
316301 SeqStmt stmt = SeqStmt ({prologue, body, epilogue});
317302
318303 // Step 3: Make a new block that contains new buffer allocations after
@@ -515,69 +500,120 @@ class PipelineRewriter : public StmtExprMutator {
515500 // A symbolic expression representing the index the latest async operation
516501 // associated with this stage has written into, at the "current" iteration.
517502 Optional<PrimExpr> producer_head;
503+ // the commit block's predicate
504+ PrimExpr commit_predicate{nullptr };
518505 };
519506
520507 /* ! Structure holding intermediate information for pipeline loop rewriting. */
521508 struct RewrittenBlockInfo {
522509 int stage;
523510 int order;
511+ PrimExpr start;
512+ PrimExpr end;
524513 PrimExpr predicate;
525514 Block block;
526515 PrimExpr access_index;
527516 bool is_async;
528517 };
529518
530519 void PopulateWaitCounts (const std::vector<RewrittenBlockInfo> &new_blocks,
531- std::map<int , AsyncStateLocal> *async_states_local) {
520+ std::map<int , AsyncStateLocal> *async_states_local,
521+ bool is_epilogue = false ) {
522+ // Precompute which orders are present in this emit, and their access_index
523+ std::unordered_map<int , PrimExpr> order_to_access_index;
524+ std::unordered_set<int > present_orders;
525+ for (const auto &nb : new_blocks) {
526+ order_to_access_index[nb.order ] = nb.access_index ;
527+ present_orders.insert (nb.order );
528+ }
532529 for (size_t i = 0 ; i < new_blocks.size (); ++i) {
530+ // 1. Find the unique async producer stage
533531 int producer_stage_idx = -1 ;
534- for (auto read_region : new_blocks[i].block ->reads ) {
532+ for (const auto & read_region : new_blocks[i].block ->reads ) {
535533 for (const auto &[stage, state] : async_states) {
536534 if (stage <= new_blocks[i].stage &&
537535 state.writes (read_region->buffer )) {
538- // Found an earlier stage where read_region->buffer was
539- // asynchronously written
536+ // Currently only a single async stage dependency is supported
540537 ICHECK (producer_stage_idx == -1 || producer_stage_idx == stage)
541538 << " A dependency on multiple async stages is not supported" ;
542539 producer_stage_idx = stage;
543540 }
544541 }
545542 }
546- if (producer_stage_idx == -1 )
543+ if (producer_stage_idx == -1 ) {
544+ // This block does not depend on any async producer
547545 continue ;
546+ }
548547 const auto &state = async_states[producer_stage_idx];
548+
549549 auto &dep_local_state = (*async_states_local)[producer_stage_idx];
550- PrimExpr in_flight_cnt = 0 ;
551- for (const auto &group : state.commit_groups ) {
552- PrimExpr consumer_head = new_blocks[i].access_index ;
553- PrimExpr producer_head;
554- if (dep_local_state.producer_head .defined ()) {
555- producer_head = dep_local_state.producer_head .value ();
556- // if the group is after the wait point, minus by 1
557- if (group.front () > new_blocks[i].order )
558- producer_head -= 1 ;
559- } else {
560- producer_head = state.producer_head ;
561- }
562- in_flight_cnt += producer_head - consumer_head;
563- }
564550
565- // We can relax the in-flight-count by the number of independent commit.
551+ // 2. Use buffer_to_commit_group_ to find all actually dependent commit
552+ // groups
566553 std::unordered_set<int > dependent_groups;
567554 for (const auto &read_region : new_blocks[i].block ->reads ) {
568- if (state.buffer_to_commit_group_ .count (read_region->buffer .get ()))
569- dependent_groups.insert (
570- state.buffer_to_commit_group_ .at (read_region->buffer .get ()));
555+ auto it = state.buffer_to_commit_group_ .find (read_region->buffer .get ());
556+ if (it != state.buffer_to_commit_group_ .end ()) {
557+ dependent_groups.insert (it->second );
558+ }
571559 }
572- for (int i = int (state.commit_groups .size ()) - 1 ; i >= 0 ; i--) {
573- if (dependent_groups.count (i) == 0 )
574- in_flight_cnt += 1 ;
575- else
576- break ; // stop relaxing
560+
561+ // If there is no dependent commit group, no wait needs to be inserted
562+ if (dependent_groups.empty ()) {
563+ continue ;
564+ }
565+
566+ // 3. Compute wait = max_g max(0, t_consumer - committed_before[g])
567+ PrimExpr t_consumer = new_blocks[i].access_index ;
568+ PrimExpr wait_expr = make_zero (t_consumer.dtype ());
569+
570+ PrimExpr current_head = dep_local_state.producer_head .defined ()
571+ ? dep_local_state.producer_head .value ()
572+ : state.producer_head ;
573+ int consumer_order = new_blocks[i].order ;
574+
575+ for (int g : dependent_groups) {
576+ const auto &group = state.commit_groups [g];
577+ if (group.empty ())
578+ continue ;
579+ int commit_order = group.back ();
580+ bool commit_present = present_orders.count (commit_order) > 0 ;
581+
582+ PrimExpr committed_before;
583+ if (commit_present && commit_order <= consumer_order) {
584+ // Commit point is in this iteration and earlier than the current
585+ // consumer; this iteration's head is visible
586+ auto commit_predicate = dep_local_state.commit_predicate ;
587+ if (analyzer_.CanProve (!commit_predicate,
588+ arith::ProofStrength::kSymbolicBound )) {
589+ // it means the commit block is not executed in this iteration
590+ committed_before = new_blocks[i].start - 1 ;
591+ } else if (is_epilogue) {
592+ committed_before = new_blocks[i].start - 1 ;
593+ } else {
594+ committed_before = order_to_access_index.at (commit_order);
595+ }
596+ } else {
597+ // Commit point is later than the current consumer or not in this
598+ // iteration; only the previous iteration's head is visible
599+ if (dep_local_state.producer_head .defined ()) {
600+ auto commit_predicate = dep_local_state.commit_predicate ;
601+ if (analyzer_.CanProve (!commit_predicate,
602+ arith::ProofStrength::kSymbolicBound )) {
603+ committed_before = new_blocks[i].start - 1 ;
604+ } else if (is_epilogue) {
605+ committed_before = new_blocks[i].start - 1 ;
606+ } else {
607+ committed_before = current_head - 1 ;
608+ }
609+ }
610+ }
611+
612+ wait_expr = analyzer_.Simplify (committed_before - t_consumer);
577613 }
578- in_flight_cnt = analyzer_. Simplify (in_flight_cnt);
579- dep_local_state. pending_waits . push_back (
580- {static_cast <int >(i), in_flight_cnt });
614+
615+ wait_expr = analyzer_. Simplify (wait_expr);
616+ dep_local_state. pending_waits . push_back ( {static_cast <int >(i), wait_expr });
581617 }
582618 }
583619
@@ -630,7 +666,7 @@ class PipelineRewriter : public StmtExprMutator {
630666 * \return The result loop.
631667 */
632668 Stmt EmitImpl (const PrimExpr &start, const PrimExpr &end, bool unroll_loop,
633- bool need_bound_check) {
669+ bool need_bound_check, bool is_epilogue = false ) {
634670 PrimExpr new_loop_var;
635671 PrimExpr extent = end - start;
636672 auto make_nop = []() {
@@ -642,7 +678,20 @@ class PipelineRewriter : public StmtExprMutator {
642678 new_loop_var = start; // use constants as the loop var for unit loops
643679 } else {
644680 new_loop_var = pipeline_loop_->loop_var .copy_with_suffix (" " );
645- analyzer_.Bind (Downcast<Var>(new_loop_var), Range (start, end));
681+ // Bind the iteration domain [start, end) to strengthen analyzer facts.
682+ analyzer_.Bind (Downcast<Var>(new_loop_var),
683+ Range::FromMinExtent (start, end - start));
684+ }
685+ // Keep the bound constraints active for all analysis below.
686+ // Only meaningful when the loop var is symbolic (non-unit loop).
687+ std::unique_ptr<With<arith::ConstraintContext>> ctx_lb_guard;
688+ std::unique_ptr<With<arith::ConstraintContext>> ctx_ub_guard;
689+ if (!is_unit_loop) {
690+ Var loop_iter = Downcast<Var>(new_loop_var);
691+ ctx_lb_guard.reset (
692+ new With<arith::ConstraintContext>(&analyzer_, loop_iter >= start));
693+ ctx_ub_guard.reset (
694+ new With<arith::ConstraintContext>(&analyzer_, loop_iter < end));
646695 }
647696
648697 std::vector<RewrittenBlockInfo> new_blocks;
@@ -653,15 +702,14 @@ class PipelineRewriter : public StmtExprMutator {
653702 for (const Block &block : ordered_stmts_) {
654703 int stage = pipeline_info_.at (block).stage ;
655704 int order = pipeline_info_.at (block).order ;
705+
656706 PrimExpr inbound = Bool (true );
657707 PrimExpr skewed_loop_var = new_loop_var - stage;
658708 if (need_bound_check)
659- inbound =
660- analyzer_.Simplify (pipeline_loop_->min <= skewed_loop_var) &&
661- (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent );
662- if (analyzer_.CanProve (!inbound)) {
663- continue ;
664- }
709+ inbound = And (
710+ pipeline_loop_->min <= skewed_loop_var,
711+ (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent ));
712+
665713 Block new_block = Downcast<Block>(
666714 PipelineBodyRewriter (buffer_data_to_buffer_, buffer_remap_,
667715 pipeline_loop_, max_stage_ != 1 )(block));
@@ -674,6 +722,8 @@ class PipelineRewriter : public StmtExprMutator {
674722 PrimExpr normalized_access_index =
675723 is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
676724
725+ normalized_access_index = analyzer_.Simplify (normalized_access_index);
726+
677727 // Adjust the block predicate and the body according to the final loop
678728 // bound
679729 // [pipeline_loop_->min, extent).
@@ -701,17 +751,18 @@ class PipelineRewriter : public StmtExprMutator {
701751 if (pipeline_info_[block].async ) {
702752 auto &local_state = async_states_local[stage];
703753 local_state.producer_head = normalized_access_index;
754+ local_state.commit_predicate = inbound;
704755 BlockNode *n = new_block.CopyOnWrite ();
705756 n->body = AttrStmt (make_zero (DataType::Int (32 )), tir::attr::async_scope,
706757 1 , n->body );
707758 }
708759
709- new_blocks.push_back ({stage, order, inbound, new_block,
760+ new_blocks.push_back ({stage, order, start, end, inbound, new_block,
710761 normalized_access_index,
711762 pipeline_info_[block].async });
712763 }
713764
714- PopulateWaitCounts (new_blocks, &async_states_local);
765+ PopulateWaitCounts (new_blocks, &async_states_local, is_epilogue );
715766
716767 auto stmts = CompletePipelineLoopStatements (new_blocks, async_states_local);
717768
@@ -1008,7 +1059,8 @@ class PipelineInjector : private StmtExprMutator {
10081059 pipeline_async_stages.find (stage) != pipeline_async_stages.end ();
10091060 PipelineAnnotation stage_order{
10101061 stage,
1011- /* order=*/ static_cast <int >(pipeline_orders[i]->value ), is_async};
1062+ /* order=*/ static_cast <int >(pipeline_orders[i]->value ), is_async,
1063+ /* original_idx=*/ static_cast <int >(i)};
10121064 pipeline_info.emplace (original_order[i], stage_order);
10131065 }
10141066
0 commit comments