-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ARITH] Improve div/mod in rewrite simplifier #3149
Changes from all commits
0251237
150b110
46a6ded
b665deb
17b47ac
f14aa9a
978ae69
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,12 +80,6 @@ TryCompare(const Expr& x, int64_t val) { | |
return kLT; | ||
} | ||
} | ||
if (val == 0) { | ||
ModularSet dmod = parent_->modular_set(diff); | ||
if (dmod->base != 0) { | ||
return kNE; | ||
} | ||
} | ||
ConstIntBound dbound = parent_->const_int_bound(diff); | ||
if (dbound->min_value > val) { | ||
return kGT; | ||
|
@@ -99,6 +93,12 @@ TryCompare(const Expr& x, int64_t val) { | |
if (dbound->max_value <= val) { | ||
return kLE; | ||
} | ||
if (val == 0) { | ||
ModularSet dmod = parent_->modular_set(diff); | ||
if (dmod->base != 0) { | ||
return kNE; | ||
} | ||
} | ||
return kUnknown; | ||
} | ||
|
||
|
@@ -284,11 +284,39 @@ Mutate_(const Sub* op, const Expr& self) { | |
CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0)); | ||
|
||
// modular-div simplification | ||
// Always pre-condition on positive integer domain | ||
// Note that c*(x/c) + x % c == x is true for every x and c != 0 even for truncated division | ||
TVM_TRY_REWRITE_IF(x - (x / c1) * c1, x % c1, | ||
CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0); | ||
c1.Eval()->value != 0); | ||
TVM_TRY_REWRITE_IF((x / c1) * c1 - x, 0 - (x % c1), | ||
CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0); | ||
c1.Eval()->value != 0); | ||
TVM_TRY_REWRITE_IF(x - ((x + y) / c1) * c1, (x + y) % c1 - y, | ||
c1.Eval()->value != 0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we replace There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be. I'll fix it and add more tests. |
||
TVM_TRY_REWRITE_IF(((x + y) / c1) * c1 - x, y - ((x + y) % c1), | ||
c1.Eval()->value != 0); | ||
TVM_TRY_REWRITE_IF(x - ((x - y) / c1) * c1, (x - y) % c1 + y, | ||
c1.Eval()->value != 0); | ||
TVM_TRY_REWRITE_IF(((x - y) / c1) * c1 - x, 0 - (x - y) % c1 - y, | ||
c1.Eval()->value != 0); | ||
|
||
TVM_TRY_REWRITE_IF(x * c2 - (x / c1) * c3, (x % c1) * c2, | ||
c1.Eval()->value != 0 && | ||
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); | ||
TVM_TRY_REWRITE_IF((x / c1) * c3 - x * c2, 0 - (x % c1) * c2, | ||
c1.Eval()->value != 0 && | ||
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); | ||
TVM_TRY_REWRITE_IF(x * c2 - ((x + y) / c1) * c3, ((x + y) % c1 - y) * c2, | ||
c1.Eval()->value != 0 && | ||
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); | ||
TVM_TRY_REWRITE_IF(((x + y) / c1) * c3 - x * c2, (y - ((x + y) % c1)) * c2, | ||
c1.Eval()->value != 0 && | ||
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); | ||
TVM_TRY_REWRITE_IF(x * c2 - ((x - y) / c1) * c3, ((x - y) % c1 + y) * c2, | ||
c1.Eval()->value != 0 && | ||
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); | ||
TVM_TRY_REWRITE_IF(((x - y) / c1) * c3 - x * c2, (0 - (x - y) % c1 - y) * c2, | ||
c1.Eval()->value != 0 && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sgrechanik-h please update the comment to mark all the cases that require special truc div rule. So we can be sure about this later
|
||
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rules are all correct to me. Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this particular example works thanks to this canonicalization rule. I've added a couple of tests to be sure. |
||
TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3, | ||
((x + (c1 % c3)) % c3 + (c1 - c2)) / c3, | ||
CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) && | ||
|
@@ -348,6 +376,7 @@ Mutate_(const Mul* op, const Expr& self) { | |
|
||
// canonicalization | ||
TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1); | ||
TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1); | ||
TVM_TRY_RECURSIVE_REWRITE_IF( | ||
(x - y) * c1, (y - x) * (0 - c1), | ||
c1.Eval()->value < 0); | ||
|
@@ -396,6 +425,16 @@ Mutate_(const Div* op, const Expr& self) { | |
// We adopt the default C division uses truncation instead of floordiv. | ||
// This means most rules need to check non-negativeness of the operands. | ||
|
||
// TryConstFold doesn't work for negative cases because it is also used by legacy | ||
// parts of tvm which still assume euclidean div. In this simplifier we assume that the division | ||
// is truncated, so perform const folding again. | ||
// NOTE: trunc div required | ||
if ((c1 / c2).Match(ret)) { | ||
int64_t c1val = c1.Eval()->value; | ||
int64_t c2val = c2.Eval()->value; | ||
return make_const(op->type, c1val / c2val); | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
// while it is always true for trunc div | ||
// restrict to common case(positive div) | ||
TVM_TRY_REWRITE_IF((x / c1) / c2, x / (c1 * c2), | ||
|
@@ -608,6 +647,12 @@ Mutate_(const Mod* op, const Expr& self) { | |
CanProveGreaterEqual(x.Eval(), 0) && | ||
CanProveGreaterEqual(y.Eval(), 0)); | ||
|
||
// canonicalization: x % c == x % (-c) for truncated division | ||
// NOTE: trunc div required | ||
TVM_TRY_RECURSIVE_REWRITE_IF(x % c1, | ||
x % PConst<Expr>(make_const(op->type, -c1.Eval()->value)), | ||
c1.Eval()->value < 0); | ||
|
||
// try modular analysis | ||
if ((x % c1).Match(ret)) { | ||
ModularSet mod = parent_->modular_set(x.Eval()); | ||
|
@@ -1025,20 +1070,53 @@ Mutate_(const LT* op, const Expr& self) { | |
TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x, | ||
c1.Eval()->value < 0); | ||
|
||
// require c1 > 0 to work for any div mode | ||
TVM_TRY_REWRITE_IF(x * c2 < c1, x < (c1 - 1) / c2 + 1, | ||
c1.Eval()->value > 0 && | ||
c2.Eval()->value > 0); | ||
TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2, | ||
// NOTE: trunc div required | ||
TVM_TRY_REWRITE_IF(x * c2 < c1, x < c1 / c2, | ||
c1.Eval()->value <= 0 && | ||
c2.Eval()->value > 0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is true if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can, since these rules are guarded by the condition |
||
// NOTE: trunc div required (euclidean is ok too, floored is not) | ||
TVM_TRY_REWRITE_IF(x * c2 < c1, (c1 - 1) / c2 - 1 < x, | ||
c1.Eval()->value > 0 && | ||
c2.Eval()->value < 0); | ||
// NOTE: trunc div required (floored is ok too, euclidean is not) | ||
TVM_TRY_REWRITE_IF(x * c2 < c1, c1 / c2 < x, | ||
c1.Eval()->value <= 0 && | ||
c2.Eval()->value < 0); | ||
|
||
// NOTE: trunc div required | ||
TVM_TRY_REWRITE_IF(c1 < x * c2, (c1 + 1) / c2 - 1 < x, | ||
c1.Eval()->value < 0 && | ||
c2.Eval()->value > 0); | ||
|
||
TVM_TRY_REWRITE_IF(c1 < x * c2, c1 / c2 < x, | ||
c1.Eval()->value >= 0 && | ||
c2.Eval()->value > 0); | ||
// NOTE: trunc div required (floored is ok too, euclidean is not) | ||
TVM_TRY_REWRITE_IF(c1 < x * c2, x < (c1 + 1) / c2 + 1, | ||
c1.Eval()->value < 0 && | ||
c2.Eval()->value < 0); | ||
// NOTE: trunc div required (euclidean is ok too, floored is not) | ||
TVM_TRY_REWRITE_IF(c1 < x * c2, x < c1 / c2, | ||
c1.Eval()->value >= 0 && | ||
c2.Eval()->value < 0); | ||
|
||
TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2, | ||
c1.Eval()->value > 0 && | ||
c2.Eval()->value > 0); | ||
// NOTE: trunc div required | ||
TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * (c2 - 1) + 1, | ||
c1.Eval()->value > 0 && | ||
c2.Eval()->value <= 0); | ||
|
||
TVM_TRY_REWRITE_IF(c1 < x / c2, (c1 + 1) * c2 - 1 < x, | ||
c1.Eval()->value >= 0 && | ||
c2.Eval()->value > 0); | ||
// NOTE: trunc div required | ||
TVM_TRY_REWRITE_IF(c1 < x / c2, c1 * c2 < x, | ||
c1.Eval()->value < 0 && | ||
c2.Eval()->value > 0); | ||
|
||
// division related simplificationx | ||
// invariance for any div mod: x - (x / c1) * c1 == x % c1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for enhancement , do we need to change dbound->max_value <= val into dbound->max_value == val? similar existing issue in line 90.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is really an enhancement, current formulation seems more straightforward, e.g. if max_value were an Expr, we could write
can_prove(max_value <= val)
whereascan_prove(max_value == val)
would be "less complete".