Skip to content

Commit

Permalink
[ARITH] Improve div/mod in rewrite simplifier
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrechanik-h committed May 8, 2019
1 parent b131d83 commit ea6ba99
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 37 deletions.
8 changes: 8 additions & 0 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ class ModularSetAnalyzer {
Impl* impl_;
};

/*!
* \brief Take GCD of a and b.
* \param a The first operand.
* \param b The second operand.
* \return The result.
*/
int64_t ZeroAwareGCD(int64_t a, int64_t b);

/*!
* \brief Rewrite-rule based simplifier.
*/
Expand Down
10 changes: 4 additions & 6 deletions src/arithmetic/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,8 @@ template<>
inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
// We use the C/C++ truncated division in TVM
if (pa && pb && pb->value != 0) {
return IntImm::make(rtype, pa->value / pb->value);
}
if (pa) {
Expand All @@ -183,9 +182,8 @@ template<>
inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
TVM_INDEX_CONST_PROPAGATION({
const Type& rtype = a.type();
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
// We use the C/C++ truncated division in TVM
if (pa && pb && pb->value != 0) {
return IntImm::make(rtype, pa->value % pb->value);
}
if (pa) {
Expand Down
39 changes: 20 additions & 19 deletions src/arithmetic/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,25 +335,6 @@ class ModularSetAnalyzer::Impl :
return Nothing();
}
}
/*!
* \brief Take GCD of a and b.
* \param a The first operand.
* \param b The second operand.
* \return The result.
*/
static int64_t ZeroAwareGCD(int64_t a, int64_t b) {
if (a < 0) a = -a;
if (b < 0) b = -b;
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
// ax + by = gcd(a, b) z if a != 0, b != 0
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}
/*!
* \brief return everything dtype can represent.
* \return Bound that represent everything dtype can represent.
Expand Down Expand Up @@ -393,5 +374,25 @@ ModularSetAnalyzer::~ModularSetAnalyzer() {
delete impl_;
}

/*!
* \brief Take GCD of a and b.
* \param a The first operand.
* \param b The second operand.
* \return The result.
*/
int64_t ZeroAwareGCD(int64_t a, int64_t b) {
if (a < 0) a = -a;
if (b < 0) b = -b;
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
// ax + by = gcd(a, b) z if a != 0, b != 0
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}

} // namespace arith
} // namespace tvm
79 changes: 67 additions & 12 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,6 @@ TryCompare(const Expr& x, int64_t val) {
return kLT;
}
}
if (val == 0) {
ModularSet dmod = parent_->modular_set(diff);
if (dmod->base != 0) {
return kNE;
}
}
ConstIntBound dbound = parent_->const_int_bound(diff);
if (dbound->min_value > val) {
return kGT;
Expand All @@ -99,6 +93,12 @@ TryCompare(const Expr& x, int64_t val) {
if (dbound->max_value <= val) {
return kLE;
}
if (val == 0) {
ModularSet dmod = parent_->modular_set(diff);
if (dmod->base != 0) {
return kNE;
}
}
return kUnknown;
}

Expand Down Expand Up @@ -175,6 +175,16 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_REWRITE(y * x + x * z, x * (y + z));
TVM_TRY_REWRITE(x * y + z * x, x * (y + z));
TVM_TRY_REWRITE(y * x + z * x, x * (y + z));
// Factor out gcd
if ((x * c1 + y * c2).Match(ret)) {
auto gcd = ZeroAwareGCD(c1.Eval()->value, c2.Eval()->value);
if (gcd != 1) {
auto b1 = PConstWithTypeLike<PVar<Expr>>(x, c1.Eval()->value / gcd);
auto b2 = PConstWithTypeLike<PVar<Expr>>(x, c2.Eval()->value / gcd);
auto pgcd = PConstWithTypeLike<PVar<Expr>>(x, gcd);
return ((x * b1 + y * b2) * pgcd).Eval();
}
}

// modular-div simplification
// Always pre-condition on positive integer domain
Expand Down Expand Up @@ -249,6 +259,16 @@ Mutate_(const Sub* op, const Expr& self) {
TVM_TRY_REWRITE(y * x - x * z, x * (y - z));
TVM_TRY_REWRITE(x * y - z * x, x * (y - z));
TVM_TRY_REWRITE(y * x - z * x, x * (y - z));
// Factor out gcd
if ((x * c1 - y * c2).Match(ret)) {
auto gcd = ZeroAwareGCD(c1.Eval()->value, c2.Eval()->value);
if (gcd != 1) {
auto b1 = PConstWithTypeLike<PVar<Expr>>(x, c1.Eval()->value / gcd);
auto b2 = PConstWithTypeLike<PVar<Expr>>(x, c2.Eval()->value / gcd);
auto pgcd = PConstWithTypeLike<PVar<Expr>>(x, gcd);
return ((x * b1 - y * b2) * pgcd).Eval();
}
}

// constant cancelation
TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2));
Expand Down Expand Up @@ -284,11 +304,15 @@ Mutate_(const Sub* op, const Expr& self) {
CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0));

