Skip to content

Commit

Permalink
Fix bug when decompose padding wrt the single child subtree
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif committed Dec 20, 2022

Verified

This commit was signed with the committer’s verified signature.
mrgrain Momo Kornher
1 parent 6161a8d commit 7da3a79
Showing 2 changed files with 44 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/tir/schedule/primitive/decompose_padding.cc
Original file line number Diff line number Diff line change
@@ -442,6 +442,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref,
if (!found_const_filling_pos) {
if (cur_loop.same_as(const_filling_pos)) {
found_const_filling_pos = true;
found_in_bound_filling_pos = true;
}
}

43 changes: 43 additions & 0 deletions tests/python/unittest/test_tir_schedule_decompose_padding.py
Original file line number Diff line number Diff line change
@@ -309,5 +309,48 @@ def pooling_decompose_3(
check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_3, check_run=True)


def test_decompose_wrt_single_child_subtree():
"""Test the case when the decompose position is under the same subtree"""

@T.prim_func
def pad_op(
x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer([1, 16, 231, 231], dtype="int8")
):
for i0, i1, i2, i3 in T.grid(1, 16, 231, 231):
with T.block("pad_temp"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
y[ax0, ax1, ax2, ax3] = T.if_then_else(
3 <= ax2 and ax2 < 228 and 3 <= ax3 and ax3 < 228,
x[ax0, ax1, ax2 - 3, ax3 - 3],
T.int8(0),
dtype="int8",
)

@T.prim_func
def pad_op_after(
x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer[(1, 16, 231, 231), "int8"]
):
for i0, i1 in T.grid(1, 16):
for i2, i3 in T.grid(231, 231):
with T.block("pad_temp_pad_const"):
ax0 = T.axis.spatial(1, 0)
ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3])
y[ax0, ax1, ax2, ax3] = T.int8(0)
for i2, i3 in T.grid(231, 225):
with T.block("pad_temp"):
T.where(3 <= i2 and i2 < 228)
ax0 = T.axis.spatial(1, 0)
ax1 = T.axis.spatial(16, i1)
ax2 = T.axis.spatial(225, i2 - 3)
ax3 = T.axis.spatial(225, i3)
y[ax0, ax1, ax2 + 3, ax3 + 3] = x[ax0, ax1, ax2, ax3]

sch = tir.Schedule(pad_op, debug_mask="all")
pad = sch.get_block("pad_temp")
_, _, h, _ = sch.get_loops(pad)
sch.decompose_padding(pad, h)
check_decompose_padding(pad_op, sch.mod["main"], pad_op_after, check_run=True)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 7da3a79

Please sign in to comment.