Skip to content

Commit f2858fa

Browse files
authored
[Enhancement] Refactor inflight computing to support dynamic pipeline extents (#1399)
* [Build] Update CMake configuration for tilelang_cython_wrapper installation - Adjusted output directories for the tilelang_cython_wrapper to ensure that development builds place the extension in build/lib. - Updated installation paths to place the extension in tilelang/lib within the wheel, improving organization and avoiding potential conflicts with other modules. - Modified the internal library path exposure in env.py to prevent shadowing of common module names, enhancing compatibility and usability in user projects. * [Build] Standardize output directories for tilelang libraries - Set output directories for both tilelang and tilelang_module libraries to "${CMAKE_BINARY_DIR}/lib" for consistency in development builds. - This change enhances organization and ensures that all build artifacts are located in a unified directory structure. * [Refactor] Update TVM subproject and enhance pipeline loop handling - Updated the TVM subproject to commit 90581fe9e5287bbcf1844ad14255a1e1e8cdf7f0. - Added new fields to `PipelineAnnotation` and `RewrittenBlockInfo` structures to track original statement indices and improve async state management. - Refactored `EmitImpl` and `PopulateWaitCounts` methods to enhance clarity and functionality, including better handling of commit groups and wait counts. - Simplified access index calculations and strengthened analyzer constraints for loop bounds. * [Cleanup] Remove license block and unused includes from inject_pipeline.cc - Eliminated the Apache license block from the top of the file to streamline the code. - Removed unused include directives for memory and stringstream to enhance code clarity and reduce unnecessary dependencies. * [Refactor] Enhance transformation pipeline and test execution - Added an additional Simplify transformation in the InjectSoftwarePipeline to improve optimization. - Updated the test file to call `test_trival_pipeline()` directly, commenting out the previous main execution for better test isolation.
1 parent bc084aa commit f2858fa

File tree

4 files changed

+125
-71
lines changed

4 files changed

+125
-71
lines changed

3rdparty/tvm

Submodule tvm updated from 3a32b76 to 90581fe

src/transform/inject_pipeline.cc

Lines changed: 122 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,3 @@
1-
/*
2-
* Licensed to the Apache Software Foundation (ASF) under one
3-
* or more contributor license agreements. See the NOTICE file
4-
* distributed with this work for additional information
5-
* regarding copyright ownership. The ASF licenses this file
6-
* to you under the Apache License, Version 2.0 (the
7-
* "License"); you may not use this file except in compliance
8-
* with the License. You may obtain a copy of the License at
9-
*
10-
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
12-
* Unless required by applicable law or agreed to in writing,
13-
* software distributed under the License is distributed on an
14-
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15-
* KIND, either express or implied. See the License for the
16-
* specific language governing permissions and limitations
17-
* under the License.
18-
*/
19-
201
/*!
212
* \file inject_software_pipeline.cc
223
* \brief Transform annotated loops into pipelined one that parallelize
@@ -79,6 +60,8 @@ struct PipelineAnnotation {
7960
int stage;
8061
int order;
8162
bool async;
63+
// Index of the statement in the original loop body order (SeqStmt order)
64+
int original_idx = -1;
8265
};
8366

8467
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
@@ -304,15 +287,17 @@ class PipelineRewriter : public StmtExprMutator {
304287
}
305288

306289
// Step 2: Emit the pipeline prologue, body and epilogue.
307-
Stmt prologue = EmitImpl(pipeline_loop_->min,
308-
pipeline_loop_->min + max_stage_, true, true);
309-
Stmt body =
310-
EmitImpl(pipeline_loop_->min + max_stage_,
311-
pipeline_loop_->min + pipeline_loop_->extent, false, false);
312-
Stmt epilogue = EmitImpl(
313-
pipeline_loop_->min + pipeline_loop_->extent,
314-
pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
315-
290+
Stmt prologue =
291+
EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true,
292+
true, false);
293+
Stmt body = EmitImpl(pipeline_loop_->min + max_stage_,
294+
pipeline_loop_->min + pipeline_loop_->extent, false,
295+
false, false);
296+
297+
Stmt epilogue =
298+
EmitImpl(pipeline_loop_->min + pipeline_loop_->extent,
299+
pipeline_loop_->min + pipeline_loop_->extent + max_stage_,
300+
true, true, true);
316301
SeqStmt stmt = SeqStmt({prologue, body, epilogue});
317302

318303
// Step 3: Make a new block that contains new buffer allocations after
@@ -515,69 +500,120 @@ class PipelineRewriter : public StmtExprMutator {
515500
// A symbolic expression representing the index the latest async operation
516501
// associated with this stage has written into, at the "current" iteration.
517502
Optional<PrimExpr> producer_head;
503+
// the commit block's predicate
504+
PrimExpr commit_predicate{nullptr};
518505
};
519506

520507
/*! Structure holding intermediate information for pipeline loop rewriting. */
521508
struct RewrittenBlockInfo {
522509
int stage;
523510
int order;
511+
PrimExpr start;
512+
PrimExpr end;
524513
PrimExpr predicate;
525514
Block block;
526515
PrimExpr access_index;
527516
bool is_async;
528517
};
529518

530519
void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
531-
std::map<int, AsyncStateLocal> *async_states_local) {
520+
std::map<int, AsyncStateLocal> *async_states_local,
521+
bool is_epilogue = false) {
522+
// Precompute which orders are present in this emit, and their access_index
523+
std::unordered_map<int, PrimExpr> order_to_access_index;
524+
std::unordered_set<int> present_orders;
525+
for (const auto &nb : new_blocks) {
526+
order_to_access_index[nb.order] = nb.access_index;
527+
present_orders.insert(nb.order);
528+
}
532529
for (size_t i = 0; i < new_blocks.size(); ++i) {
530+
// 1. Find the unique async producer stage
533531
int producer_stage_idx = -1;
534-
for (auto read_region : new_blocks[i].block->reads) {
532+
for (const auto &read_region : new_blocks[i].block->reads) {
535533
for (const auto &[stage, state] : async_states) {
536534
if (stage <= new_blocks[i].stage &&
537535
state.writes(read_region->buffer)) {
538-
// Found an earlier stage where read_region->buffer was
539-
// asynchronously written
536+
// Currently only a single async stage dependency is supported
540537
ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
541538
<< "A dependency on multiple async stages is not supported";
542539
producer_stage_idx = stage;
543540
}
544541
}
545542
}
546-
if (producer_stage_idx == -1)
543+
if (producer_stage_idx == -1) {
544+
// This block does not depend on any async producer
547545
continue;
546+
}
548547
const auto &state = async_states[producer_stage_idx];
548+
549549
auto &dep_local_state = (*async_states_local)[producer_stage_idx];
550-
PrimExpr in_flight_cnt = 0;
551-
for (const auto &group : state.commit_groups) {
552-
PrimExpr consumer_head = new_blocks[i].access_index;
553-
PrimExpr producer_head;
554-
if (dep_local_state.producer_head.defined()) {
555-
producer_head = dep_local_state.producer_head.value();
556-
// if the group is after the wait point, minus by 1
557-
if (group.front() > new_blocks[i].order)
558-
producer_head -= 1;
559-
} else {
560-
producer_head = state.producer_head;
561-
}
562-
in_flight_cnt += producer_head - consumer_head;
563-
}
564550

565-
// We can relax the in-flight-count by the number of independent commit.
551+
// 2. Use buffer_to_commit_group_ to find all actually dependent commit
552+
// groups
566553
std::unordered_set<int> dependent_groups;
567554
for (const auto &read_region : new_blocks[i].block->reads) {
568-
if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
569-
dependent_groups.insert(
570-
state.buffer_to_commit_group_.at(read_region->buffer.get()));
555+
auto it = state.buffer_to_commit_group_.find(read_region->buffer.get());
556+
if (it != state.buffer_to_commit_group_.end()) {
557+
dependent_groups.insert(it->second);
558+
}
571559
}
572-
for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
573-
if (dependent_groups.count(i) == 0)
574-
in_flight_cnt += 1;
575-
else
576-
break; // stop relaxing
560+
561+
// If there is no dependent commit group, no wait needs to be inserted
562+
if (dependent_groups.empty()) {
563+
continue;
564+
}
565+
566+
// 3. Compute wait = max_g max(0, t_consumer - committed_before[g])
567+
PrimExpr t_consumer = new_blocks[i].access_index;
568+
PrimExpr wait_expr = make_zero(t_consumer.dtype());
569+
570+
PrimExpr current_head = dep_local_state.producer_head.defined()
571+
? dep_local_state.producer_head.value()
572+
: state.producer_head;
573+
int consumer_order = new_blocks[i].order;
574+
575+
for (int g : dependent_groups) {
576+
const auto &group = state.commit_groups[g];
577+
if (group.empty())
578+
continue;
579+
int commit_order = group.back();
580+
bool commit_present = present_orders.count(commit_order) > 0;
581+
582+
PrimExpr committed_before;
583+
if (commit_present && commit_order <= consumer_order) {
584+
// Commit point is in this iteration and earlier than the current
585+
// consumer; this iteration's head is visible
586+
auto commit_predicate = dep_local_state.commit_predicate;
587+
if (analyzer_.CanProve(!commit_predicate,
588+
arith::ProofStrength::kSymbolicBound)) {
589+
// it means the commit block is not executed in this iteration
590+
committed_before = new_blocks[i].start - 1;
591+
} else if (is_epilogue) {
592+
committed_before = new_blocks[i].start - 1;
593+
} else {
594+
committed_before = order_to_access_index.at(commit_order);
595+
}
596+
} else {
597+
// Commit point is later than the current consumer or not in this
598+
// iteration; only the previous iteration's head is visible
599+
if (dep_local_state.producer_head.defined()) {
600+
auto commit_predicate = dep_local_state.commit_predicate;
601+
if (analyzer_.CanProve(!commit_predicate,
602+
arith::ProofStrength::kSymbolicBound)) {
603+
committed_before = new_blocks[i].start - 1;
604+
} else if (is_epilogue) {
605+
committed_before = new_blocks[i].start - 1;
606+
} else {
607+
committed_before = current_head - 1;
608+
}
609+
}
610+
}
611+
612+
wait_expr = analyzer_.Simplify(committed_before - t_consumer);
577613
}
578-
in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
579-
dep_local_state.pending_waits.push_back(
580-
{static_cast<int>(i), in_flight_cnt});
614+
615+
wait_expr = analyzer_.Simplify(wait_expr);
616+
dep_local_state.pending_waits.push_back({static_cast<int>(i), wait_expr});
581617
}
582618
}
583619

@@ -630,7 +666,7 @@ class PipelineRewriter : public StmtExprMutator {
630666
* \return The result loop.
631667
*/
632668
Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop,
633-
bool need_bound_check) {
669+
bool need_bound_check, bool is_epilogue = false) {
634670
PrimExpr new_loop_var;
635671
PrimExpr extent = end - start;
636672
auto make_nop = []() {
@@ -642,7 +678,20 @@ class PipelineRewriter : public StmtExprMutator {
642678
new_loop_var = start; // use constants as the loop var for unit loops
643679
} else {
644680
new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
645-
analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
681+
// Bind the iteration domain [start, end) to strengthen analyzer facts.
682+
analyzer_.Bind(Downcast<Var>(new_loop_var),
683+
Range::FromMinExtent(start, end - start));
684+
}
685+
// Keep the bound constraints active for all analysis below.
686+
// Only meaningful when the loop var is symbolic (non-unit loop).
687+
std::unique_ptr<With<arith::ConstraintContext>> ctx_lb_guard;
688+
std::unique_ptr<With<arith::ConstraintContext>> ctx_ub_guard;
689+
if (!is_unit_loop) {
690+
Var loop_iter = Downcast<Var>(new_loop_var);
691+
ctx_lb_guard.reset(
692+
new With<arith::ConstraintContext>(&analyzer_, loop_iter >= start));
693+
ctx_ub_guard.reset(
694+
new With<arith::ConstraintContext>(&analyzer_, loop_iter < end));
646695
}
647696

648697
std::vector<RewrittenBlockInfo> new_blocks;
@@ -653,15 +702,14 @@ class PipelineRewriter : public StmtExprMutator {
653702
for (const Block &block : ordered_stmts_) {
654703
int stage = pipeline_info_.at(block).stage;
655704
int order = pipeline_info_.at(block).order;
705+
656706
PrimExpr inbound = Bool(true);
657707
PrimExpr skewed_loop_var = new_loop_var - stage;
658708
if (need_bound_check)
659-
inbound =
660-
analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
661-
(skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
662-
if (analyzer_.CanProve(!inbound)) {
663-
continue;
664-
}
709+
inbound = And(
710+
pipeline_loop_->min <= skewed_loop_var,
711+
(skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent));
712+
665713
Block new_block = Downcast<Block>(
666714
PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
667715
pipeline_loop_, max_stage_ != 1)(block));
@@ -674,6 +722,8 @@ class PipelineRewriter : public StmtExprMutator {
674722
PrimExpr normalized_access_index =
675723
is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
676724

725+
normalized_access_index = analyzer_.Simplify(normalized_access_index);
726+
677727
// Adjust the block predicate and the body according to the final loop
678728
// bound
679729
// [pipeline_loop_->min, extent).
@@ -701,17 +751,18 @@ class PipelineRewriter : public StmtExprMutator {
701751
if (pipeline_info_[block].async) {
702752
auto &local_state = async_states_local[stage];
703753
local_state.producer_head = normalized_access_index;
754+
local_state.commit_predicate = inbound;
704755
BlockNode *n = new_block.CopyOnWrite();
705756
n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
706757
1, n->body);
707758
}
708759

709-
new_blocks.push_back({stage, order, inbound, new_block,
760+
new_blocks.push_back({stage, order, start, end, inbound, new_block,
710761
normalized_access_index,
711762
pipeline_info_[block].async});
712763
}
713764

714-
PopulateWaitCounts(new_blocks, &async_states_local);
765+
PopulateWaitCounts(new_blocks, &async_states_local, is_epilogue);
715766

716767
auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
717768

@@ -1008,7 +1059,8 @@ class PipelineInjector : private StmtExprMutator {
10081059
pipeline_async_stages.find(stage) != pipeline_async_stages.end();
10091060
PipelineAnnotation stage_order{
10101061
stage,
1011-
/*order=*/static_cast<int>(pipeline_orders[i]->value), is_async};
1062+
/*order=*/static_cast<int>(pipeline_orders[i]->value), is_async,
1063+
/*original_idx=*/static_cast<int>(i)};
10121064
pipeline_info.emplace(original_order[i], stage_order);
10131065
}
10141066

testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def _check(original, transformed):
1010
mod = tl.transform.InjectSoftwarePipeline()(mod)
1111
mod = tl.transform.Simplify()(mod)
1212
mod = tl.transform.LowerOpaqueBlock()(mod)
13+
mod = tl.transform.Simplify()(mod)
1314
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
1415
True)
1516

tilelang/engine/phase.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
217217
mod = tilelang.transform.InjectFenceProxy()(mod)
218218

219219
mod = tilelang.transform.LowerOpaqueBlock()(mod)
220+
mod = tilelang.transform.Simplify()(mod)
220221
mod = tir.transform.NarrowDataType(32)(mod)
221222
mod = tilelang.transform.FlattenBuffer()(mod)
222223
# ConfigIndexBitwidth must be applied after FlattenBuffer

0 commit comments

Comments
 (0)