// modular-div simplification
// Always pre-condition on positive integer domain
// Note that c*(x/c) + x % c == x is true for every x and c != 0 even for truncated division
TVM_TRY_REWRITE_IF(x - (x / c1) * c1, x % c1,
CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0);
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF((x / c1) * c1 - x, 0 - (x % c1),
CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0);
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x - ((x + c2) / c1) * c1, (x + c2) % c1 - c2,
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(((x + c2) / c1) * c1 - x, c2 - ((x + c2) % c1),
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3,
((x + (c1 % c3)) % c3 + (c1 - c2)) / c3,
CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) &&
Expand Down Expand Up @@ -348,6 +372,7 @@ Mutate_(const Mul* op, const Expr& self) {

// canonicalization
TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1);
TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1);
TVM_TRY_RECURSIVE_REWRITE_IF(
(x - y) * c1, (y - x) * (0 - c1),
c1.Eval()->value < 0);
Expand Down Expand Up @@ -617,6 +642,11 @@ Mutate_(const Mod* op, const Expr& self) {
return (mod->base % c1).Eval();
}
}

// canonicalization: x % c == x % (-c) for truncated division
TVM_TRY_RECURSIVE_REWRITE_IF(x % c1,
x % PConst<Expr>(make_const(op->type, -c1.Eval()->value)),
c1.Eval()->value < 0);
}
return ret;
}
Expand Down Expand Up @@ -1025,20 +1055,45 @@ Mutate_(const LT* op, const Expr& self) {
TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x,
c1.Eval()->value < 0);

