From aead39754d4be8c69777e821ef97225a3bc61267 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 24 Sep 2019 11:01:37 -0700 Subject: [PATCH] [ARITH] Explicitly state truncdiv/mod in pattern matching. (#3986) * [ARITH] Explicitly state truncdiv/mod in pattern matching. * Fix the dependent cpp test --- include/tvm/expr_operator.h | 34 +++ src/arithmetic/canonical_simplify.cc | 4 +- src/arithmetic/int_operator.h | 20 ++ src/arithmetic/modular_set.cc | 4 +- src/arithmetic/pattern_match.h | 16 +- src/arithmetic/rewrite_simplify.cc | 228 ++++++++------- src/lang/expr_operator.cc | 10 +- tests/cpp/pattern_match_test.cc | 10 +- .../unittest/test_arith_rewrite_simplify.py | 267 ++++++++++-------- 9 files changed, 350 insertions(+), 243 deletions(-) diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index b0e82e7fb50c..5f0f8495a5c0 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -332,6 +332,20 @@ TVM_DLL Expr operator||(Expr a, Expr b); * \note This operator does eager constant folding. */ TVM_DLL Expr operator!(Expr a); +/*! + * \brief compute division in C semantics. + * + * a / b as in C/C++. + * + * When operands are integers, it directly corresponds to truncdiv. + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr div(Expr a, Expr b); /*! * \brief compute trunc(a / b) * @@ -640,6 +654,21 @@ inline Expr make_zero(Type t) { return make_const(t, 0); } +/*! + * \brief Helper function to raise a compiler error about division ambiguity. + * \note The call to this function will always results in a compiler error. + * \tparam TA Any class type. + */ +template +inline void DivAmbiguityError(const TA& a) { + constexpr bool div_ambiguity = !std::is_class::value; + static_assert(div_ambiguity, + "TVM supports multiple types of integer divisions, " + "please call div, floordiv/floormod or truncdiv/truncmod directly " + "to avoid ambiguity in the code. " + "Checkout these functions in expr_operator.h."); +} + // additional const expression overloading #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ inline Expr Name(Expr& a, Expr b) { \ @@ -688,12 +717,17 @@ TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min); +TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(div); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*) TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*) TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=); // integer related ops TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%); +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod); +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv); +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod); +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&); diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index e1fa6d6f84c5..8b0d29e28708 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -67,7 +67,7 @@ enum DivMode { inline Expr ModImpl(Expr a, Expr b, DivMode mode) { if (mode == kTruncDiv) { - return a % b; + return truncmod(a, b); } else { CHECK_EQ(mode, kFloorDiv); return floormod(a, b); @@ -76,7 +76,7 @@ inline Expr ModImpl(Expr a, Expr b, DivMode mode) { inline Expr DivImpl(Expr a, Expr b, DivMode mode) { if (mode == kTruncDiv) { - return a / b; + return truncdiv(a, b); } else { CHECK_EQ(mode, kFloorDiv); return floordiv(a, b); diff --git a/src/arithmetic/int_operator.h b/src/arithmetic/int_operator.h index d92094415eba..e1694a3fcd20 100644 --- a/src/arithmetic/int_operator.h +++ b/src/arithmetic/int_operator.h @@ -92,6 +92,26 @@ inline bool WillOverflow(int64_t x, return y == 0; } +/*! + * \brief Peform trunc division of two integers. + * \param x The left operand. + * \param y The right operand. + * \return the result. + */ +inline int64_t truncdiv(int64_t x, int64_t y) { + return x / y; +} + +/*! + * \brief Compute the truncdiv remainder of two integers. + * \param x The left operand. + * \param y The right operand. + * \return the result. + */ +inline int64_t truncmod(int64_t x, int64_t y) { + return x % y; +} + /*! * \brief Peform floor division of two integers. * \param x The left operand. diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index c072f0986c5a..08454dd0ef5a 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file modular_set.cc * \brief Modular set analysis */ @@ -111,7 +110,8 @@ class ModularSetAnalyzer::Impl : PVar var; PVar coeff, base; // pattern match interesting constraints - if (((var % coeff) == base).Match(constraint)) { + if ((truncmod(var, coeff) == base).Match(constraint) || + (floormod(var, coeff) == base).Match(constraint)) { Entry entry(coeff.Eval()->value, base.Eval()->value); return UpdateByIntersect(var.Eval(), entry); } diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index 1278c7d32ee5..f7d5483cf6de 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -300,31 +300,41 @@ class PConstWithTypeLike : }; -#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \ +#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \ template \ inline PBinaryExpr \ FuncName(const Pattern& a, const Pattern& b) { \ + CheckStep; \ return PBinaryExpr(a.derived(), b.derived()); \ } \ template \ inline PBinaryExpr > \ FuncName(const Pattern& a, int64_t b) { \ + CheckStep; \ return FuncName(a, PConstWithTypeLike(a.derived(), b)); \ } \ template \ inline PBinaryExpr, TA> \ FuncName(int64_t b, const Pattern& a) { \ + CheckStep; \ return FuncName(PConstWithTypeLike(a.derived(), b), a); \ } +#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \ + TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, ) + + +// raise ambiguity error for operator overload of / and % +TVM_PATTERN_BINARY_OP_EX(operator/, ir::Div, DivAmbiguityError(a)); +TVM_PATTERN_BINARY_OP_EX(operator%, ir::Mod, DivAmbiguityError(a)); + // arithmetic expressions TVM_PATTERN_BINARY_OP(operator+, ir::Add); TVM_PATTERN_BINARY_OP(operator-, ir::Sub); TVM_PATTERN_BINARY_OP(operator*, ir::Mul); -TVM_PATTERN_BINARY_OP(operator/, ir::Div); -TVM_PATTERN_BINARY_OP(operator%, ir::Mod); TVM_PATTERN_BINARY_OP(min, ir::Min); TVM_PATTERN_BINARY_OP(max, ir::Max); +TVM_PATTERN_BINARY_OP(div, ir::Div); TVM_PATTERN_BINARY_OP(truncdiv, ir::Div); TVM_PATTERN_BINARY_OP(truncmod, ir::Mod); TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv); diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index a567f502f766..e3b3e7aed09c 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -194,7 +194,7 @@ Mutate_(const Add* op, const Expr& self) { // DivMod rules // truc div - TVM_TRY_REWRITE((x / c1) * c1 + x % c1, x); + TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x); // floor div TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x); @@ -208,7 +208,7 @@ Mutate_(const Add* op, const Expr& self) { // DivMod rules // truc div - TVM_TRY_RECURSIVE_REWRITE((y % c1) + x * c1, x * c1 + (y % c1)); + TVM_TRY_RECURSIVE_REWRITE(truncmod(y, c1) + x * c1, x * c1 + truncmod(y, c1)); // floor div TVM_TRY_RECURSIVE_REWRITE(floormod(y, c1) + x * c1, x * c1 + floormod(y, c1)); } @@ -314,48 +314,49 @@ Mutate_(const Sub* op, const Expr& self) { // DivMod rules // trucdiv // NOTE: c*(x/c) + x % c == x is true all division mode. - TVM_TRY_REWRITE_IF(x - (x / c1) * c1, x % c1, + TVM_TRY_REWRITE_IF(x - truncdiv(x, c1) * c1, truncmod(x, c1), c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF((x / c1) * c1 - x, 0 - (x % c1), + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 - x, 0 - truncmod(x, c1), c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(x - ((x + y) / c1) * c1, (x + y) % c1 - y, + TVM_TRY_REWRITE_IF(x - (truncdiv(x + y, c1)) * c1, truncmod(x + y, c1) - y, c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(((x + y) / c1) * c1 - x, y - ((x + y) % c1), + TVM_TRY_REWRITE_IF((truncdiv(x + y, c1)) * c1 - x, y - truncmod(x + y, c1), c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(x - ((x - y) / c1) * c1, (x - y) % c1 + y, + TVM_TRY_REWRITE_IF(x - truncdiv(x - y, c1) * c1, truncmod(x - y, c1) + y, c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(((x - y) / c1) * c1 - x, 0 - (x - y) % c1 - y, + TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c1 - x, 0 - truncmod(x - y, c1) - y, c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(x * c2 - (x / c1) * c3, (x % c1) * c2, + TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x, c1) * c3, truncmod(x, c1) * c2, c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF((x / c1) * c3 - x * c2, 0 - (x % c1) * c2, + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c3 - x * c2, 0 - truncmod(x, c1) * c2, c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(x * c2 - ((x + y) / c1) * c3, ((x + y) % c1 - y) * c2, + TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x + y, c1) * c3, (truncmod(x + y, c1) - y) * c2, c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(((x + y) / c1) * c3 - x * c2, (y - ((x + y) % c1)) * c2, + TVM_TRY_REWRITE_IF(truncdiv(x + y, c1) * c3 - x * c2, (y - truncmod(x + y, c1)) * c2, c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(x * c2 - ((x - y) / c1) * c3, ((x - y) % c1 + y) * c2, + TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x - y, c1) * c3, (truncmod(x - y, c1) + y) * c2, c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(((x - y) / c1) * c3 - x * c2, (0 - (x - y) % c1 - y) * c2, + TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c3 - x * c2, (0 - truncmod(x - y, c1) - y) * c2, c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); // Proof in the case of floordiv, need positive condition. // let x = a * c3 + r // (x + c1) / c3 - x / c3 => (r + c1) / c3 - TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3, - ((x + ((c2 % c3) + c3) % c3) % c3 + (c1 - c2)) / c3, + // NOTE: the use of floormod(c2, c3) was intentional to simplify the const. + TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x + c2, c3), + truncdiv(truncmod(x + floormod(c2, c3), c3) + (c1 - c2), c3), CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) && c1.Eval()->value >= c2.Eval()->value && c3.Eval()->value > 0); - TVM_TRY_REWRITE_IF((x + c1) / c3 - x / c3, - (x % c3 + c1) / c3, + TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x, c3), + truncdiv(truncmod(x, c3) + c1, c3), CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value >= 0 && c3.Eval()->value > 0); @@ -478,14 +479,15 @@ Mutate_(const Div* op, const Expr& self) { // Vector rules if (op->type.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) / broadcast(y, lanes), - broadcast(x / y, lanes)); + // NOTE: use div as the pattern also works for float. + TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)), + broadcast(div(x, y), lanes)); // ramp / bcast - if ((ramp(b1, c1, lanes) / broadcast(c2, lanes)).Match(ret)) { + if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; if (c1val % c2val == 0) { - return ramp(b1 / c2, c1 / c2, lanes).Eval(); + return ramp(div(b1, c2), div(c1, c2), lanes).Eval(); } // If all possible indices in ramp are the same. if (CanProveGreaterEqual(b1.Eval(), 0)) { @@ -493,7 +495,7 @@ Mutate_(const Div* op, const Expr& self) { int64_t ramp_min = bmod->base / c2val; int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val; if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) { - return broadcast(b1 / c2, lanes).Eval(); + return broadcast(div(b1, c2), lanes).Eval(); } } } @@ -508,73 +510,79 @@ Mutate_(const Div* op, const Expr& self) { // parts of tvm which still assume euclidean div. In this simplifier we assume that the division // is truncated, so perform const folding again. // NOTE: trunc div required - if ((c1 / c2).Match(ret)) { + if (truncdiv(c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - return make_const(op->type, c1val / c2val); + return make_const(op->type, truncdiv(c1val, c2val)); } // while it is always true for trunc div // restrict to common case(positive div) - TVM_TRY_REWRITE_IF((x / c1) / c2, x / (c1 * c2), + TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1), c2), truncdiv(x, c1 * c2), c1.Eval()->value > 0 && c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF((x / c1 + c2) / c3, (x + c1 * c2) / (c1 * c3), + TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1) + c2, c3), truncdiv(x + c1 * c2, c1 * c3), c1.Eval()->value > 0 && c2.Eval()->value >= 0 && c3.Eval()->value > 0 && CanProveGreaterEqual(x.Eval(), 0)); - if (((x * c1) / c2).Match(ret)) { + if (truncdiv(x * c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; if (c1val > 0 && c2val > 0) { - if (c1val % c2val == 0) return (x * (c1 / c2)).Eval(); - if (c2val % c1val == 0) return (x / (c2 / c1)).Eval(); + if (c1val % c2val == 0) return (x * truncdiv(c1, c2)).Eval(); + if (c2val % c1val == 0) return truncdiv(x, truncdiv(c2, c1)).Eval(); } } - TVM_TRY_REWRITE(x / x, OneWithTypeLike(x)); - TVM_TRY_REWRITE(x * c1 / x, c1); - TVM_TRY_REWRITE(c1 * x / x, c1); + TVM_TRY_REWRITE(truncdiv(x, x), OneWithTypeLike(x)); + TVM_TRY_REWRITE(truncdiv(x * c1, x), c1); + TVM_TRY_REWRITE(truncdiv(c1 * x, x), c1); // Rules involving 2-operands. - TVM_TRY_REWRITE_IF((x * c1 + y) / c2, x * (c1 / c2) + y / c2, + TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y, c2), + x * truncdiv(c1, c2) + truncdiv(y, c2), c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); - TVM_TRY_REWRITE_IF(min(x * c1, y) / c2, min(x * (c1 / c2), y / c2), + TVM_TRY_REWRITE_IF(truncdiv(min(x * c1, y), c2), + min(x * truncdiv(c1, c2), truncdiv(y, c2)), c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); - TVM_TRY_REWRITE_IF(max(x * c1, y) / c2, max(x * (c1 / c2), y / c2), + TVM_TRY_REWRITE_IF(truncdiv(max(x * c1, y), c2), + max(x * truncdiv(c1, c2), truncdiv(y, c2)), c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); - TVM_TRY_REWRITE_IF((y + x * c1) / c2, y / c2 + x * (c1 / c2), + TVM_TRY_REWRITE_IF(truncdiv(y + x * c1, c2), + truncdiv(y, c2) + x * truncdiv(c1, c2), c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); - TVM_TRY_REWRITE_IF(min(y, x * c1) / c2, min(y / c2, x * (c1 / c2)), + TVM_TRY_REWRITE_IF(truncdiv(min(y, x * c1), c2), + min(truncdiv(y, c2), x * truncdiv(c1, c2)), c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); - TVM_TRY_REWRITE_IF(max(y, x * c1) / c2, max(y / c2, x * (c1 / c2)), + TVM_TRY_REWRITE_IF(truncdiv(max(y, x * c1), c2), + max(truncdiv(y, c2), x * truncdiv(c1, c2)), c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && @@ -582,80 +590,89 @@ Mutate_(const Div* op, const Expr& self) { CanProveGreaterEqual(y.Eval(), 0)); // Rules involving 3-operands. - TVM_TRY_REWRITE_IF((x * c1 + y + z) / c2, x * (c1 / c2) + (y + z)/ c2, + TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y + z, c2), + x * truncdiv(c1, c2) + truncdiv(y + z, c2), c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF((x * c1 - y + z) / c2, x * (c1 / c2) + (z - y)/ c2, + TVM_TRY_REWRITE_IF(truncdiv(x * c1 - y + z, c2), + x * truncdiv(c1, c2) + truncdiv(z - y, c2), c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((z - y).Eval(), 0)); - TVM_TRY_REWRITE_IF((x * c1 + y - z) / c2, x * (c1 / c2) + (y - z)/ c2, + TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y - z, c2), + x * truncdiv(c1, c2) + truncdiv(y - z, c2), c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y - z).Eval(), 0)); - TVM_TRY_REWRITE_IF((y + x * c1 + z) / c2, x * (c1 / c2) + (y + z) / c2, + TVM_TRY_REWRITE_IF(truncdiv(y + x * c1 + z, c2), + x * truncdiv(c1, c2) + truncdiv(y + z, c2), c1.Eval()->value > 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF((x + c1) / c2, x / c2 + c1 / c2, + TVM_TRY_REWRITE_IF(truncdiv(x + c1, c2), + truncdiv(x, c2) + truncdiv(c1, c2), c1.Eval()->value > 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0)); - TVM_TRY_REWRITE_IF((x + y) / x, y / x + 1, + TVM_TRY_REWRITE_IF(truncdiv(x + y, x), truncdiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); - TVM_TRY_REWRITE_IF((y + x) / x, y / x + 1, + TVM_TRY_REWRITE_IF(truncdiv(y + x, x), truncdiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); - TVM_TRY_REWRITE_IF(((x + y) + z) / x, (y + z) / x + 1, + TVM_TRY_REWRITE_IF(truncdiv((x + y) + z, x), + truncdiv(y + z, x) + 1, CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF(((y + x) + z) / x, (y + z) / x + 1, + TVM_TRY_REWRITE_IF(truncdiv((y + x) + z, x), + truncdiv(y + z, x) + 1, CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF((y + (z + x)) / x, (y + z) / x + 1, + TVM_TRY_REWRITE_IF(truncdiv(y + (z + x), x), + truncdiv(y + z, x) + 1, CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF((y + (x + z)) / x, (y + z) / x + 1, + TVM_TRY_REWRITE_IF(truncdiv(y + (x + z), x), + truncdiv(y + z, x) + 1, CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF((x * y) / y, x, + TVM_TRY_REWRITE_IF(truncdiv(x * y, y), x, CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); - TVM_TRY_REWRITE_IF((y * x) / y, x, + TVM_TRY_REWRITE_IF(truncdiv(y * x, y), x, CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); - TVM_TRY_REWRITE_IF((x * z + y) / z, x + y / z, + TVM_TRY_REWRITE_IF(truncdiv(x * z + y, z), x + truncdiv(y, z), CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && CanProveGreaterEqual(z.Eval(), 0)); - TVM_TRY_REWRITE_IF((z * x + y) / z, x + y / z, + TVM_TRY_REWRITE_IF(truncdiv(z * x + y, z), x + truncdiv(y, z), CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && CanProveGreaterEqual(z.Eval(), 0)); - TVM_TRY_REWRITE_IF((y + x * z) / z, y / z + x, + TVM_TRY_REWRITE_IF(truncdiv(y + x * z, z), truncdiv(y, z) + x, CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && CanProveGreaterEqual(z.Eval(), 0)); - TVM_TRY_REWRITE_IF((y + z * x) / z, y / z + x, + TVM_TRY_REWRITE_IF(truncdiv(y + z * x, z), truncdiv(y, z) + x, CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && CanProveGreaterEqual(z.Eval(), 0)); @@ -679,15 +696,15 @@ Mutate_(const Mod* op, const Expr& self) { // Vector rules if (op->type.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) % broadcast(y, lanes), - broadcast(x % y, lanes)); + TVM_TRY_REWRITE(truncmod(broadcast(x, lanes), broadcast(y, lanes)), + broadcast(truncmod(x, y), lanes)); // ramp % bcast - if ((ramp(b1, c1, lanes) % broadcast(c2, lanes)).Match(ret)) { + if (truncmod(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; if (c1val % c2val == 0) { - return broadcast(b1 % c2, lanes).Eval(); + return broadcast(truncmod(b1, c2), lanes).Eval(); } // If all possible indices in ramp are the same. if (CanProveGreaterEqual(b1.Eval(), 0)) { @@ -696,9 +713,10 @@ Mutate_(const Mod* op, const Expr& self) { int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val; if (bmod->coeff % c2val == 0) { if (ramp_min == ramp_max) { - return ramp(bmod->base % c2, c1, lanes).Eval(); + return ramp(truncmod(bmod->base, c2), c1, lanes).Eval(); } else { - return (ramp(bmod->base % c2, c1, lanes) % broadcast(c2, lanes)).Eval(); + return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes), + broadcast(c2, lanes)).Eval(); } } } @@ -709,23 +727,23 @@ Mutate_(const Mod* op, const Expr& self) { // Be-aware of the division rules: // We adopt the default C division uses truncation instead of floordiv. // This means most rules need to check non-negativeness of the operands. - TVM_TRY_REWRITE_IF((x * c1) % c2, ZeroWithTypeLike(x), + TVM_TRY_REWRITE_IF(truncmod(x * c1, c2), ZeroWithTypeLike(x), c2.Eval()->value != 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF((x * c1 + y) % c2, y % c2, + TVM_TRY_REWRITE_IF(truncmod(x * c1 + y, c2), truncmod(y, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual((x * c1).Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); - TVM_TRY_REWRITE_IF((x + c1) % c2, x % c2, + TVM_TRY_REWRITE_IF(truncmod(x + c1, c2), truncmod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value >= 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0)); - TVM_TRY_REWRITE_IF((x + y * c1) % c2, x % c2, + TVM_TRY_REWRITE_IF(truncmod(x + y * c1, c2), truncmod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0) && @@ -733,18 +751,18 @@ Mutate_(const Mod* op, const Expr& self) { // canonicalization: x % c == x % (-c) for truncated division // NOTE: trunc div required - TVM_TRY_RECURSIVE_REWRITE_IF(x % c1, - x % PConst(make_const(op->type, -c1.Eval()->value)), + TVM_TRY_RECURSIVE_REWRITE_IF(truncmod(x, c1), + truncmod(x, PConst(make_const(op->type, -c1.Eval()->value))), c1.Eval()->value < 0); // try modular analysis - if ((x % c1).Match(ret)) { + if (truncmod(x, c1).Match(ret)) { ModularSet mod = analyzer_->modular_set(x.Eval()); int64_t c1val = c1.Eval()->value; if (mod->coeff % c1val == 0 && c1val > 0 && CanProveGreaterEqual(x.Eval(), 0)) { - return (mod->base % c1).Eval(); + return truncmod(mod->base, c1).Eval(); } } } @@ -798,7 +816,7 @@ Mutate_(const FloorDiv* op, const Expr& self) { int64_t c2val = c2.Eval()->value; if (c1val > 0 && c2val > 0) { if (c1val % c2val == 0) return (x * floordiv(c1, c2)).Eval(); - if (c2val % c1val == 0) return (floordiv(x, floordiv(c2, c1))).Eval(); + if (c2val % c1val == 0) return floordiv(x, floordiv(c2, c1)).Eval(); } } @@ -1025,18 +1043,18 @@ Mutate_(const Min* op, const Expr& self) { // DivMod rules // Divide up rounding: truc div // NOTE: trucdiv(x, y) >= floordiv(x, y) - TVM_TRY_REWRITE_IF(min(((x + c1) / c2) * c2, x), x, + TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, x), x, c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); - TVM_TRY_REWRITE_IF(min(((x + c1) / c2) * c2, max(x, c2)), max(x, c2), + TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, max(x, c2)), max(x, c2), c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && CanProveGreaterEqual(x.Eval(), 0)); - TVM_TRY_REWRITE_IF(min(x, ((x + c1) / c2) * c2), x, + TVM_TRY_REWRITE_IF(min(x, truncdiv(x + c1, c2) * c2), x, c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); - TVM_TRY_REWRITE_IF(min(max(x, c2), ((x + c1) / c2) * c2), max(x, c2), + TVM_TRY_REWRITE_IF(min(max(x, c2), truncdiv(x + c1, c2) * c2), max(x, c2), c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && CanProveGreaterEqual(x.Eval(), 0)); @@ -1104,11 +1122,11 @@ Mutate_(const Min* op, const Expr& self) { TVM_TRY_REWRITE(min(min(x, c1), c2), min(x, min(c1, c2))); // scaling rule - if (min(x / c1, y / c1).Match(ret)) { + if (min(truncdiv(x, c1), truncdiv(y, c1)).Match(ret)) { if (c1.Eval()->value > 0) { - return (min(x, y) / c1).Eval(); + return truncdiv(min(x, y), c1).Eval(); } else { - return (max(x, y) / c1).Eval(); + return truncdiv(max(x, y), c1).Eval(); } } if (min(floordiv(x, c1), floordiv(y, c1)).Match(ret)) { @@ -1210,10 +1228,12 @@ Mutate_(const Max* op, const Expr& self) { // DivMod rules // Divide up rounding: truc div // NOTE: trucdiv(x, y) >= floordiv(x, y) - TVM_TRY_REWRITE_IF(max(((x + c1) / c2) * c2, x), ((x + c1) / c2) * c2, + TVM_TRY_REWRITE_IF(max(truncdiv(x + c1, c2) * c2, x), + truncdiv(x + c1, c2) * c2, c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(x, ((x + c1) / c2) * c2), ((x + c1) / c2) * c2, + TVM_TRY_REWRITE_IF(max(x, truncdiv(x + c1, c2) * c2), + truncdiv(x + c1, c2) * c2, c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); @@ -1276,11 +1296,11 @@ Mutate_(const Max* op, const Expr& self) { TVM_TRY_REWRITE(max(max(x, c1), c2), max(x, max(c1, c2))); // scaling rule - if (max(x / c1, y / c1).Match(ret)) { + if (max(truncdiv(x, c1), truncdiv(y, c1)).Match(ret)) { if (c1.Eval()->value > 0) { - return (max(x, y) / c1).Eval(); + return truncdiv(max(x, y), c1).Eval(); } else { - return (min(x, y) / c1).Eval(); + return truncdiv(min(x, y), c1).Eval(); } } if (max(floordiv(x, c1), floordiv(y, c1)).Match(ret)) { @@ -1425,70 +1445,70 @@ Mutate_(const LT* op, const Expr& self) { // constant cancelation: only need to make use of one mod // truc div - TVM_TRY_REWRITE_IF(x * c2 < c1, x < (c1 - 1) / c2 + 1, + TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1 - 1, c2) + 1, c1.Eval()->value > 0 && c2.Eval()->value > 0); // NOTE: trunc div required - TVM_TRY_REWRITE_IF(x * c2 < c1, x < c1 / c2, + TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1, c2), c1.Eval()->value <= 0 && c2.Eval()->value > 0); // NOTE: trunc div required (euclidean is ok too, floored is not) - TVM_TRY_REWRITE_IF(x * c2 < c1, (c1 - 1) / c2 - 1 < x, + TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1 - 1, c2) - 1 < x, c1.Eval()->value > 0 && c2.Eval()->value < 0); // NOTE: trunc div required (floored is ok too, euclidean is not) - TVM_TRY_REWRITE_IF(x * c2 < c1, c1 / c2 < x, + TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1, c2) < x, c1.Eval()->value <= 0 && c2.Eval()->value < 0); // NOTE: trunc div required - TVM_TRY_REWRITE_IF(c1 < x * c2, (c1 + 1) / c2 - 1 < x, + TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1 + 1, c2) - 1 < x, c1.Eval()->value < 0 && c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF(c1 < x * c2, c1 / c2 < x, + TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1, c2) < x, c1.Eval()->value >= 0 && c2.Eval()->value > 0); // NOTE: trunc div required (floored is ok too, euclidean is not) - TVM_TRY_REWRITE_IF(c1 < x * c2, x < (c1 + 1) / c2 + 1, + TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1 + 1, c2) + 1, c1.Eval()->value < 0 && c2.Eval()->value < 0); // NOTE: trunc div required (euclidean is ok too, floored is not) - TVM_TRY_REWRITE_IF(c1 < x * c2, x < c1 / c2, + TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1, c2), c1.Eval()->value >= 0 && c2.Eval()->value < 0); // DivMod rules // trucdiv - TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2, + TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, x < c1 * c2, c1.Eval()->value > 0 && c2.Eval()->value > 0); // NOTE: trunc div required - TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * (c2 - 1) + 1, + TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, x < c1 * (c2 - 1) + 1, c1.Eval()->value > 0 && c2.Eval()->value <= 0); - TVM_TRY_REWRITE_IF(c1 < x / c2, (c1 + 1) * c2 - 1 < x, + TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), (c1 + 1) * c2 - 1 < x, c1.Eval()->value >= 0 && c2.Eval()->value > 0); // NOTE: trunc div required - TVM_TRY_REWRITE_IF(c1 < x / c2, c1 * c2 < x, + TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), c1 * c2 < x, c1.Eval()->value < 0 && c2.Eval()->value > 0); // invariance for any div mod: x - (x / c1) * c1 == x % c1 - TVM_TRY_REWRITE_IF((x / c1) * c1 < x, 0 < x % c1, + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x, 0 < truncmod(x, c1), c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF((x / c1) * c1 < x + y, 0 < x % c1 + y, + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x + y, 0 < truncmod(x, c1) + y, c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF((x / c1) * c1 < x - y, y < x % c1, + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x - y, y < truncmod(x, c1), c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(((x + c2) / c1) * c1 < x, - c2 < (x + c2) % c1, + TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x, + c2 < truncmod(x + c2, c1), c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(((x + c2) / c1) * c1 < x + y, - c2 < (x + c2) % c1 + y, + TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x + y, + c2 < truncmod(x + c2, c1) + y, c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(((x + c2) / c1) * c1 < x - y, - y < (x + c2) % c1 + (0 - c2), + TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x - y, + y < truncmod(x + c2, c1) + (0 - c2), c1.Eval()->value > 0); // floordiv diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index d7a40c133784..f66b997bb934 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -178,13 +178,19 @@ Expr operator*(Expr a, Expr b) { return ir::Mul::make(a, b); } -Expr truncdiv(Expr a, Expr b) { +Expr div(Expr a, Expr b) { BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::Div::make(a, b); } +Expr truncdiv(Expr a, Expr b) { + CHECK(a.type().is_int() || a.type().is_uint()); + CHECK(b.type().is_int() || b.type().is_uint()); + return div(a, b); +} + Expr truncmod(Expr a, Expr b) { BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); @@ -193,7 +199,7 @@ Expr truncmod(Expr a, Expr b) { } Expr operator/(Expr a, Expr b) { - return truncdiv(a, b); + return div(a, b); } Expr operator%(Expr a, Expr b) { diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 934ac620fb73..7fb654b5d9d4 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -47,9 +47,9 @@ TEST(Pattern, Basic) { } CHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1)))); CHECK((px + min(py, px)).Match(z + min(y, z))); - CHECK((px + py / (px * py)).Match(x + 2 / (x * 2))); - CHECK((px - py % (px * pz)).Match(x - 2 % (x * 2))); - CHECK((px - py % (px * PConst(2))).Match(x - 2 % (x * 2))); + CHECK((px + truncdiv(py, px * py)).Match(x + truncdiv(2, x * 2))); + CHECK((px - truncmod(py, px * pz)).Match(x - truncmod(2, x * 2))); + CHECK((px - floormod(py, px * PConst(2))).Match(x - floormod(2, x * 2))); // logicals CHECK((px == pz).Match(x == 1)); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index ca303544157c..246ac1339fb2 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -56,24 +56,26 @@ def test_vector_simplify(): tvm.expr.Ramp(x * 2, 8, 4)) ## DivMod rules + tdiv = tvm.truncdiv + tmod = tvm.truncmod # truc div - ck.verify(y.astype("int32x2") / x.astype("int32x2"), - (y / x).astype("int32x2")) - ck.verify(tvm.expr.Ramp(x, 4, 4) / 2, - tvm.expr.Ramp(x/ 2, 2, 4)) + ck.verify(tdiv(y.astype("int32x2"), x.astype("int32x2")), + tdiv(y, x).astype("int32x2")) + ck.verify(tdiv(tvm.expr.Ramp(x, 4, 4), 2), + tvm.expr.Ramp(tdiv(x, 2), 2, 4)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) - ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) / 8, + ck.verify(tdiv(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")) - ck.verify(tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8, - tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8) - ck.verify(y.astype("int32x2") % x.astype("int32x2"), - (y % x).astype("int32x2")) - ck.verify(tvm.expr.Ramp(x, 4, 4) % 2, - tvm.expr.Broadcast(x % 2, 4)) - ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) % 8, + ck.verify(tdiv(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8), + tdiv(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8)) + ck.verify(tmod(y.astype("int32x2"), x.astype("int32x2")), + tmod(y, x).astype("int32x2")) + ck.verify(tmod(tvm.expr.Ramp(x, 4, 4), 2), + tvm.expr.Broadcast(tmod(x, 2), 4)) + ck.verify(tmod(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8), tvm.expr.Ramp(1, 1, 4)) - ck.verify(tvm.expr.Ramp(x * 8 + 1, 15, 4) % 8, - tvm.expr.Ramp(1, 15, 4) % 8) + ck.verify(tmod(tvm.expr.Ramp(x * 8 + 1, 15, 4), 8), + tmod(tvm.expr.Ramp(1, 15, 4), 8)) # floor div fld = tvm.floordiv @@ -187,10 +189,12 @@ def test_add_index_simplify(): ck.verify(x + 2 + 3 + 4 + x * 3, x * 4 + 9); # DivMod rules + tdiv = tvm.truncdiv + tmod = tvm.truncmod # truc div - ck.verify(y * (x % 8) + 10 * (x % 8), (x % 8) * (y + 10)) + ck.verify(y * tmod(x, 8) + 10 * tmod(x, 8), tmod(x, 8) * (y + 10)) ck.analyzer.update(x, tvm.arith.ConstIntBound(-1, 1000), override=True) - ck.verify((x / 8) * 8 + x % 8, x) + ck.verify(tdiv(x, 8) * 8 + tmod(x, 8), x) # floor div fld = tvm.floordiv @@ -256,31 +260,33 @@ def test_sub_index_simplify(): # DivMod patterns # truc div + tdiv = tvm.truncdiv + tmod = tvm.truncmod ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) - ck.verify(x - (x / 3) * 3, x % 3) - - ck.verify((x + 5) / 3 - x / 3, ((x % 3) + 5)/ 3) - ck.verify((x + 5) / 3 - (x + 1) / 3, (((x + 1) % 3) + 4)/ 3) - - ck.verify(y - (y / (-5)) * (-5), y % 5) - ck.verify((y / 3) * 3 - y, 0 - y % 3) - ck.verify(y - ((y - 6) / 5) * 5, (y + (-6)) % 5 + 6) - ck.verify(((y - 6) / 5) * 5 - y, (-6) - (y + (-6)) % 5) - ck.verify(y - ((y + z) / 5) * 5, (y + z) % 5 - z) - ck.verify(((y + z) / 5) * 5 - y, z - (y + z) % 5) - ck.verify(y - ((y - z) / 5) * 5, (y - z) % 5 + z) - ck.verify(((y - z) / 5) * 5 - y, 0 - (y - z) % 5 - z) - - ck.verify(y * 3 - (y / 2) * 6, (y % 2) * 3) - ck.verify((y / 3) * 6 - y * 2, (y % 3) * (-2)) - ck.verify(y * 5 - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5) - ck.verify(y * 5 - ((y - z) / 2) * 10, ((y - z) % 2 + z) * 5) - ck.verify(((y + z) / 3) * 6 - y * 2, (z - (y + z) % 3) * 2) - ck.verify(((y - z) / 3) * 6 - y * 2, (0 - (y - z) % 3 - z) * 2) - ck.verify(5 * y - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5) - ck.verify(5 * y - 10 * ((y - z) / 2), ((y - z) % 2 + z) * 5) - ck.verify(6 * ((y + z) / 3) - y * 2, (z - (y + z) % 3) * 2) - ck.verify(((y - z) / 3) * 6 - 2 * y, (0 - (y - z) % 3 - z) * 2) + ck.verify(x - tdiv(x, 3) * 3, tmod(x, 3)) + + ck.verify(tdiv(x + 5, 3) - tdiv(x, 3), tdiv(tmod(x, 3) + 5, 3)) + ck.verify(tdiv(x + 5, 3) - tdiv(x + 1, 3), tdiv(tmod(x + 1, 3) + 4, 3)) + + ck.verify(y - tdiv(y, (-5)) * (-5), tmod(y, 5)) + ck.verify(tdiv(y, 3) * 3 - y, 0 - tmod(y, 3)) + ck.verify(y - tdiv(y - 6, 5) * 5, tmod(y + (-6), 5) + 6) + ck.verify(tdiv(y - 6, 5) * 5 - y, (-6) - tmod(y + (-6), 5)) + ck.verify(y - tdiv(y + z, 5) * 5, tmod(y + z, 5) - z) + ck.verify(tdiv(y + z, 5) * 5 - y, z - tmod(y + z, 5)) + ck.verify(y - tdiv(y - z, 5) * 5, tmod(y - z, 5) + z) + ck.verify(tdiv(y - z, 5) * 5 - y, 0 - tmod(y - z, 5) - z) + + ck.verify(y * 3 - tdiv(y, 2) * 6, tmod(y, 2) * 3) + ck.verify(tdiv(y, 3) * 6 - y * 2, tmod(y, 3) * (-2)) + ck.verify(y * 5 - tdiv(y + z, 2) * 10, (tmod(y + z, 2) - z) * 5) + ck.verify(y * 5 - tdiv(y - z, 2) * 10, (tmod(y - z, 2) + z) * 5) + ck.verify(tdiv(y + z, 3) * 6 - y * 2, (z - tmod(y + z, 3)) * 2) + ck.verify(tdiv(y - z, 3) * 6 - y * 2, (0 - tmod(y - z, 3) - z) * 2) + ck.verify(5 * y - tdiv(y + z, 2) * 10, (tmod(y + z, 2) - z) * 5) + ck.verify(5 * y - 10 * tdiv(y - z, 2), (tmod(y - z, 2) + z) * 5) + ck.verify(6 * tdiv(y + z, 3) - y * 2, (z - tmod(y + z, 3)) * 2) + ck.verify(tdiv(y - z, 3) * 6 - 2 * y, (0 - tmod(y - z, 3) - z) * 2) # floor div fld = tvm.floordiv @@ -323,46 +329,48 @@ def test_mul_index_simplify(): def test_div_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") + tdiv = tvm.truncdiv + tmod = tvm.truncmod - ck.verify(x / x, 1) + ck.verify(tdiv(x, x), 1) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 1000), override=True) - ck.verify(x / 2 / 3, x / 6) - ck.verify((x / 2 + 1) / 3, (x + 2) / 6) - ck.verify(x * 2 / 4, x / 2) - ck.verify(x * 4 / 2, x * 2) + ck.verify(tdiv(tdiv(x, 2), 3), tdiv(x, 6)) + ck.verify(tdiv(tdiv(x, 2) + 1, 3), tdiv(x + 2, 6)) + ck.verify(tdiv(x * 2, 4), tdiv(x, 2)) + ck.verify(tdiv(x * 4, 2), x * 2) - ck.verify((x * 4 + y) / 2, x * 2 + y / 2) - ck.verify(tvm.min(x * 6, y) / 2, tvm.min(x * 3, y / 2)) - ck.verify(tvm.max(x * 6, y) / 2, tvm.max(x * 3, y / 2)) + ck.verify(tdiv(x * 4 + y, 2), x * 2 + tdiv(y, 2)) + ck.verify(tdiv(tvm.min(x * 6, y), 2), tvm.min(x * 3, tdiv(y, 2))) + ck.verify(tdiv(tvm.max(x * 6, y), 2), tvm.max(x * 3, tdiv(y, 2))) - ck.verify((y + x * 4) / 2, y / 2 + x * 2) - ck.verify(tvm.min(y, x * 6) / 2, tvm.min(y / 2, x * 3)) - ck.verify(tvm.max(y, x * 6) / 2, tvm.max(y / 2, x * 3)) + ck.verify(tdiv(y + x * 4, 2), tdiv(y, 2) + x * 2) + ck.verify(tdiv(tvm.min(y, x * 6), 2), tvm.min(tdiv(y, 2), x * 3)) + ck.verify(tdiv(tvm.max(y, x * 6), 2), tvm.max(tdiv(y, 2), x * 3)) # 3-operands - ck.verify((x * 6 + y + z) / 2, x * 3 + (y + z) / 2) - ck.verify((x * 6 - y + (y + 3)) / 2, x * 3 + 1) - ck.verify((x * 6 + (y + 3) - y) / 2, x * 3 + 1) - ck.verify((y + x * 6 + z) / 2, x * 3 + (y + z) / 2) - ck.verify((x + 4) / 2, x / 2 + 2) + ck.verify(tdiv(x * 6 + y + z, 2), x * 3 + tdiv(y + z, 2)) + ck.verify(tdiv(x * 6 - y + (y + 3), 2), x * 3 + 1) + ck.verify(tdiv(x * 6 + (y + 3) - y, 2), x * 3 + 1) + ck.verify(tdiv(y + x * 6 + z, 2), x * 3 + tdiv(y + z, 2)) + ck.verify(tdiv(x + 4, 2), tdiv(x, 2) + 2) - ck.verify((x + y) / x, y / x + 1) - ck.verify((y + x) / x, y / x + 1) - ck.verify(((x + y) + z) / x, (y + z) / x + 1) - ck.verify(((y + x) + z) / x, (y + z) / x + 1) - ck.verify((y + (x + z)) / x, (y + z) / x + 1) - ck.verify((y + (z + x)) / x, (y + z) / x + 1) + ck.verify(tdiv(x + y, x), tdiv(y, x) + 1) + ck.verify(tdiv(y + x, x), tdiv(y, x) + 1) + ck.verify(tdiv((x + y) + z, x), tdiv(y + z, x) + 1) + ck.verify(tdiv((y + x) + z, x), tdiv(y + z, x) + 1) + ck.verify(tdiv(y + (x + z), x), tdiv(y + z, x) + 1) + ck.verify(tdiv(y + (z + x), x), tdiv(y + z, x) + 1) - ck.verify((x * y) / y, x) - ck.verify((y * x) / y, x) + ck.verify(tdiv(x * y, y), x) + ck.verify(tdiv(y * x, y), x) - ck.verify((x * z + y) / z, x + y / z) - ck.verify((z * x + y) / z, x + y / z) - ck.verify((y + x * z) / z, y / z + x) - ck.verify((y + z * x) / z, y / z + x) + ck.verify(tdiv(x * z + y, z), x + tdiv(y, z)) + ck.verify(tdiv(z * x + y, z), x + tdiv(y, z)) + ck.verify(tdiv(y + x * z, z), tdiv(y, z) + x) + ck.verify(tdiv(y + z * x, z), tdiv(y, z) + x) def test_floordiv_index_simplify(): @@ -417,31 +425,33 @@ def test_mod_index_simplify(): ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(nx, tvm.arith.ConstIntBound(-1000, 0), override=True) ck.analyzer.update(ny, tvm.arith.ConstIntBound(-1000, 0), override=True) - - ck.verify(x * 10 % 2, 0) - ck.verify((x * 10 + y) % 2, y % 2) - ck.verify((x + 10) % 2, x % 2) - ck.verify((x + y * 10) % 2, x % 2) - ck.verify((x* 10 + 1 + y * 2 + 2) % 2, 1) - ck.verify(x * 10 % -2, 0) - ck.verify((x * 10 + y) % -2, y % 2) - ck.verify((x + 10) % -2, x % 2) - ck.verify((x + y * 10) % -2, x % 2) - ck.verify((x* 10 + 1 + y * 2 + 2) % -2, 1) - - ck.verify(x * (-10) % 2, 0) - ck.verify((x * (-10) + y) % 2, (x * (-10) + y) % 2) - ck.verify((x + (-10)) % 2, (x + (-10)) % 2) - ck.verify((x + y * (-10)) % 2, (x + y * (-10)) % 2) - ck.verify(x * (-10) % -2, 0) - - ck.verify(nx * 10 % 2, 0) - ck.verify((nx * (-10) + y) % 2, y % 2) - ck.verify((x + ny * (-10)) % 2, x % 2) - ck.verify((nx * (-10) + 1 + ny * (-2) + 2) % 2, 1) - ck.verify(nx * 10 % -2, 0) - ck.verify((nx * (-10) + y) % -2, y % 2) - ck.verify((x + ny * (-10)) % -2, x % 2) + tdiv = tvm.truncdiv + tmod = tvm.truncmod + + ck.verify(tmod(x * 10, 2), 0) + ck.verify(tmod(x * 10 + y, 2), tmod(y, 2)) + ck.verify(tmod(x + 10, 2), tmod(x, 2)) + ck.verify(tmod(x + y * 10, 2), tmod(x, 2)) + ck.verify(tmod(x* 10 + 1 + y * 2 + 2, 2), 1) + ck.verify(tmod(x * 10, -2), 0) + ck.verify(tmod(x * 10 + y, -2), tmod(y, 2)) + ck.verify(tmod(x + 10, -2), tmod(x, 2)) + ck.verify(tmod(x + y * 10, -2), tmod(x, 2)) + ck.verify(tmod(x* 10 + 1 + y * 2 + 2, -2), 1) + + ck.verify(tmod(x * (-10), 2), 0) + ck.verify(tmod(x * (-10) + y, 2), tmod(x * (-10) + y, 2)) + ck.verify(tmod(x + (-10), 2), tmod(x + (-10), 2)) + ck.verify(tmod(x + y * (-10), 2), tmod(x + y * (-10), 2)) + ck.verify(tmod(x * (-10), -2), 0) + + ck.verify(tmod(nx * 10, 2), 0) + ck.verify(tmod(nx * (-10) + y, 2), tmod(y, 2)) + ck.verify(tmod(x + ny * (-10), 2), tmod(x, 2)) + ck.verify(tmod(nx * (-10) + 1 + ny * (-2) + 2, 2), 1) + ck.verify(tmod(nx * 10, -2), 0) + ck.verify(tmod(nx * (-10) + y, -2), tmod(y, 2)) + ck.verify(tmod(x + ny * (-10), -2), tmod(x, 2)) def test_floormod_index_simplify(): @@ -468,8 +478,10 @@ def test_min_index_simplify(): x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") fld = tvm.floordiv flm = tvm.floormod + tdiv = tvm.truncdiv + tmod = tvm.truncmod # const int bound - ck.verify(tvm.min(x % 2, y % 2 + 10), x % 2) + ck.verify(tvm.min(tmod(x, 2), tmod(y, 2) + 10), tmod(x, 2)) ck.verify(tvm.min(flm(x, 2), flm(y, 2) + 10), flm(x, 2)) ck.verify(tvm.min(x + 1, x + 10), x + 1) @@ -521,13 +533,14 @@ def test_min_index_simplify(): # DivMod rules # truc div ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000)) - ck.verify(tvm.min((x + 3) / 4 * 4, x), x) - ck.verify(tvm.min((x + 3) / 4 * 4, tvm.max(x, 4)), tvm.max(x, 4)) - ck.verify(tvm.min(x, (x + 3) / 4 * 4), x) - ck.verify(tvm.min(tvm.max(x, 4), (x + 3) / 4 * 4), tvm.max(x, 4)) + ck.verify(tvm.min(tdiv(x + 3, 4) * 4, x), x) + ck.verify(tvm.min(tdiv(x + 3, 4) * 4, tvm.max(x, 4)), tvm.max(x, 4)) + ck.verify(tvm.min(x, tdiv(x + 3, 4) * 4), x) + ck.verify(tvm.min(tvm.max(x, 4), tdiv(x + 3, 4) * 4), tvm.max(x, 4)) ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True) - ck.verify(tvm.min(x / 10, y / 10), tvm.min(x, y) / 10) - ck.verify(tvm.min(x / (-10), y / (-10)), tvm.max(x, y) / (-10)) + ck.verify(tvm.min(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.min(x, y), 10)) + ck.verify(tvm.min(tdiv(x, (-10)), tdiv(y, (-10))), + tdiv(tvm.max(x, y), (-10))) # floor div ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True) @@ -545,8 +558,10 @@ def test_max_index_simplify(): x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") flm = tvm.floormod fld = tvm.floordiv + tdiv = tvm.truncdiv + tmod = tvm.truncmod # const int bound - ck.verify(tvm.max(x % 2, y % 2 + 10), y % 2 + 10) + ck.verify(tvm.max(tmod(x, 2), tmod(y, 2) + 10), tmod(y, 2) + 10) ck.verify(tvm.max(flm(x, 2), flm(y, 2) + 10), flm(y, 2) + 10) ck.verify(tvm.max(x + 1, x + 10), x + 10) @@ -597,9 +612,9 @@ def test_max_index_simplify(): # DivMod rules # truc div - ck.verify(tvm.max(x / 10, y / 10), tvm.max(x, y) / 10) - ck.verify(tvm.max(x / (-10), y / (-10)), tvm.min(x, y) / (-10)) - ck.verify(tvm.max((x + 3) / 4 * 4, x), (x + 3) / 4 * 4) + ck.verify(tvm.max(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.max(x, y), 10)) + ck.verify(tvm.max(tdiv(x, (-10)), tdiv(y, (-10))), tdiv(tvm.min(x, y), (-10))) + ck.verify(tvm.max(tdiv(x + 3, 4) * 4, x), tdiv(x + 3, 4) * 4) # floordiv ck.verify(tvm.max(fld(x, 10), fld(y, 10)), fld(tvm.max(x, y), 10)) @@ -614,11 +629,13 @@ def test_cmp_simplify(): x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") flm = tvm.floormod fld = tvm.floordiv + tdiv = tvm.truncdiv + tmod = tvm.truncmod # const int bound - ck.verify((x % 2 + 10).equal(0), tvm.const(0, "bool")) - ck.verify(tvm.expr.NE(x % 2 + 10, 0), tvm.const(1, "bool")) - ck.verify(x % 2 + 10 > 1, tvm.const(1, "bool")) - ck.verify(x % 2 + 10 <= 1, tvm.const(0, "bool")) + ck.verify((tmod(x, 2) + 10).equal(0), tvm.const(0, "bool")) + ck.verify(tvm.expr.NE(tmod(x, 2) + 10, 0), tvm.const(1, "bool")) + ck.verify(tmod(x, 2) + 10 > 1, tvm.const(1, "bool")) + ck.verify(tmod(x, 2) + 10 <= 1, tvm.const(0, "bool")) ck.verify(flm(x, 2) + 2 > 1, tvm.const(1, "bool")) ck.verify(flm(x, 2) + 10 <= 1, tvm.const(0, "bool")) @@ -688,26 +705,26 @@ def test_cmp_simplify(): # DivMod rules # truc div - ck.verify(x / 2 < 3, x < 6) - ck.verify(3 < x / 2, tvm.expr.LT(7, x)) - ck.verify(x / 3 >= 0, tvm.expr.LE(-2, x)) - ck.verify(x / 2 >= 1, tvm.expr.LE(2, x)) - ck.verify(x / 2 >= 0, tvm.expr.LE(-1, x)) - ck.verify(x / 2 >= -1, tvm.expr.LE(-3, x)) + ck.verify(tdiv(x, 2) < 3, x < 6) + ck.verify(3 < tdiv(x, 2), tvm.expr.LT(7, x)) + ck.verify(tdiv(x, 3) >= 0, tvm.expr.LE(-2, x)) + ck.verify(tdiv(x, 2) >= 1, tvm.expr.LE(2, x)) + ck.verify(tdiv(x, 2) >= 0, tvm.expr.LE(-1, x)) + ck.verify(tdiv(x, 2) >= -1, tvm.expr.LE(-3, x)) - ck.verify(x / 2 <= 1, tvm.expr.LE(x, 3)) - ck.verify(x / 2 <= 0, tvm.expr.LE(x, 1)) - ck.verify(x / 2 <= -1, tvm.expr.LE(x, -2)) + ck.verify(tdiv(x, 2) <= 1, tvm.expr.LE(x, 3)) + ck.verify(tdiv(x, 2) <= 0, tvm.expr.LE(x, 1)) + ck.verify(tdiv(x, 2) <= -1, tvm.expr.LE(x, -2)) - ck.verify(x / 4 * 4 < x, tvm.expr.LT(0, x % 4)) - ck.verify(x / 4 * 4 >= x, tvm.expr.LE(x % 4, 0)) + ck.verify(tdiv(x, 4) * 4 < x, tvm.expr.LT(0, tmod(x, 4))) + ck.verify(tdiv(x, 4) * 4 >= x, tvm.expr.LE(tmod(x, 4), 0)) - ck.verify(x / 4 * 4 < x + y, tvm.expr.LT(0, x % 4 + y)) - ck.verify(x / 4 * 4 < x - y, tvm.expr.LT(y, x % 4)) + ck.verify(tdiv(x, 4) * 4 < x + y, tvm.expr.LT(0, tmod(x, 4) + y)) + ck.verify(tdiv(x, 4) * 4 < x - y, tvm.expr.LT(y, tmod(x, 4))) - ck.verify((x + 2) / 4 * 4 >= x, tvm.expr.LE((x + 2) % 4, 2)) - ck.verify((x + 2) / 4 * 4 >= x + y, tvm.expr.LE((x + 2) % 4 + y, 2)) - ck.verify((x + 2) / 4 * 4 >= x - y, tvm.expr.LE((x + 2) % 4 + (-2), y)) + ck.verify(tdiv(x + 2, 4) * 4 >= x, tvm.expr.LE(tmod(x + 2, 4), 2)) + ck.verify(tdiv(x + 2, 4) * 4 >= x + y, tvm.expr.LE(tmod(x + 2, 4) + y, 2)) + ck.verify(tdiv(x + 2, 4) * 4 >= x - y, tvm.expr.LE(tmod(x + 2, 4) + (-2), y)) # floor div ck.verify(fld(x, 2) < 3, x < 6) @@ -753,7 +770,7 @@ def test_cmp_simplify(): ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool")) ck.verify(y*y >= 0, tvm.const(1, "bool")) ck.verify(x*6 <= -3, tvm.const(0, "bool")) - ck.verify((y - 1) % 3 == 0, (y + (-1)) % 3 == 0) + ck.verify(tmod(y - 1, 3) == 0, tmod(y + (-1), 3) == 0) def test_logical_simplify():