Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] Improve div/mod in rewrite simplifier #3149

Merged
merged 7 commits into from
May 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/arithmetic/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <tvm/expr_operator.h>
#include <tvm/ir_functor_ext.h>
#include <limits>
#include <utility>
#include <unordered_map>
#include "pattern_match.h"

namespace tvm {
Expand Down
102 changes: 90 additions & 12 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,6 @@ TryCompare(const Expr& x, int64_t val) {
return kLT;
}
}
if (val == 0) {
ModularSet dmod = parent_->modular_set(diff);
if (dmod->base != 0) {
return kNE;
}
}
ConstIntBound dbound = parent_->const_int_bound(diff);
if (dbound->min_value > val) {
return kGT;
Expand All @@ -99,6 +93,12 @@ TryCompare(const Expr& x, int64_t val) {
if (dbound->max_value <= val) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for enhancement , do we need to change dbound->max_value <= val into dbound->max_value == val? similar existing issue in line 90.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is really an enhancement, current formulation seems more straightforward, e.g. if max_value were an Expr, we could write can_prove(max_value <= val) whereas can_prove(max_value == val) would be "less complete".

return kLE;
}
if (val == 0) {
ModularSet dmod = parent_->modular_set(diff);
if (dmod->base != 0) {
return kNE;
}
}
return kUnknown;
}

Expand Down Expand Up @@ -284,11 +284,39 @@ Mutate_(const Sub* op, const Expr& self) {
CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0));

// modular-div simplification
// Always pre-condition on positive integer domain
// Note that c*(x/c) + x % c == x is true for every x and c != 0 even for truncated division
TVM_TRY_REWRITE_IF(x - (x / c1) * c1, x % c1,
CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0);
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF((x / c1) * c1 - x, 0 - (x % c1),
CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0);
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x - ((x + y) / c1) * c1, (x + y) % c1 - y,
c1.Eval()->value != 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we replace c2 by y (a non-const), would the rule still be correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be. I'll fix it and add more tests.

