diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 5e557a50cfaf..05aff7a5eee8 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -29,7 +29,6 @@ #include #include #include "pattern_match.h" -#include "modular_set.h" namespace tvm { namespace arith { @@ -338,6 +337,25 @@ class ModularSetAnalyzer::Impl : return Nothing(); } } + /*! + * \brief Take GCD of a and b. + * \param a The first operand. + * \param b The second operand. + * \return The result. + */ + static int64_t ZeroAwareGCD(int64_t a, int64_t b) { + if (a < 0) a = -a; + if (b < 0) b = -b; + if (a < b) std::swap(a, b); + if (b == 0) return a; + // perform GCD (greatest common divisor) + // ax + by = gcd(a, b) z if a != 0, b != 0 + while (a % b != 0) { + a = a % b; + std::swap(a, b); + } + return b; + } /*! * \brief return everything dtype can represent. * \return Bound that represent everything dtype can represent. @@ -377,25 +395,5 @@ ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; } -/*! - * \brief Take GCD of a and b. - * \param a The first operand. - * \param b The second operand. - * \return The result. - */ -int64_t ZeroAwareGCD(int64_t a, int64_t b) { - if (a < 0) a = -a; - if (b < 0) b = -b; - if (a < b) std::swap(a, b); - if (b == 0) return a; - // perform GCD (greatest common divisor) - // ax + by = gcd(a, b) z if a != 0, b != 0 - while (a % b != 0) { - a = a % b; - std::swap(a, b); - } - return b; -} - } // namespace arith } // namespace tvm diff --git a/src/arithmetic/modular_set.h b/src/arithmetic/modular_set.h deleted file mode 100644 index aca38bdfd640..000000000000 --- a/src/arithmetic/modular_set.h +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/arithmetic/modular_set.h - * \brief Modular set analysis - */ -#ifndef TVM_ARITHMETIC_MODULAR_SET_H_ -#define TVM_ARITHMETIC_MODULAR_SET_H_ - -#include - -namespace tvm { -namespace arith { - -/*! - * \brief Take GCD of a and b. - * \param a The first operand. - * \param b The second operand. - * \return The result. - */ -int64_t ZeroAwareGCD(int64_t a, int64_t b); - -} // namespace arith -} // namespace tvm -#endif // TVM_ARITHMETIC_MODULAR_SET_H_ diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 6ca58ab2ec56..a404ba0774f3 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -29,7 +29,6 @@ #include #include "const_fold.h" #include "pattern_match.h" -#include "modular_set.h" #include "rewrite_simplify.h" namespace tvm { @@ -176,16 +175,6 @@ Mutate_(const Add* op, const Expr& self) { TVM_TRY_REWRITE(y * x + x * z, x * (y + z)); TVM_TRY_REWRITE(x * y + z * x, x * (y + z)); TVM_TRY_REWRITE(y * x + z * x, x * (y + z)); - // Factor out gcd - if ((x * c1 + y * c2).Match(ret)) { - auto gcd = ZeroAwareGCD(c1.Eval()->value, c2.Eval()->value); - if (gcd != 1) { - auto b1 = PConstWithTypeLike>(x, c1.Eval()->value / gcd); - auto b2 = PConstWithTypeLike>(x, c2.Eval()->value / gcd); - auto pgcd = PConstWithTypeLike>(x, gcd); - return ((x * b1 + y * b2) * pgcd).Eval(); - } - } // modular-div simplification // Always pre-condition on positive integer domain @@ -260,16 +249,6 @@ Mutate_(const Sub* op, const Expr& self) { TVM_TRY_REWRITE(y * x - x * z, x * (y - z)); TVM_TRY_REWRITE(x * y - z * x, x * (y - z)); TVM_TRY_REWRITE(y * x - z * x, x * (y - z)); - // Factor out gcd - if ((x * c1 - y * c2).Match(ret)) { - auto gcd = ZeroAwareGCD(c1.Eval()->value, c2.Eval()->value); - if (gcd != 1) { - auto b1 = PConstWithTypeLike>(x, c1.Eval()->value / gcd); - auto b2 = PConstWithTypeLike>(x, c2.Eval()->value / gcd); - auto pgcd = PConstWithTypeLike>(x, gcd); - return ((x * b1 - y * b2) * pgcd).Eval(); - } - } // constant cancelation TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2)); @@ -316,8 +295,28 @@ Mutate_(const Sub* op, const Expr& self) { c1.Eval()->value != 0); TVM_TRY_REWRITE_IF(x - ((x - y) / c1) * c1, (x - y) % c1 + y, c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(((x - y) / c1) * c1 - x, ((y - x) % c1 - y), + TVM_TRY_REWRITE_IF(((x - y) / c1) * c1 - x, (y - x) % c1 - y, c1.Eval()->value != 0); + + TVM_TRY_REWRITE_IF(x * c2 - (x / c1) * c3, (x % c1) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF((x / c1) * c3 - x * c2, 0 - (x % c1) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF(x * c2 - ((x + y) / c1) * c3, ((x + y) % c1 - y) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF(((x + y) / c1) * c3 - x * c2, (y - ((x + y) % c1)) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF(x * c2 - ((x - y) / c1) * c3, ((x - y) % c1 + y) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF(((x - y) / c1) * c3 - x * c2, ((y - x) % c1 - y) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3, ((x + (c1 % c3)) % c3 + (c1 - c2)) / c3, CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) && diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index cae18e256549..563cc1ecf55f 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -236,6 +236,13 @@ def test_sub_index_simplify(): ck.verify(y - ((y - z) / 5) * 5, (y - z) % 5 + z) ck.verify(((y - z) / 5) * 5 - y, (z - y) % 5 - z) + ck.verify(y * 3 - (y / 2) * 6, (y % 2) * 3) + ck.verify((y / 3) * 6 - y * 2, (y % 3) * (-2)) + ck.verify(y * 5 - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5) + ck.verify(y * 5 - ((y - z) / 2) * 10, ((y - z) % 2 + z) * 5) + ck.verify(((y + z) / 3) * 6 - y * 2, (z - (y + z) % 3) * 2) + ck.verify(((y - z) / 3) * 6 - y * 2, ((z - y) % 3 - z) * 2) + def test_mul_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") @@ -466,8 +473,6 @@ def test_cmp_simplify(): ck.verify((0 - x * 3) <= 0, tvm.expr.LE(0, x)) ck.verify((0 - x * 3) >= 0, tvm.expr.LE(x, 0)) ck.verify(2 * x <= 0, x <= 0) - ck.verify(2 * x - 4 * y <= 0, x + y*(-2) <= 0) - ck.verify(2 * x + 4 * y <= 0, x + y*2 <= 0) ck.verify(x * 2 >= 3, tvm.expr.LE(2, x)) ck.verify(x * 2 >= 2, tvm.expr.LE(1, x))