Skip to content

Commit

Permalink
Replace gcd factoring with specialized rules
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrechanik-h committed May 17, 2019
1 parent 4c17256 commit 98d0663
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 88 deletions.
40 changes: 19 additions & 21 deletions src/arithmetic/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include <utility>
#include <unordered_map>
#include "pattern_match.h"
#include "modular_set.h"

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -338,6 +337,25 @@ class ModularSetAnalyzer::Impl :
return Nothing();
}
}
/*!
* \brief Take GCD of a and b.
* \param a The first operand.
* \param b The second operand.
* \return The result.
*/
static int64_t ZeroAwareGCD(int64_t a, int64_t b) {
if (a < 0) a = -a;
if (b < 0) b = -b;
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
// ax + by = gcd(a, b) z if a != 0, b != 0
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}
/*!
* \brief return everything dtype can represent.
* \return Bound that represent everything dtype can represent.
Expand Down Expand Up @@ -377,25 +395,5 @@ ModularSetAnalyzer::~ModularSetAnalyzer() {
delete impl_;
}

/*!
* \brief Take GCD of a and b.
* \param a The first operand.
* \param b The second operand.
* \return The result.
*/
int64_t ZeroAwareGCD(int64_t a, int64_t b) {
if (a < 0) a = -a;
if (b < 0) b = -b;
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
// ax + by = gcd(a, b) z if a != 0, b != 0
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}

} // namespace arith
} // namespace tvm
43 changes: 0 additions & 43 deletions src/arithmetic/modular_set.h

This file was deleted.

43 changes: 21 additions & 22 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include <algorithm>
#include "const_fold.h"
#include "pattern_match.h"
#include "modular_set.h"
#include "rewrite_simplify.h"

namespace tvm {
Expand Down Expand Up @@ -176,16 +175,6 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_REWRITE(y * x + x * z, x * (y + z));
TVM_TRY_REWRITE(x * y + z * x, x * (y + z));
TVM_TRY_REWRITE(y * x + z * x, x * (y + z));
// Factor out gcd
if ((x * c1 + y * c2).Match(ret)) {
auto gcd = ZeroAwareGCD(c1.Eval()->value, c2.Eval()->value);
if (gcd != 1) {
auto b1 = PConstWithTypeLike<PVar<Expr>>(x, c1.Eval()->value / gcd);
auto b2 = PConstWithTypeLike<PVar<Expr>>(x, c2.Eval()->value / gcd);
auto pgcd = PConstWithTypeLike<PVar<Expr>>(x, gcd);
return ((x * b1 + y * b2) * pgcd).Eval();
}
}

// modular-div simplification
// Always pre-condition on positive integer domain
Expand Down Expand Up @@ -260,16 +249,6 @@ Mutate_(const Sub* op, const Expr& self) {
TVM_TRY_REWRITE(y * x - x * z, x * (y - z));
TVM_TRY_REWRITE(x * y - z * x, x * (y - z));
TVM_TRY_REWRITE(y * x - z * x, x * (y - z));
// Factor out gcd
if ((x * c1 - y * c2).Match(ret)) {
auto gcd = ZeroAwareGCD(c1.Eval()->value, c2.Eval()->value);
if (gcd != 1) {
auto b1 = PConstWithTypeLike<PVar<Expr>>(x, c1.Eval()->value / gcd);
auto b2 = PConstWithTypeLike<PVar<Expr>>(x, c2.Eval()->value / gcd);
auto pgcd = PConstWithTypeLike<PVar<Expr>>(x, gcd);
return ((x * b1 - y * b2) * pgcd).Eval();
}
}

// constant cancelation
TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2));
Expand Down Expand Up @@ -316,8 +295,28 @@ Mutate_(const Sub* op, const Expr& self) {
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x - ((x - y) / c1) * c1, (x - y) % c1 + y,
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(((x - y) / c1) * c1 - x, ((y - x) % c1 - y),
TVM_TRY_REWRITE_IF(((x - y) / c1) * c1 - x, (y - x) % c1 - y,
c1.Eval()->value != 0);

TVM_TRY_REWRITE_IF(x * c2 - (x / c1) * c3, (x % c1) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF((x / c1) * c3 - x * c2, 0 - (x % c1) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(x * c2 - ((x + y) / c1) * c3, ((x + y) % c1 - y) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(((x + y) / c1) * c3 - x * c2, (y - ((x + y) % c1)) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(x * c2 - ((x - y) / c1) * c3, ((x - y) % c1 + y) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(((x - y) / c1) * c3 - x * c2, ((y - x) % c1 - y) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);

TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3,
((x + (c1 % c3)) % c3 + (c1 - c2)) / c3,
CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) &&
Expand Down
9 changes: 7 additions & 2 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ def test_sub_index_simplify():
ck.verify(y - ((y - z) / 5) * 5, (y - z) % 5 + z)
ck.verify(((y - z) / 5) * 5 - y, (z - y) % 5 - z)

ck.verify(y * 3 - (y / 2) * 6, (y % 2) * 3)
ck.verify((y / 3) * 6 - y * 2, (y % 3) * (-2))
ck.verify(y * 5 - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5)
ck.verify(y * 5 - ((y - z) / 2) * 10, ((y - z) % 2 + z) * 5)
ck.verify(((y + z) / 3) * 6 - y * 2, (z - (y + z) % 3) * 2)
ck.verify(((y - z) / 3) * 6 - y * 2, ((z - y) % 3 - z) * 2)

def test_mul_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
Expand Down Expand Up @@ -466,8 +473,6 @@ def test_cmp_simplify():
ck.verify((0 - x * 3) <= 0, tvm.expr.LE(0, x))
ck.verify((0 - x * 3) >= 0, tvm.expr.LE(x, 0))
ck.verify(2 * x <= 0, x <= 0)
ck.verify(2 * x - 4 * y <= 0, x + y*(-2) <= 0)
ck.verify(2 * x + 4 * y <= 0, x + y*2 <= 0)

ck.verify(x * 2 >= 3, tvm.expr.LE(2, x))
ck.verify(x * 2 >= 2, tvm.expr.LE(1, x))
Expand Down

0 comments on commit 98d0663

Please sign in to comment.