diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 8112beef7551..ebf8f3c1db4a 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -36,6 +36,18 @@ struct ModularSetAnalyzer::Entry { int64_t coeff{1}; int64_t base{0}; + Entry() = default; + + Entry(int64_t coeff, int64_t base) { + CHECK_GE(coeff, 0); + this->coeff = coeff; + if (coeff != 0) { + base = base % coeff; + if (base < 0) base += coeff; + } + this->base = base; + } + bool is_const() const { return coeff == 0; } @@ -53,10 +65,7 @@ class ModularSetAnalyzer::Impl : if (!override) { CHECK(!var_map_.count(var)); } - Entry e; - e.coeff = info->coeff; - e.base = info->base; - var_map_[var] = e; + var_map_[var] = Entry(info->coeff, info->base); } // Detect useful constraints and use them in the analysis scope. @@ -65,9 +74,7 @@ class ModularSetAnalyzer::Impl : PVar coeff, base; // pattern match interesting constraints if (((var % coeff) == base).Match(constraint)) { - Entry entry; - entry.coeff = coeff.Eval()->value; - entry.base = base.Eval()->value; + Entry entry(coeff.Eval()->value, base.Eval()->value); return UpdateByIntersect(var.Eval(), entry); } return nullptr; @@ -83,18 +90,12 @@ class ModularSetAnalyzer::Impl : } Entry VisitExpr_(const IntImm* op) final { - Entry ret; - ret.base = op->value; - ret.coeff = 0; - return ret; + return Entry(0, op->value); } Entry VisitExpr_(const UIntImm* op) final { if (op->value < std::numeric_limits::max()) { - Entry ret; - ret.base = static_cast(op->value); - ret.coeff = 0; - return ret; + return Entry(0, static_cast(op->value)); } else { return Everything(); } @@ -103,19 +104,15 @@ class ModularSetAnalyzer::Impl : Entry VisitExpr_(const Add* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); - Entry ret; - ret.coeff = ZeroAwareGCD(a.coeff, b.coeff); - ret.base = BaseSimplify(a.base + b.base, ret.coeff); - return ret; + int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); + return Entry(coeff, a.base + b.base); } Entry VisitExpr_(const Sub* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); - Entry ret; - ret.coeff = ZeroAwareGCD(a.coeff, b.coeff); - ret.base = BaseSimplify(a.base - b.base, ret.coeff); - return ret; + int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); + return Entry(coeff, a.base - b.base); } Entry VisitExpr_(const Mul* op) final { @@ -128,10 +125,8 @@ class ModularSetAnalyzer::Impl : int64_t pq = a.coeff * b.coeff; int64_t pm = a.coeff * b.base; int64_t qn = a.base * b.coeff; - Entry ret; - ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn)); - ret.base = BaseSimplify(a.base * b.base, ret.coeff); - return ret; + int64_t coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn)); + return Entry(coeff, a.base * b.base); } Entry DivByConst(const Expr& lhs, @@ -140,20 +135,15 @@ class ModularSetAnalyzer::Impl : Entry a = VisitExpr(lhs); CHECK_NE(val, 0); if (a.coeff % val == 0) { - Entry ret; if (a.base == 0) { // a c x / c -> a x - ret.coeff = std::abs(a.coeff / val); - ret.base = 0; - return ret; + return Entry(std::abs(a.coeff / val), 0); } // positive division have a clear rounding mode. // Only handle case where we clearly know we need to round down. if (a.base > 0 && val > 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0))) { - ret.coeff = a.coeff / val; - ret.base = a.base / val; - return ret; + return Entry(a.coeff / val, a.base / val); } } return Everything(); @@ -251,41 +241,80 @@ class ModularSetAnalyzer::Impl : } int64_t base0 = a.base % coeff; int64_t base1 = b.base % coeff; - Entry ret; if (base0 == base1) { - ret.coeff = coeff; - ret.base = base0; - return ret; + return Entry(coeff, base0); } else { - ret.coeff = ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff); - ret.base = 0; - return ret; + return Entry(ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff), base0); } } + /*! + * \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b) + * \param a The first coefficient. + * \param b The second coefficient. + * \param x The solution of x. + * \param y The solution of y. + * \return The GCD of a and b. + */ + static int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) { + // Extended Euclidean algorithm + // if a < 0, the problem can be convert into + // |a|* (-x) + b * y = gcd(|a|, b) + // + // initial condition: + // a * 0 + b * 1 = b + // a * 1 + b * 0 = a + int64_t s = 0, old_s = 1; + int64_t r = b, old_r = a >= 0 ? a : -a; + // Iteration (r2 < r1): + // a * x1 + b * y1 = r1 + // a * x2 + b * y2 = r2 + // The above two eqs can derive the following eq (q = r1 / r2) + // a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3 + // Because r3 < r2, the iteration can eventually terminate + while (r != 0) { + int64_t q = old_r / r; + int64_t tmp = old_r; + old_r = r; + r = tmp - q * r; + tmp = old_s; + old_s = s; + s = tmp - q * s; + } + + *x = a >= 0 ? old_s : -old_s; + if (b != 0) { + *y = (old_r - (*x) * a) / b; + } else { + *y = 1; + } + + return old_r; + } /*! * \brief Create interect of two sets. * \param a The left operand. * \param b the right operand. */ static Entry Intersect(Entry a, Entry b) { - // simple rule for now: pick higher constraints. - // TODO(team-team): Use extended euclidean algorithm. - if (a.coeff == 0) return a; - if (b.coeff == 0) return b; - if (a.coeff >= b.coeff) return a; - return b; - } - /*! - * \brief Simplify base so that it is in [0, coeff) when coeff != 0. - * \param base The base value. - * \param coeff The coeff value. - * \return The simplified base. - */ - static int64_t BaseSimplify(int64_t base, int64_t coeff) { - if (coeff == 0) return base; - base = base % coeff; - if (base < 0) base += coeff; - return base; + int64_t x, y; + int64_t c1 = a.coeff, b1 = a.base, c2 = b.coeff, b2 = b.base; + // z = c1 * p + b1 + // z = c2 * q + b2 + // c1 * x + c2 * y = gcd(c1, c2) + // -> c1 * p - c2 * q = b2 - b1 + // -> p = (b2 - b1) / gcd * x + // -> q = (b2 - b1) / gcd * (-y) + // -> z = LCM(x, y) * k + (c1 * p + b1) + int64_t gcd = ExtendedEuclidean(c1, c2, &x, &y); + int64_t v = b2 - b1; + if (v % gcd == 0) { + x = v / gcd * x; + y = v / gcd * (-y); + int64_t coeff = c1 / gcd * c2; + return Entry(coeff, x * c1 + b1); + } else { + return Nothing(); + } } /*! * \brief Take GCD of a and b. @@ -311,9 +340,14 @@ class ModularSetAnalyzer::Impl : * \return Bound that represent everything dtype can represent. */ static Entry Everything() { - Entry ret; - ret.coeff = 1; ret.base = 0; - return ret; + return Entry(1, 0); + } + /*! + * \brief return an empty set + * \return Bound that represent everything dtype can represent. + */ + static Entry Nothing() { + return Entry(0, 1); } }; diff --git a/tests/python/unittest/test_arith_modular_set.py b/tests/python/unittest/test_arith_modular_set.py index 06ae5197b974..af60bc2152f0 100644 --- a/tests/python/unittest/test_arith_modular_set.py +++ b/tests/python/unittest/test_arith_modular_set.py @@ -117,6 +117,22 @@ def test_constraint_scope(): assert m.coeff == 1 assert m.base == 0 +def test_intersect(): + a = tvm.var("a") + analyzer = tvm.arith.Analyzer() + with analyzer.constraint_scope(a % 4 == 1): + with analyzer.constraint_scope(a % 3 == 1): + m = analyzer.modular_set(a) + assert m.coeff == 12 + assert m.base == 1 + + with analyzer.constraint_scope(a % 3 == 2): + with analyzer.constraint_scope(a % 5 == 3): + with analyzer.constraint_scope(a % 7 == 2): + m = analyzer.modular_set(a) + assert m.coeff == 105 + assert m.base == 23 + if __name__ == "__main__": test_cast() @@ -126,3 +142,4 @@ def test_constraint_scope(): test_min_max_select() test_mix_index() test_constraint_scope() + test_intersect()