TVM_TRY_REWRITE_IF(((x + y) / c1) * c1 - x, y - ((x + y) % c1),
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x - ((x - y) / c1) * c1, (x - y) % c1 + y,
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(((x - y) / c1) * c1 - x, 0 - (x - y) % c1 - y,
c1.Eval()->value != 0);

TVM_TRY_REWRITE_IF(x * c2 - (x / c1) * c3, (x % c1) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF((x / c1) * c3 - x * c2, 0 - (x % c1) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(x * c2 - ((x + y) / c1) * c3, ((x + y) % c1 - y) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(((x + y) / c1) * c3 - x * c2, (y - ((x + y) % c1)) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(x * c2 - ((x - y) / c1) * c3, ((x - y) % c1 + y) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(((x - y) / c1) * c3 - x * c2, (0 - (x - y) % c1 - y) * c2,
c1.Eval()->value != 0 &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgrechanik-h please update the comment to mark all the cases that require special truc div rule. So we can be sure about this later

// NOTE: trunc div required

c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rules are all correct to me. Thanks.
What happens if we have something like c2 * x - c3 * (x / c1) ( I used commutativity) ? Does the simplifier automatically use TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1); in its attempt to simplify it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this particular example works thanks to this canonicalization rule. I've added a couple of tests to be sure.

TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3,
((x + (c1 % c3)) % c3 + (c1 - c2)) / c3,
CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) &&
Expand Down Expand Up @@ -348,6 +376,7 @@ Mutate_(const Mul* op, const Expr& self) {

// canonicalization
TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1);
TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1);
TVM_TRY_RECURSIVE_REWRITE_IF(
(x - y) * c1, (y - x) * (0 - c1),
c1.Eval()->value < 0);
Expand Down Expand Up @@ -396,6 +425,16 @@ Mutate_(const Div* op, const Expr& self) {
// We adopt the default C division uses truncation instead of floordiv.
// This means most rules need to check non-negativeness of the operands.

// TryConstFold doesn't work for negative cases because it is also used by legacy
// parts of tvm which still assume euclidean div. In this simplifier we assume that the division
// is truncated, so perform const folding again.
// NOTE: trunc div required
if ((c1 / c2).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
return make_const(op->type, c1val / c2val);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// while it is always true for trunc div
// restrict to common case(positive div)
TVM_TRY_REWRITE_IF((x / c1) / c2, x / (c1 * c2),
Expand Down Expand Up @@ -608,6 +647,12 @@ Mutate_(const Mod* op, const Expr& self) {
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));

// canonicalization: x % c == x % (-c) for truncated division
// NOTE: trunc div required
TVM_TRY_RECURSIVE_REWRITE_IF(x % c1,
x % PConst<Expr>(make_const(op->type, -c1.Eval()->value)),
c1.Eval()->value < 0);

// try modular analysis
if ((x % c1).Match(ret)) {
ModularSet mod = parent_->modular_set(x.Eval());
Expand Down Expand Up @@ -1025,20 +1070,53 @@ Mutate_(const LT* op, const Expr& self) {
TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x,
c1.Eval()->value < 0);

// require c1 > 0 to work for any div mode
TVM_TRY_REWRITE_IF(x * c2 < c1, x < (c1 - 1) / c2 + 1,
c1.Eval()->value > 0 &&
c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2,
// NOTE: trunc div required
TVM_TRY_REWRITE_IF(x * c2 < c1, x < c1 / c2,
c1.Eval()->value <= 0 &&
c2.Eval()->value > 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is true if x is an integer. Can we assume that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can, since these rules are guarded by the condition IsIndexType(op->a.type()), and also all constants are declared as integers (PVar<Integer>).

// NOTE: trunc div required (euclidean is ok too, floored is not)
TVM_TRY_REWRITE_IF(x * c2 < c1, (c1 - 1) / c2 - 1 < x,
c1.Eval()->value > 0 &&
c2.Eval()->value < 0);
// NOTE: trunc div required (floored is ok too, euclidean is not)
TVM_TRY_REWRITE_IF(x * c2 < c1, c1 / c2 < x,
c1.Eval()->value <= 0 &&
c2.Eval()->value < 0);

// NOTE: trunc div required
TVM_TRY_REWRITE_IF(c1 < x * c2, (c1 + 1) / c2 - 1 < x,
c1.Eval()->value < 0 &&
c2.Eval()->value > 0);

TVM_TRY_REWRITE_IF(c1 < x * c2, c1 / c2 < x,
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0);
// NOTE: trunc div required (floored is ok too, euclidean is not)
TVM_TRY_REWRITE_IF(c1 < x * c2, x < (c1 + 1) / c2 + 1,
c1.Eval()->value < 0 &&
c2.Eval()->value < 0);
// NOTE: trunc div required (euclidean is ok too, floored is not)
TVM_TRY_REWRITE_IF(c1 < x * c2, x < c1 / c2,
c1.Eval()->value >= 0 &&
c2.Eval()->value < 0);

TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2,
c1.Eval()->value > 0 &&
c2.Eval()->value > 0);
// NOTE: trunc div required
TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * (c2 - 1) + 1,
c1.Eval()->value > 0 &&
c2.Eval()->value <= 0);

TVM_TRY_REWRITE_IF(c1 < x / c2, (c1 + 1) * c2 - 1 < x,
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0);
// NOTE: trunc div required
TVM_TRY_REWRITE_IF(c1 < x / c2, c1 * c2 < x,
c1.Eval()->value < 0 &&
c2.Eval()->value > 0);

// division related simplificationx
// invariance for any div mod: x - (x / c1) * c1 == x % c1
Expand Down
69 changes: 69 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,25 @@ def test_sub_index_simplify():
ck.verify(x - (x / 3) * 3, x % 3)
ck.verify((x + 5) / 3 - x / 3, (((x + 2) % 3) + 5)/ 3)

ck.verify(y - (y / (-5)) * (-5), y % 5)
ck.verify((y / 3) * 3 - y, 0 - y % 3)
ck.verify(y - ((y - 6) / 5) * 5, (y + (-6)) % 5 + 6)
ck.verify(((y - 6) / 5) * 5 - y, (-6) - (y + (-6)) % 5)
ck.verify(y - ((y + z) / 5) * 5, (y + z) % 5 - z)
ck.verify(((y + z) / 5) * 5 - y, z - (y + z) % 5)
ck.verify(y - ((y - z) / 5) * 5, (y - z) % 5 + z)
ck.verify(((y - z) / 5) * 5 - y, 0 - (y - z) % 5 - z)

ck.verify(y * 3 - (y / 2) * 6, (y % 2) * 3)
ck.verify((y / 3) * 6 - y * 2, (y % 3) * (-2))
ck.verify(y * 5 - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5)
ck.verify(y * 5 - ((y - z) / 2) * 10, ((y - z) % 2 + z) * 5)
ck.verify(((y + z) / 3) * 6 - y * 2, (z - (y + z) % 3) * 2)
ck.verify(((y - z) / 3) * 6 - y * 2, (0 - (y - z) % 3 - z) * 2)
ck.verify(5 * y - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5)
ck.verify(5 * y - 10 * ((y - z) / 2), ((y - z) % 2 + z) * 5)
ck.verify(6 * ((y + z) / 3) - y * 2, (z - (y + z) % 3) * 2)
ck.verify(((y - z) / 3) * 6 - 2 * y, (0 - (y - z) % 3 - z) * 2)

def test_mul_index_simplify():
ck = RewriteChecker()
Expand Down Expand Up @@ -292,6 +311,11 @@ def test_mod_index_simplify():
ck.verify((x + 10) % 2, x % 2)
ck.verify((x + y * 10) % 2, x % 2)
ck.verify((x* 10 + 1 + y * 2 + 2) % 2, 1)
ck.verify(x * 10 % -2, 0)
ck.verify((x * 10 + y) % -2, y % 2)
ck.verify((x + 10) % -2, x % 2)
ck.verify((x + y * 10) % -2, x % 2)
ck.verify((x* 10 + 1 + y * 2 + 2) % -2, 1)


def test_min_index_simplify():
Expand Down Expand Up @@ -449,6 +473,50 @@ def test_cmp_simplify():
ck.verify(x / 2 < 3, x < 6)
ck.verify(x * 4 <= 2, x <= 0)
ck.verify(3 < x / 2, tvm.expr.LT(7, x))
ck.verify(x / 3 >= 0, tvm.expr.LE(-2, x))
ck.verify((0 - x * 3) <= 0, tvm.expr.LE(0, x))
ck.verify((0 - x * 3) >= 0, tvm.expr.LE(x, 0))
ck.verify(2 * x <= 0, x <= 0)

ck.verify(x * 2 >= 3, tvm.expr.LE(2, x))
ck.verify(x * 2 >= 2, tvm.expr.LE(1, x))
ck.verify(x * 2 >= 1, tvm.expr.LE(1, x))
ck.verify(x * 2 >= 0, tvm.expr.LE(0, x))
ck.verify(x * 2 >= -1, tvm.expr.LE(0, x))
ck.verify(x * 2 >= -2, tvm.expr.LE(-1, x))
ck.verify(x * 2 >= -3, tvm.expr.LE(-1, x))

ck.verify(x * 2 <= 3, tvm.expr.LE(x, 1))
ck.verify(x * 2 <= 2, tvm.expr.LE(x, 1))
ck.verify(x * 2 <= 1, tvm.expr.LE(x, 0))
ck.verify(x * 2 <= 0, tvm.expr.LE(x, 0))
ck.verify(x * 2 <= -1, tvm.expr.LE(x, -1))
ck.verify(x * 2 <= -2, tvm.expr.LE(x, -1))
ck.verify(x * 2 <= -3, tvm.expr.LE(x, -2))

ck.verify(x * (-2) >= 3, tvm.expr.LE(x, -2))
ck.verify(x * (-2) >= 2, tvm.expr.LE(x, -1))
ck.verify(x * (-2) >= 1, tvm.expr.LE(x, -1))
ck.verify(x * (-2) >= 0, tvm.expr.LE(x, 0))
ck.verify(x * (-2) >= -1, tvm.expr.LE(x, 0))
ck.verify(x * (-2) >= -2, tvm.expr.LE(x, 1))
ck.verify(x * (-2) >= -3, tvm.expr.LE(x, 1))

ck.verify(x * (-2) <= 3, tvm.expr.LE(-1, x))
ck.verify(x * (-2) <= 2, tvm.expr.LE(-1, x))
ck.verify(x * (-2) <= 1, tvm.expr.LE(0, x))
ck.verify(x * (-2) <= 0, tvm.expr.LE(0, x))
ck.verify(x * (-2) <= -1, tvm.expr.LE(1, x))
ck.verify(x * (-2) <= -2, tvm.expr.LE(1, x))
ck.verify(x * (-2) <= -3, tvm.expr.LE(2, x))

ck.verify(x / 2 >= 1, tvm.expr.LE(2, x))
ck.verify(x / 2 >= 0, tvm.expr.LE(-1, x))
ck.verify(x / 2 >= -1, tvm.expr.LE(-3, x))

ck.verify(x / 2 <= 1, tvm.expr.LE(x, 3))
ck.verify(x / 2 <= 0, tvm.expr.LE(x, 1))
ck.verify(x / 2 <= -1, tvm.expr.LE(x, -2))

ck.verify(x / 4 * 4 < x, tvm.expr.LT(0, x % 4))
ck.verify(x / 4 * 4 >= x, tvm.expr.LE(x % 4, 0))
Expand Down Expand Up @@ -480,6 +548,7 @@ def test_cmp_simplify():
ck.verify(x*y <= 0, tvm.const(1, "bool"))
ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool"))
ck.verify(y*y >= 0, tvm.const(1, "bool"))
ck.verify(x*6 <= -3, tvm.const(0, "bool"))


def test_logical_simplify():
Expand Down