Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Schedule][Bugfix] Fix decompose padding wrt the single child subtree #13646

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/tir/schedule/primitive/decompose_padding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ class PaddingInfoAnalyzer {

// Step 3. Analyze in-bound write region.
PrimExpr in_bound_predicate = RewritePredicate(pad_predicate && realize->predicate);
if (analyzer_->CanProveEqual(in_bound_predicate, 1)) {
SetError("The in-bound predicate is trivial");
return false;
}
Array<Range> in_bound_region = this->EstimateInBoundRegion(
/*iter_values=*/realize->iter_values, /*dom_map=*/dom_map,
/*in_bound_predicate=*/in_bound_predicate);
Expand Down Expand Up @@ -439,13 +443,14 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref,
analyzer.Bind(cur_loop->loop_var, range);
loops.push_back(cur_loop);

if (!found_const_filling_pos) {
if (cur_loop.same_as(const_filling_pos)) {
found_const_filling_pos = true;
if (cur_loop.same_as(const_filling_pos)) {
ICHECK(!found_const_filling_pos);
found_const_filling_pos = true;
if (!found_in_bound_filling_pos) {
found_in_bound_filling_pos = true;
in_bound_filling_pos = cur_loop;
}
}

if (!found_in_bound_filling_pos) {
} else if (!found_in_bound_filling_pos) {
if (!cur_loop->body->IsInstance<ForNode>() &&
!cur_loop->body->IsInstance<BlockRealizeNode>()) {
found_in_bound_filling_pos = true;
Expand Down
63 changes: 63 additions & 0 deletions tests/python/unittest/test_tir_schedule_decompose_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,5 +309,68 @@ def pooling_decompose_3(
check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_3, check_run=True)


def test_decompose_wrt_single_child_subtree():
"""Test the case when the decompose position is under the single child subtree"""

@T.prim_func
def pad_op(
x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer([1, 16, 231, 231], dtype="int8")
):
for i0, i1, i2, i3 in T.grid(1, 16, 231, 231):
with T.block("pad_temp"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
y[ax0, ax1, ax2, ax3] = T.if_then_else(
3 <= ax2 and ax2 < 228 and 3 <= ax3 and ax3 < 228,
x[ax0, ax1, ax2 - 3, ax3 - 3],
T.int8(0),
dtype="int8",
)

@T.prim_func
def pad_op_after(
x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer[(1, 16, 231, 231), "int8"]
):
for i0, i1 in T.grid(1, 16):
for i2, i3 in T.grid(231, 231):
with T.block("pad_temp_pad_const"):
ax0 = T.axis.spatial(1, 0)
ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3])
y[ax0, ax1, ax2, ax3] = T.int8(0)
for i2, i3 in T.grid(225, 225):
with T.block("pad_temp"):
ax0 = T.axis.spatial(1, 0)
ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3])
y[ax0, ax1, ax2 + 3, ax3 + 3] = x[ax0, ax1, ax2, ax3]

sch = tir.Schedule(pad_op, debug_mask="all")
pad = sch.get_block("pad_temp")
_, _, h, _ = sch.get_loops(pad)
sch.decompose_padding(pad, h)
check_decompose_padding(pad_op, sch.mod["main"], pad_op_after, check_run=True)


def test_not_to_decompose_trivial_predicate():
"""Test the case when the padding condition is trivial"""

@T.prim_func
def trivial_pad(
x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer([1, 16, 225, 225], dtype="int8")
):
for i0, i1, i2, i3 in T.grid(1, 16, 225, 225):
with T.block("pad_temp"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
y[ax0, ax1, ax2, ax3] = T.if_then_else(
0 <= ax2 and ax2 < 225 and 0 <= ax3 and ax3 < 225,
x[ax0, ax1, ax2, ax3],
T.int8(0),
dtype="int8",
)

sch = tir.Schedule(trivial_pad, debug_mask="all")
pad = sch.get_block("pad_temp")
_, _, h, _ = sch.get_loops(pad)
assert not sch.can_decompose_padding(pad, h)


if __name__ == "__main__":
tvm.testing.main()