Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 3a32b7 to 90581f
192 changes: 122 additions & 70 deletions src/transform/inject_pipeline.cc
Original file line number Diff line number Diff line change
@@ -1,22 +1,3 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file inject_software_pipeline.cc
* \brief Transform annotated loops into pipelined one that parallelize
Expand Down Expand Up @@ -79,6 +60,8 @@ struct PipelineAnnotation {
int stage;
int order;
bool async;
// Index of the statement in the original loop body order (SeqStmt order)
int original_idx = -1;
};

using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
Expand Down Expand Up @@ -304,15 +287,17 @@ class PipelineRewriter : public StmtExprMutator {
}

// Step 2: Emit the pipeline prologue, body and epilogue.
Stmt prologue = EmitImpl(pipeline_loop_->min,
pipeline_loop_->min + max_stage_, true, true);
Stmt body =
EmitImpl(pipeline_loop_->min + max_stage_,
pipeline_loop_->min + pipeline_loop_->extent, false, false);
Stmt epilogue = EmitImpl(
pipeline_loop_->min + pipeline_loop_->extent,
pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);

Stmt prologue =
EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true,
true, false);
Stmt body = EmitImpl(pipeline_loop_->min + max_stage_,
pipeline_loop_->min + pipeline_loop_->extent, false,
false, false);

Stmt epilogue =
EmitImpl(pipeline_loop_->min + pipeline_loop_->extent,
pipeline_loop_->min + pipeline_loop_->extent + max_stage_,
true, true, true);
SeqStmt stmt = SeqStmt({prologue, body, epilogue});

// Step 3: Make a new block that contains new buffer allocations after
Expand Down Expand Up @@ -515,69 +500,120 @@ class PipelineRewriter : public StmtExprMutator {
// A symbolic expression representing the index the latest async operation
// associated with this stage has written into, at the "current" iteration.
Optional<PrimExpr> producer_head;
// the commit block's predicate
PrimExpr commit_predicate{nullptr};
};

/*! Structure holding intermediate information for pipeline loop rewriting. */
struct RewrittenBlockInfo {
int stage;
int order;
PrimExpr start;
PrimExpr end;
PrimExpr predicate;
Block block;
PrimExpr access_index;
bool is_async;
};

