diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index 76313f46d1c8..446c8ead7e8e 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -104,7 +104,10 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, } // Cond 2. For a block that generates a constant tensor, ignore all other conditions if (inline_const_tensor && block->reads.empty()) { - return InlineType::kInlineIntoConsumer; + Array consumer_srefs = GetConsumers(state, block_sref); + if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) { + return InlineType::kInlineIntoConsumer; + } } // Cond 3. The block doesn't contain any disallowed operators if (!is_pure_sptial && !disallow_op.empty() && HasOp(realize, disallow_op)) { diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py index a8ffa6ff9d3f..fcf6a8571b7f 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -320,6 +320,21 @@ def main(placeholder: T.Buffer[(1, 384), "int64"], placeholder_1: T.Buffer[(3052 T.writes(T_add[ax0, ax1, ax2]) T_add[ax0, ax1, ax2] = placeholder_1[T.min(T.max(T.int64(0), T.Select(T.cast(placeholder[ax0, ax1] < T.int64(0), "int32") != 0, placeholder[ax0, ax1] + T.int64(30522), placeholder[ax0, ax1])), T.int64(30521)), ax2] + placeholder_2[ax0, ax1, ax2] +@tvm.script.ir_module +class ConstConsumer: + @T.prim_func + def main(T_full: T.Buffer[(1, 12, 4096), "int64"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1, i2 in T.grid(1, 12, 4096): + with T.block("T_full"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads() + T.writes(T_full[ax0, ax1, ax2]) + T_full[ax0, ax1, ax2] = T.int64(0) + # pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks # fmt: on @@ -383,8 +398,21 @@ def test_inline_pure_spatial(): tvm.ir.assert_structural_equal(lhs=space.mod, rhs=AfterPureSpatial) +def test_inline_constant_tensor(): + mod = ConstConsumer + target = Target("cuda", host="llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=auto_inline(target=target), + ) + (space,) = ctx.space_generator.generate_design_space(mod=mod) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=ConstConsumer) + + if __name__ == "__main__": test_inline_consumer_chain() test_inline_into_cache() test_inline_into_multiple_consumers() test_inline_pure_spatial() + test_inline_constant_tensor()