Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Check producer predicate in ReverseComputeInline #13338

Merged
merged 5 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,32 @@ class OpaqueAccessError : public ScheduleError {
Block scope_root_;
};

class ProducerHasNonTrivialPredicateError : public ScheduleError {
public:
explicit ProducerHasNonTrivialPredicateError(IRModule mod, BlockRealize producer,
PrimExpr new_predicate)
: mod_(mod), producer_(producer), new_predicate_(new_predicate) {}

String FastErrorString() const final {
return "ScheduleError: The producer block has a non-trivial predicate.";
}

String DetailRenderTemplate() const final {
return "ScheduleError: The producer block {0} has a non-trivial predicate " +
PrettyPrint(producer_->predicate) +
" that cannot be implied "
"by the synthesized predicate " +
PrettyPrint(new_predicate_) + " of the new inlined block.";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {producer_}; }

IRModule mod_;
BlockRealize producer_;
PrimExpr new_predicate_;
};

/*!
* \brief The base class of the inliner, which handles:
* 1) Substitute a subtree with the specific block being inlined
Expand Down Expand Up @@ -533,10 +559,11 @@ class ReverseComputeInliner : public BaseInliner {
public:
explicit ReverseComputeInliner(const Buffer& inlined_buffer, const BlockNode* producer_block,
const BlockRealize& consumer_block_realize,
const StmtSRef& scope_root_sref)
const StmtSRef& scope_root_sref, const IRModule& mod)
: BaseInliner(inlined_buffer, consumer_block_realize->block, scope_root_sref),
producer_block_(producer_block),
consumer_block_(consumer_block_realize->block.get()) {
consumer_block_(consumer_block_realize->block.get()),
mod_(mod) {
// Initialize the predicates to ensure consumer block iters are in-bound
consumer_iter_in_bound_ = Bool(true);
for (const IterVar& iter : consumer_block_realize->block->iter_vars) {
Expand Down Expand Up @@ -632,8 +659,15 @@ class ReverseComputeInliner : public BaseInliner {
Stmt VisitStmt_(const BlockRealizeNode* op) final {
BlockRealize new_block_realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
if (op->block.get() == producer_block_) {
new_block_realize.CopyOnWrite()->predicate =
BuildInlinedConsumerPredicate(new_block_realize.get());
auto new_predicate = BuildInlinedConsumerPredicate(new_block_realize.get());

With<arith::ConstraintContext> ctx(&analyzer_, new_predicate);
if (!analyzer_.CanProve(op->predicate)) {
// We do not allow cases where the new predicate for the inlined block cannot
// imply the original predicate in the producer block.
throw ProducerHasNonTrivialPredicateError(mod_, GetRef<BlockRealize>(op), new_predicate);
}
new_block_realize.CopyOnWrite()->predicate = new_predicate;
}
return std::move(new_block_realize);
}
Expand Down Expand Up @@ -749,6 +783,8 @@ class ReverseComputeInliner : public BaseInliner {
PrimExpr consumer_iter_in_bound_{nullptr};
/*! \brief The arithmetic analyzer */
arith::Analyzer analyzer_;
/*! \brief The target module, only used for error reporting. */
const IRModule& mod_;
};

void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
Expand Down Expand Up @@ -814,7 +850,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref);
// Step 4. Analyze the block body
ReverseComputeInliner inliner(inlined_buffer, producer_block_sref->StmtAs<BlockNode>(),
consumer_block_realize, scope_root_sref);
consumer_block_realize, scope_root_sref, self->mod);
if (!inliner.BodyPatternAllowInline(consumer_block_realize)) {
throw BodyAnalysisError(true, self->mod, consumer_block);
}
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_meta_schedule_trace_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -3037,8 +3037,8 @@ def test_inline_order():
# reverse-inlined at the very end of ScheduleUsingAnchorTrace, where its producer block
# "conv2d_nhwc_reindex_shared" has the predicate
# T.where(((ax1_0 * 4 + ax1_1) * 32 + ax1_2) * 2 + ax1_3 < 64) due to anchor-block scheduling
# (see Conv2dInt8_with_predicate_scheduled). Currently, if we try to reverse-inline a block to
# its producer that has a predicate, the predicate disappears after reverse inlining.
# (see Conv2dInt8_with_predicate_scheduled). ReverseComputeInline cannot be applied in
# such cases.

def apply_trace(sch: Schedule) -> None:
b0 = sch.get_block(name="pad_temp", func_name="main")
Expand Down
Loading