Skip to content

Commit

Permalink
This fixes a specific case when loop partitioning with indivisble
Browse files Browse the repository at this point in the history
factors and resulting nested loop is broken.
This is due to the fact that we are creating zero extent loops which
are fixed afterwards. However unroll pass breaks due to the zero extent
loop.
  • Loading branch information
kimishpatel committed Nov 14, 2019
1 parent a3ca1a4 commit 8627ad8
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 24 deletions.
52 changes: 28 additions & 24 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,17 +513,19 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min());
Expr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min);
// stop recursing on this interval if we can't prove it has non-negative length
pre_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(node, body_begin - min, pre_body);
if (!analyzer_.CanProve(body_begin == min)) {
Expr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min);
// stop recursing on this interval if we can't prove it has non-negative length
pre_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(node, body_begin - min, pre_body);
}
}
} else {
body_begin = min;
Expand All @@ -536,19 +538,21 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
// require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max+1);
// stop recursing on this interval if we can't prove it has non-negative length
post_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt post_body =
Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max+1);
// stop recursing on this interval if we can't prove it has non-negative length
post_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt post_body =
Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
}
}
} else {
post_doubt_begin = max + 1;
Expand Down
22 changes: 22 additions & 0 deletions tests/python/unittest/test_pass_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,27 @@ def test_conv_tiling():
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))


def test_multilevel_splitting_with_indivisble_factors():
import topi
A = tvm.placeholder((130,), dtype="float32")
B = topi.nn.relu(A)
s = tvm.create_schedule(B.op)
(y,) = s[B].op.axis
(yo, yi) = s[B].split(y, factor=8)
(yoo, yoi) = s[B].split(yo, factor=16)
s[B].reorder(yoo, yoi, yi)
s[B].unroll(yi)

## But this does the right thing.
with tvm.build_config(partition_const_loop=True):
lowered_body = tvm.lower(s, [A, B]).body
def visit_stmt(op):
return(isinstance(op, tvm.expr.Max))
num_max = collect_visit(lowered_body, visit_stmt)
assert num_max.count(True) == 10


def test_double_splitting_with_indivisible_factors():
m = 48
dtype="float32"
Expand Down Expand Up @@ -443,4 +464,5 @@ def test_simple_rfactor():
test_cce_loop_3()
test_conv_tiling()
test_double_splitting_with_indivisible_factors()
test_multilevel_splitting_with_indivisble_factors()
test_simple_rfactor()

0 comments on commit 8627ad8

Please sign in to comment.