From a8157e26e8d2449de82c37634db15059e2387d0f Mon Sep 17 00:00:00 2001 From: wrongtest Date: Tue, 20 Dec 2022 14:44:47 +0800 Subject: [PATCH] Fix bug when decompose padding wrt the single child subtree --- .../schedule/primitive/decompose_padding.cc | 17 +++-- .../test_tir_schedule_decompose_padding.py | 63 +++++++++++++++++++ 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc index c41760876722..e657b4f4663d 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -114,6 +114,10 @@ class PaddingInfoAnalyzer { // Step 3. Analyze in-bound write region. PrimExpr in_bound_predicate = RewritePredicate(pad_predicate && realize->predicate); + if (analyzer_->CanProveEqual(in_bound_predicate, 1)) { + SetError("The in-bound predicate is trivial"); + return false; + } Array in_bound_region = this->EstimateInBoundRegion( /*iter_values=*/realize->iter_values, /*dom_map=*/dom_map, /*in_bound_predicate=*/in_bound_predicate); @@ -439,13 +443,14 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, analyzer.Bind(cur_loop->loop_var, range); loops.push_back(cur_loop); - if (!found_const_filling_pos) { - if (cur_loop.same_as(const_filling_pos)) { - found_const_filling_pos = true; + if (cur_loop.same_as(const_filling_pos)) { + ICHECK(!found_const_filling_pos); + found_const_filling_pos = true; + if (!found_in_bound_filling_pos) { + found_in_bound_filling_pos = true; + in_bound_filling_pos = cur_loop; } - } - - if (!found_in_bound_filling_pos) { + } else if (!found_in_bound_filling_pos) { if (!cur_loop->body->IsInstance() && !cur_loop->body->IsInstance()) { found_in_bound_filling_pos = true; diff --git a/tests/python/unittest/test_tir_schedule_decompose_padding.py b/tests/python/unittest/test_tir_schedule_decompose_padding.py index a3fc4326a3c9..ead8b0b33262 100644 --- a/tests/python/unittest/test_tir_schedule_decompose_padding.py +++ b/tests/python/unittest/test_tir_schedule_decompose_padding.py @@ -309,5 +309,68 @@ def pooling_decompose_3( check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_3, check_run=True) +def test_decompose_wrt_single_child_subtree(): + """Test the case when the decompose position is under the single child subtree""" + + @T.prim_func + def pad_op( + x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer([1, 16, 231, 231], dtype="int8") + ): + for i0, i1, i2, i3 in T.grid(1, 16, 231, 231): + with T.block("pad_temp"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + y[ax0, ax1, ax2, ax3] = T.if_then_else( + 3 <= ax2 and ax2 < 228 and 3 <= ax3 and ax3 < 228, + x[ax0, ax1, ax2 - 3, ax3 - 3], + T.int8(0), + dtype="int8", + ) + + @T.prim_func + def pad_op_after( + x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer[(1, 16, 231, 231), "int8"] + ): + for i0, i1 in T.grid(1, 16): + for i2, i3 in T.grid(231, 231): + with T.block("pad_temp_pad_const"): + ax0 = T.axis.spatial(1, 0) + ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3]) + y[ax0, ax1, ax2, ax3] = T.int8(0) + for i2, i3 in T.grid(225, 225): + with T.block("pad_temp"): + ax0 = T.axis.spatial(1, 0) + ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3]) + y[ax0, ax1, ax2 + 3, ax3 + 3] = x[ax0, ax1, ax2, ax3] + + sch = tir.Schedule(pad_op, debug_mask="all") + pad = sch.get_block("pad_temp") + _, _, h, _ = sch.get_loops(pad) + sch.decompose_padding(pad, h) + check_decompose_padding(pad_op, sch.mod["main"], pad_op_after, check_run=True) + + +def test_not_to_decompose_trivial_predicate(): + """Test the case when the padding condition is trivial""" + + @T.prim_func + def trivial_pad( + x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer([1, 16, 225, 225], dtype="int8") + ): + for i0, i1, i2, i3 in T.grid(1, 16, 225, 225): + with T.block("pad_temp"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + y[ax0, ax1, ax2, ax3] = T.if_then_else( + 0 <= ax2 and ax2 < 225 and 0 <= ax3 and ax3 < 225, + x[ax0, ax1, ax2, ax3], + T.int8(0), + dtype="int8", + ) + + sch = tir.Schedule(trivial_pad, debug_mask="all") + pad = sch.get_block("pad_temp") + _, _, h, _ = sch.get_loops(pad) + assert not sch.can_decompose_padding(pad, h) + + if __name__ == "__main__": tvm.testing.main()