Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broken loop partitioning due to recent changes. #4243

Merged
merged 1 commit into from
Nov 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,17 +513,19 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min());
Expr cond = (body_begin - min >= 0);
Copy link
Contributor

@umangyadav umangyadav Nov 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kimishpatel
we can also think about changing this condition to body_begin - min > 0 instead of using >= that can also prevent recursing on zero extent loops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@umangyadav, we can do this one, however we still generate zero extent loop. Just that we dont recursively partition anything generated for zero extent loop.
I still prefer to generate cleaner output that has no such zero extent loops that need fixing afterwards, unless somehow we fix it in the loop partition itself. This can be done as you suggested via RemoveNoOp however it seems little convoluted way to go about this. I would rather have loop partition generate clean IR with valid loop iterations.
Is there a reason why you would prefer to generate zero extent loop?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kimishpatel There is no specific reason or preference for generating the zero extent loops.

My point is that there is always possibility of having the zero extent loops.

We need to make sure not to recurse on those loops by having some mechanism e.g either guard conditions, RemoveNoOp etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@umangyadav, I understand. Do you have specific example where we can still generate zero extent loops despite this PR? Then it will be easier to see what would be a better fix.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@umangyadav I agree we might need a cleanup pass. But if we can avoid generating zero extent loops, it's desirable.

if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min);
// stop recursing on this interval if we can't prove it has non-negative length
pre_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(node, body_begin - min, pre_body);
if (!analyzer_.CanProve(body_begin == min)) {
Expr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min);
// stop recursing on this interval if we can't prove it has non-negative length
pre_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(node, body_begin - min, pre_body);
}
}
} else {
body_begin = min;
Expand All @@ -536,19 +538,21 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
// require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max+1);
// stop recursing on this interval if we can't prove it has non-negative length
post_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt post_body =
Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max+1);
// stop recursing on this interval if we can't prove it has non-negative length
post_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt post_body =
Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
}
}
} else {
post_doubt_begin = max + 1;
Expand Down
22 changes: 22 additions & 0 deletions tests/python/unittest/test_pass_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,27 @@ def test_conv_tiling():
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))


def test_multilevel_splitting_with_indivisble_factors():
import topi
A = tvm.placeholder((130,), dtype="float32")
B = topi.nn.relu(A)
s = tvm.create_schedule(B.op)
(y,) = s[B].op.axis
(yo, yi) = s[B].split(y, factor=8)
(yoo, yoi) = s[B].split(yo, factor=16)
s[B].reorder(yoo, yoi, yi)
s[B].unroll(yi)

## But this does the right thing.
with tvm.build_config(partition_const_loop=True):
lowered_body = tvm.lower(s, [A, B]).body
def visit_stmt(op):
return(isinstance(op, tvm.expr.Max))
num_max = collect_visit(lowered_body, visit_stmt)
assert num_max.count(True) == 10


def test_double_splitting_with_indivisible_factors():
m = 48
dtype="float32"
Expand Down Expand Up @@ -443,4 +464,5 @@ def test_simple_rfactor():
test_cce_loop_3()
test_conv_tiling()
test_double_splitting_with_indivisible_factors()
test_multilevel_splitting_with_indivisble_factors()
test_simple_rfactor()