// require c1 > 0 to work for any div mode
TVM_TRY_REWRITE_IF(x * c2 < c1, x < (c1 - 1) / c2 + 1,
c1.Eval()->value > 0 &&
c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2,
c1.Eval()->value > 0 &&
TVM_TRY_REWRITE_IF(x * c2 < c1, x < c1 / c2,
c1.Eval()->value <= 0 &&
c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(x * c2 < c1, (c1 - 1) / c2 - 1 < x,
c1.Eval()->value > 0 &&
c2.Eval()->value < 0);
TVM_TRY_REWRITE_IF(x * c2 < c1, c1 / c2 < x,
c1.Eval()->value <= 0 &&
c2.Eval()->value < 0);

TVM_TRY_REWRITE_IF(c1 < x * c2, (c1 + 1) / c2 - 1 < x,
c1.Eval()->value < 0 &&
c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(c1 < x * c2, c1 / c2 < x,
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(c1 < x * c2, x < (c1 + 1) / c2 + 1,
c1.Eval()->value < 0 &&
c2.Eval()->value < 0);
TVM_TRY_REWRITE_IF(c1 < x * c2, x < c1 / c2,
c1.Eval()->value >= 0 &&
c2.Eval()->value < 0);

TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2,
c1.Eval()->value > 0 &&
c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * (c2 - 1) + 1,
c1.Eval()->value > 0 &&
c2.Eval()->value <= 0);

TVM_TRY_REWRITE_IF(c1 < x / c2, (c1 + 1) * c2 - 1 < x,
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(c1 < x / c2, c1 * c2 < x,
c1.Eval()->value < 0 &&
c2.Eval()->value > 0);

// division related simplificationx
// invariance for any div mod: x - (x / c1) * c1 == x % c1
Expand Down
56 changes: 56 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ def test_sub_index_simplify():
ck.verify(x - (x / 3) * 3, x % 3)
ck.verify((x + 5) / 3 - x / 3, (((x + 2) % 3) + 5)/ 3)

ck.verify(y - (y / (-5)) * (-5), y % 5)
ck.verify((y / 3) * 3 - y, 0 - y % 3)
ck.verify(y - ((y - 6) / 5) * 5, (y + (-6)) % 5 + 6)
ck.verify(((y - 6) / 5) * 5 - y, (-6) - (y + (-6)) % 5)

def test_mul_index_simplify():
ck = RewriteChecker()
Expand Down Expand Up @@ -292,6 +296,11 @@ def test_mod_index_simplify():
ck.verify((x + 10) % 2, x % 2)
ck.verify((x + y * 10) % 2, x % 2)
ck.verify((x* 10 + 1 + y * 2 + 2) % 2, 1)
ck.verify(x * 10 % -2, 0)
ck.verify((x * 10 + y) % -2, y % 2)
ck.verify((x + 10) % -2, x % 2)
ck.verify((x + y * 10) % -2, x % 2)
ck.verify((x* 10 + 1 + y * 2 + 2) % -2, 1)


def test_min_index_simplify():
Expand Down Expand Up @@ -449,6 +458,52 @@ def test_cmp_simplify():
ck.verify(x / 2 < 3, x < 6)
ck.verify(x * 4 <= 2, x <= 0)
ck.verify(3 < x / 2, tvm.expr.LT(7, x))
ck.verify(x / 3 >= 0, tvm.expr.LE(-2, x))
ck.verify((0 - x * 3) <= 0, tvm.expr.LE(0, x))
ck.verify((0 - x * 3) >= 0, tvm.expr.LE(x, 0))
ck.verify(2 * x <= 0, x <= 0)
ck.verify(2 * x - 4 * y <= 0, x + y*(-2) <= 0)
ck.verify(2 * x + 4 * y <= 0, x + y*2 <= 0)

ck.verify(x * 2 >= 3, tvm.expr.LE(2, x))
ck.verify(x * 2 >= 2, tvm.expr.LE(1, x))
ck.verify(x * 2 >= 1, tvm.expr.LE(1, x))
ck.verify(x * 2 >= 0, tvm.expr.LE(0, x))
ck.verify(x * 2 >= -1, tvm.expr.LE(0, x))
ck.verify(x * 2 >= -2, tvm.expr.LE(-1, x))
ck.verify(x * 2 >= -3, tvm.expr.LE(-1, x))

ck.verify(x * 2 <= 3, tvm.expr.LE(x, 1))
ck.verify(x * 2 <= 2, tvm.expr.LE(x, 1))
ck.verify(x * 2 <= 1, tvm.expr.LE(x, 0))
ck.verify(x * 2 <= 0, tvm.expr.LE(x, 0))
ck.verify(x * 2 <= -1, tvm.expr.LE(x, -1))
ck.verify(x * 2 <= -2, tvm.expr.LE(x, -1))
ck.verify(x * 2 <= -3, tvm.expr.LE(x, -2))

ck.verify(x * (-2) >= 3, tvm.expr.LE(x, -2))
ck.verify(x * (-2) >= 2, tvm.expr.LE(x, -1))
ck.verify(x * (-2) >= 1, tvm.expr.LE(x, -1))
ck.verify(x * (-2) >= 0, tvm.expr.LE(x, 0))
ck.verify(x * (-2) >= -1, tvm.expr.LE(x, 0))
ck.verify(x * (-2) >= -2, tvm.expr.LE(x, 1))
ck.verify(x * (-2) >= -3, tvm.expr.LE(x, 1))

ck.verify(x * (-2) <= 3, tvm.expr.LE(-1, x))
ck.verify(x * (-2) <= 2, tvm.expr.LE(-1, x))
ck.verify(x * (-2) <= 1, tvm.expr.LE(0, x))
ck.verify(x * (-2) <= 0, tvm.expr.LE(0, x))
ck.verify(x * (-2) <= -1, tvm.expr.LE(1, x))
ck.verify(x * (-2) <= -2, tvm.expr.LE(1, x))
ck.verify(x * (-2) <= -3, tvm.expr.LE(2, x))

ck.verify(x / 2 >= 1, tvm.expr.LE(2, x))
ck.verify(x / 2 >= 0, tvm.expr.LE(-1, x))
ck.verify(x / 2 >= -1, tvm.expr.LE(-3, x))

ck.verify(x / 2 <= 1, tvm.expr.LE(x, 3))
ck.verify(x / 2 <= 0, tvm.expr.LE(x, 1))
ck.verify(x / 2 <= -1, tvm.expr.LE(x, -2))

ck.verify(x / 4 * 4 < x, tvm.expr.LT(0, x % 4))
ck.verify(x / 4 * 4 >= x, tvm.expr.LE(x % 4, 0))
Expand Down Expand Up @@ -480,6 +535,7 @@ def test_cmp_simplify():
ck.verify(x*y <= 0, tvm.const(1, "bool"))
ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool"))
ck.verify(y*y >= 0, tvm.const(1, "bool"))
ck.verify(x*6 <= -3, tvm.const(0, "bool"))


def test_logical_simplify():
Expand Down

0 comments on commit ea6ba99

Please sign in to comment.