void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
std::map<int, AsyncStateLocal> *async_states_local) {
std::map<int, AsyncStateLocal> *async_states_local,
bool is_epilogue = false) {
// Precompute which orders are present in this emit, and their access_index
std::unordered_map<int, PrimExpr> order_to_access_index;
std::unordered_set<int> present_orders;
for (const auto &nb : new_blocks) {
order_to_access_index[nb.order] = nb.access_index;
present_orders.insert(nb.order);
}
for (size_t i = 0; i < new_blocks.size(); ++i) {
// 1. Find the unique async producer stage
int producer_stage_idx = -1;
for (auto read_region : new_blocks[i].block->reads) {
for (const auto &read_region : new_blocks[i].block->reads) {
for (const auto &[stage, state] : async_states) {
if (stage <= new_blocks[i].stage &&
state.writes(read_region->buffer)) {
// Found an earlier stage where read_region->buffer was
// asynchronously written
// Currently only a single async stage dependency is supported
ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
<< "A dependency on multiple async stages is not supported";
producer_stage_idx = stage;
}
}
}
if (producer_stage_idx == -1)
if (producer_stage_idx == -1) {
// This block does not depend on any async producer
continue;
}
const auto &state = async_states[producer_stage_idx];

auto &dep_local_state = (*async_states_local)[producer_stage_idx];
PrimExpr in_flight_cnt = 0;
for (const auto &group : state.commit_groups) {
PrimExpr consumer_head = new_blocks[i].access_index;
PrimExpr producer_head;
if (dep_local_state.producer_head.defined()) {
producer_head = dep_local_state.producer_head.value();
// if the group is after the wait point, minus by 1
if (group.front() > new_blocks[i].order)
producer_head -= 1;
} else {
producer_head = state.producer_head;
}
in_flight_cnt += producer_head - consumer_head;
}

// We can relax the in-flight-count by the number of independent commit.
// 2. Use buffer_to_commit_group_ to find all actually dependent commit
// groups
std::unordered_set<int> dependent_groups;
for (const auto &read_region : new_blocks[i].block->reads) {
if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
dependent_groups.insert(
state.buffer_to_commit_group_.at(read_region->buffer.get()));
auto it = state.buffer_to_commit_group_.find(read_region->buffer.get());
if (it != state.buffer_to_commit_group_.end()) {
dependent_groups.insert(it->second);
}
}
for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
if (dependent_groups.count(i) == 0)
in_flight_cnt += 1;
else
break; // stop relaxing

// If there is no dependent commit group, no wait needs to be inserted
if (dependent_groups.empty()) {
continue;
}

// 3. Compute wait = max_g max(0, t_consumer - committed_before[g])
PrimExpr t_consumer = new_blocks[i].access_index;
PrimExpr wait_expr = make_zero(t_consumer.dtype());

PrimExpr current_head = dep_local_state.producer_head.defined()
? dep_local_state.producer_head.value()
: state.producer_head;
int consumer_order = new_blocks[i].order;

for (int g : dependent_groups) {
const auto &group = state.commit_groups[g];
if (group.empty())
continue;
int commit_order = group.back();
bool commit_present = present_orders.count(commit_order) > 0;

PrimExpr committed_before;
if (commit_present && commit_order <= consumer_order) {
// Commit point is in this iteration and earlier than the current
// consumer; this iteration's head is visible
auto commit_predicate = dep_local_state.commit_predicate;
if (analyzer_.CanProve(!commit_predicate,
arith::ProofStrength::kSymbolicBound)) {
// it means the commit block is not executed in this iteration
committed_before = new_blocks[i].start - 1;
} else if (is_epilogue) {
committed_before = new_blocks[i].start - 1;
} else {
committed_before = order_to_access_index.at(commit_order);
}
} else {
// Commit point is later than the current consumer or not in this
// iteration; only the previous iteration's head is visible
if (dep_local_state.producer_head.defined()) {
auto commit_predicate = dep_local_state.commit_predicate;
if (analyzer_.CanProve(!commit_predicate,
arith::ProofStrength::kSymbolicBound)) {
committed_before = new_blocks[i].start - 1;
} else if (is_epilogue) {
committed_before = new_blocks[i].start - 1;
} else {
committed_before = current_head - 1;
}
}
}

wait_expr = analyzer_.Simplify(committed_before - t_consumer);
}
in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
dep_local_state.pending_waits.push_back(
{static_cast<int>(i), in_flight_cnt});

wait_expr = analyzer_.Simplify(wait_expr);
dep_local_state.pending_waits.push_back({static_cast<int>(i), wait_expr});
}
}

Expand Down Expand Up @@ -630,7 +666,7 @@ class PipelineRewriter : public StmtExprMutator {
* \return The result loop.
*/
Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop,
bool need_bound_check) {
bool need_bound_check, bool is_epilogue = false) {
PrimExpr new_loop_var;
PrimExpr extent = end - start;
auto make_nop = []() {
Expand All @@ -642,7 +678,20 @@ class PipelineRewriter : public StmtExprMutator {
new_loop_var = start; // use constants as the loop var for unit loops
} else {
new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
// Bind the iteration domain [start, end) to strengthen analyzer facts.
analyzer_.Bind(Downcast<Var>(new_loop_var),
Range::FromMinExtent(start, end - start));
}
// Keep the bound constraints active for all analysis below.
// Only meaningful when the loop var is symbolic (non-unit loop).
std::unique_ptr<With<arith::ConstraintContext>> ctx_lb_guard;
std::unique_ptr<With<arith::ConstraintContext>> ctx_ub_guard;
if (!is_unit_loop) {
Var loop_iter = Downcast<Var>(new_loop_var);
ctx_lb_guard.reset(
new With<arith::ConstraintContext>(&analyzer_, loop_iter >= start));
ctx_ub_guard.reset(
new With<arith::ConstraintContext>(&analyzer_, loop_iter < end));
}

std::vector<RewrittenBlockInfo> new_blocks;
Expand All @@ -653,15 +702,14 @@ class PipelineRewriter : public StmtExprMutator {
for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_.at(block).stage;
int order = pipeline_info_.at(block).order;

PrimExpr inbound = Bool(true);
PrimExpr skewed_loop_var = new_loop_var - stage;
if (need_bound_check)
inbound =
analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
(skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
if (analyzer_.CanProve(!inbound)) {
continue;
}
inbound = And(
pipeline_loop_->min <= skewed_loop_var,
(skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent));

Block new_block = Downcast<Block>(
PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
pipeline_loop_, max_stage_ != 1)(block));
Expand All @@ -674,6 +722,8 @@ class PipelineRewriter : public StmtExprMutator {
PrimExpr normalized_access_index =
is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;

normalized_access_index = analyzer_.Simplify(normalized_access_index);

// Adjust the block predicate and the body according to the final loop
// bound
// [pipeline_loop_->min, extent).
Expand Down Expand Up @@ -701,17 +751,18 @@ class PipelineRewriter : public StmtExprMutator {
if (pipeline_info_[block].async) {
auto &local_state = async_states_local[stage];
local_state.producer_head = normalized_access_index;
local_state.commit_predicate = inbound;
BlockNode *n = new_block.CopyOnWrite();
n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
1, n->body);
}

new_blocks.push_back({stage, order, inbound, new_block,
new_blocks.push_back({stage, order, start, end, inbound, new_block,
normalized_access_index,
pipeline_info_[block].async});
}

PopulateWaitCounts(new_blocks, &async_states_local);
PopulateWaitCounts(new_blocks, &async_states_local, is_epilogue);

auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);

Expand Down Expand Up @@ -1008,7 +1059,8 @@ class PipelineInjector : private StmtExprMutator {
pipeline_async_stages.find(stage) != pipeline_async_stages.end();
PipelineAnnotation stage_order{
stage,
/*order=*/static_cast<int>(pipeline_orders[i]->value), is_async};
/*order=*/static_cast<int>(pipeline_orders[i]->value), is_async,
/*original_idx=*/static_cast<int>(i)};
pipeline_info.emplace(original_order[i], stage_order);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def _check(original, transformed):
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tl.transform.Simplify()(mod)
mod = tl.transform.LowerOpaqueBlock()(mod)
mod = tl.transform.Simplify()(mod)
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
True)

Expand Down
1 change: 1 addition & 0 deletions tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.InjectFenceProxy()(mod)

mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.Simplify()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
# ConfigIndexBitwidth must be applied after FlattenBuffer
Expand Down
Loading