Skip to content

Commit

Permalink
[MetaSchedule] Fix autoinline for single const consumer block (#12668)
Browse files Browse the repository at this point in the history
fix autoinline and add test
  • Loading branch information
Yuanjing Shi authored Sep 1, 2022
1 parent 038f15b commit 32f9a5f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/meta_schedule/schedule_rule/auto_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
}
// Cond 2. For a block that generates a constant tensor, ignore all other conditions
if (inline_const_tensor && block->reads.empty()) {
return InlineType::kInlineIntoConsumer;
Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref);
if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) {
return InlineType::kInlineIntoConsumer;
}
}
// Cond 3. The block doesn't contain any disallowed operators
if (!is_pure_sptial && !disallow_op.empty() && HasOp(realize, disallow_op)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,21 @@ def main(placeholder: T.Buffer[(1, 384), "int64"], placeholder_1: T.Buffer[(3052
T.writes(T_add[ax0, ax1, ax2])
T_add[ax0, ax1, ax2] = placeholder_1[T.min(T.max(T.int64(0), T.Select(T.cast(placeholder[ax0, ax1] < T.int64(0), "int32") != 0, placeholder[ax0, ax1] + T.int64(30522), placeholder[ax0, ax1])), T.int64(30521)), ax2] + placeholder_2[ax0, ax1, ax2]

@tvm.script.ir_module
class ConstConsumer:
@T.prim_func
def main(T_full: T.Buffer[(1, 12, 4096), "int64"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1, i2 in T.grid(1, 12, 4096):
with T.block("T_full"):
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads()
T.writes(T_full[ax0, ax1, ax2])
T_full[ax0, ax1, ax2] = T.int64(0)

# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on

Expand Down Expand Up @@ -383,8 +398,21 @@ def test_inline_pure_spatial():
tvm.ir.assert_structural_equal(lhs=space.mod, rhs=AfterPureSpatial)


def test_inline_constant_tensor():
mod = ConstConsumer
target = Target("cuda", host="llvm")
ctx = _create_context(
mod=mod,
target=target,
rule=auto_inline(target=target),
)
(space,) = ctx.space_generator.generate_design_space(mod=mod)
tvm.ir.assert_structural_equal(lhs=space.mod, rhs=ConstConsumer)


if __name__ == "__main__":
test_inline_consumer_chain()
test_inline_into_cache()
test_inline_into_multiple_consumers()
test_inline_pure_spatial()
test_inline_constant_tensor()

0 comments on commit 32f9a5f

Please sign in to comment.