diff --git a/3rdparty/tvm b/3rdparty/tvm index 3a32b763e..90581fe9e 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 3a32b763e9d8393b14e4d0f824b2846f70041bc1 +Subproject commit 90581fe9e5287bbcf1844ad14255a1e1e8cdf7f0 diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index 511ebc573..79e78add9 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -1,22 +1,3 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - /*! * \file inject_software_pipeline.cc * \brief Transform annotated loops into pipelined one that parallelize @@ -79,6 +60,8 @@ struct PipelineAnnotation { int stage; int order; bool async; + // Index of the statement in the original loop body order (SeqStmt order) + int original_idx = -1; }; using PipelineInfo = std::unordered_mapmin, - pipeline_loop_->min + max_stage_, true, true); - Stmt body = - EmitImpl(pipeline_loop_->min + max_stage_, - pipeline_loop_->min + pipeline_loop_->extent, false, false); - Stmt epilogue = EmitImpl( - pipeline_loop_->min + pipeline_loop_->extent, - pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true); - + Stmt prologue = + EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true, + true, false); + Stmt body = EmitImpl(pipeline_loop_->min + max_stage_, + pipeline_loop_->min + pipeline_loop_->extent, false, + false, false); + + Stmt epilogue = + EmitImpl(pipeline_loop_->min + pipeline_loop_->extent, + pipeline_loop_->min + pipeline_loop_->extent + max_stage_, + true, true, true); SeqStmt stmt = SeqStmt({prologue, body, epilogue}); // Step 3: Make a new block that contains new buffer allocations after @@ -515,12 +500,16 @@ class PipelineRewriter : public StmtExprMutator { // A symbolic expression representing the index the latest async operation // associated with this stage has written into, at the "current" iteration. Optional producer_head; + // the commit block's predicate + PrimExpr commit_predicate{nullptr}; }; /*! Structure holding intermediate information for pipeline loop rewriting. */ struct RewrittenBlockInfo { int stage; int order; + PrimExpr start; + PrimExpr end; PrimExpr predicate; Block block; PrimExpr access_index; @@ -528,56 +517,103 @@ class PipelineRewriter : public StmtExprMutator { }; void PopulateWaitCounts(const std::vector &new_blocks, - std::map *async_states_local) { + std::map *async_states_local, + bool is_epilogue = false) { + // Precompute which orders are present in this emit, and their access_index + std::unordered_map order_to_access_index; + std::unordered_set present_orders; + for (const auto &nb : new_blocks) { + order_to_access_index[nb.order] = nb.access_index; + present_orders.insert(nb.order); + } for (size_t i = 0; i < new_blocks.size(); ++i) { + // 1. Find the unique async producer stage int producer_stage_idx = -1; - for (auto read_region : new_blocks[i].block->reads) { + for (const auto &read_region : new_blocks[i].block->reads) { for (const auto &[stage, state] : async_states) { if (stage <= new_blocks[i].stage && state.writes(read_region->buffer)) { - // Found an earlier stage where read_region->buffer was - // asynchronously written + // Currently only a single async stage dependency is supported ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage) << "A dependency on multiple async stages is not supported"; producer_stage_idx = stage; } } } - if (producer_stage_idx == -1) + if (producer_stage_idx == -1) { + // This block does not depend on any async producer continue; + } const auto &state = async_states[producer_stage_idx]; + auto &dep_local_state = (*async_states_local)[producer_stage_idx]; - PrimExpr in_flight_cnt = 0; - for (const auto &group : state.commit_groups) { - PrimExpr consumer_head = new_blocks[i].access_index; - PrimExpr producer_head; - if (dep_local_state.producer_head.defined()) { - producer_head = dep_local_state.producer_head.value(); - // if the group is after the wait point, minus by 1 - if (group.front() > new_blocks[i].order) - producer_head -= 1; - } else { - producer_head = state.producer_head; - } - in_flight_cnt += producer_head - consumer_head; - } - // We can relax the in-flight-count by the number of independent commit. + // 2. Use buffer_to_commit_group_ to find all actually dependent commit + // groups std::unordered_set dependent_groups; for (const auto &read_region : new_blocks[i].block->reads) { - if (state.buffer_to_commit_group_.count(read_region->buffer.get())) - dependent_groups.insert( - state.buffer_to_commit_group_.at(read_region->buffer.get())); + auto it = state.buffer_to_commit_group_.find(read_region->buffer.get()); + if (it != state.buffer_to_commit_group_.end()) { + dependent_groups.insert(it->second); + } } - for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) { - if (dependent_groups.count(i) == 0) - in_flight_cnt += 1; - else - break; // stop relaxing + + // If there is no dependent commit group, no wait needs to be inserted + if (dependent_groups.empty()) { + continue; + } + + // 3. Compute wait = max_g max(0, t_consumer - committed_before[g]) + PrimExpr t_consumer = new_blocks[i].access_index; + PrimExpr wait_expr = make_zero(t_consumer.dtype()); + + PrimExpr current_head = dep_local_state.producer_head.defined() + ? dep_local_state.producer_head.value() + : state.producer_head; + int consumer_order = new_blocks[i].order; + + for (int g : dependent_groups) { + const auto &group = state.commit_groups[g]; + if (group.empty()) + continue; + int commit_order = group.back(); + bool commit_present = present_orders.count(commit_order) > 0; + + PrimExpr committed_before; + if (commit_present && commit_order <= consumer_order) { + // Commit point is in this iteration and earlier than the current + // consumer; this iteration's head is visible + auto commit_predicate = dep_local_state.commit_predicate; + if (analyzer_.CanProve(!commit_predicate, + arith::ProofStrength::kSymbolicBound)) { + // it means the commit block is not executed in this iteration + committed_before = new_blocks[i].start - 1; + } else if (is_epilogue) { + committed_before = new_blocks[i].start - 1; + } else { + committed_before = order_to_access_index.at(commit_order); + } + } else { + // Commit point is later than the current consumer or not in this + // iteration; only the previous iteration's head is visible + if (dep_local_state.producer_head.defined()) { + auto commit_predicate = dep_local_state.commit_predicate; + if (analyzer_.CanProve(!commit_predicate, + arith::ProofStrength::kSymbolicBound)) { + committed_before = new_blocks[i].start - 1; + } else if (is_epilogue) { + committed_before = new_blocks[i].start - 1; + } else { + committed_before = current_head - 1; + } + } + } + + wait_expr = analyzer_.Simplify(committed_before - t_consumer); } - in_flight_cnt = analyzer_.Simplify(in_flight_cnt); - dep_local_state.pending_waits.push_back( - {static_cast(i), in_flight_cnt}); + + wait_expr = analyzer_.Simplify(wait_expr); + dep_local_state.pending_waits.push_back({static_cast(i), wait_expr}); } } @@ -630,7 +666,7 @@ class PipelineRewriter : public StmtExprMutator { * \return The result loop. */ Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop, - bool need_bound_check) { + bool need_bound_check, bool is_epilogue = false) { PrimExpr new_loop_var; PrimExpr extent = end - start; auto make_nop = []() { @@ -642,7 +678,20 @@ class PipelineRewriter : public StmtExprMutator { new_loop_var = start; // use constants as the loop var for unit loops } else { new_loop_var = pipeline_loop_->loop_var.copy_with_suffix(""); - analyzer_.Bind(Downcast(new_loop_var), Range(start, end)); + // Bind the iteration domain [start, end) to strengthen analyzer facts. + analyzer_.Bind(Downcast(new_loop_var), + Range::FromMinExtent(start, end - start)); + } + // Keep the bound constraints active for all analysis below. + // Only meaningful when the loop var is symbolic (non-unit loop). + std::unique_ptr> ctx_lb_guard; + std::unique_ptr> ctx_ub_guard; + if (!is_unit_loop) { + Var loop_iter = Downcast(new_loop_var); + ctx_lb_guard.reset( + new With(&analyzer_, loop_iter >= start)); + ctx_ub_guard.reset( + new With(&analyzer_, loop_iter < end)); } std::vector new_blocks; @@ -653,15 +702,14 @@ class PipelineRewriter : public StmtExprMutator { for (const Block &block : ordered_stmts_) { int stage = pipeline_info_.at(block).stage; int order = pipeline_info_.at(block).order; + PrimExpr inbound = Bool(true); PrimExpr skewed_loop_var = new_loop_var - stage; if (need_bound_check) - inbound = - analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) && - (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent); - if (analyzer_.CanProve(!inbound)) { - continue; - } + inbound = And( + pipeline_loop_->min <= skewed_loop_var, + (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent)); + Block new_block = Downcast( PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, max_stage_ != 1)(block)); @@ -674,6 +722,8 @@ class PipelineRewriter : public StmtExprMutator { PrimExpr normalized_access_index = is_unit_loop ? skewed_loop_var : skewed_loop_var + delta; + normalized_access_index = analyzer_.Simplify(normalized_access_index); + // Adjust the block predicate and the body according to the final loop // bound // [pipeline_loop_->min, extent). @@ -701,17 +751,18 @@ class PipelineRewriter : public StmtExprMutator { if (pipeline_info_[block].async) { auto &local_state = async_states_local[stage]; local_state.producer_head = normalized_access_index; + local_state.commit_predicate = inbound; BlockNode *n = new_block.CopyOnWrite(); n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, 1, n->body); } - new_blocks.push_back({stage, order, inbound, new_block, + new_blocks.push_back({stage, order, start, end, inbound, new_block, normalized_access_index, pipeline_info_[block].async}); } - PopulateWaitCounts(new_blocks, &async_states_local); + PopulateWaitCounts(new_blocks, &async_states_local, is_epilogue); auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local); @@ -1008,7 +1059,8 @@ class PipelineInjector : private StmtExprMutator { pipeline_async_stages.find(stage) != pipeline_async_stages.end(); PipelineAnnotation stage_order{ stage, - /*order=*/static_cast(pipeline_orders[i]->value), is_async}; + /*order=*/static_cast(pipeline_orders[i]->value), is_async, + /*original_idx=*/static_cast(i)}; pipeline_info.emplace(original_order[i], stage_order); } diff --git a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py index c0444043d..7cb1b5517 100644 --- a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py +++ b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py @@ -10,6 +10,7 @@ def _check(original, transformed): mod = tl.transform.InjectSoftwarePipeline()(mod) mod = tl.transform.Simplify()(mod) mod = tl.transform.LowerOpaqueBlock()(mod) + mod = tl.transform.Simplify()(mod) tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), True) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index b688ad9fa..cd205a6db 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -217,6 +217,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.LowerOpaqueBlock()(mod) + mod = tilelang.transform.Simplify()(mod) mod = tir.transform.NarrowDataType(32)(mod) mod = tilelang.transform.FlattenBuffer()(mod) # ConfigIndexBitwidth must be applied after FlattenBuffer