Skip to content

Commit

Permalink
Fix intersect of modular set
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Mar 27, 2019
1 parent 84cb712 commit 456f84c
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 62 deletions.
158 changes: 96 additions & 62 deletions src/arithmetic/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ struct ModularSetAnalyzer::Entry {
int64_t coeff{1};
int64_t base{0};

Entry() = default;

Entry(int64_t coeff, int64_t base) {
CHECK_GE(coeff, 0);
this->coeff = coeff;
if (coeff != 0) {
base = base % coeff;
if (base < 0) base += coeff;
}
this->base = base;
}

bool is_const() const {
return coeff == 0;
}
Expand All @@ -53,10 +65,7 @@ class ModularSetAnalyzer::Impl :
if (!override) {
CHECK(!var_map_.count(var));
}
Entry e;
e.coeff = info->coeff;
e.base = info->base;
var_map_[var] = e;
var_map_[var] = Entry(info->coeff, info->base);
}

// Detect useful constraints and use them in the analysis scope.
Expand All @@ -65,9 +74,7 @@ class ModularSetAnalyzer::Impl :
PVar<Integer> coeff, base;
// pattern match interesting constraints
if (((var % coeff) == base).Match(constraint)) {
Entry entry;
entry.coeff = coeff.Eval()->value;
entry.base = base.Eval()->value;
Entry entry(coeff.Eval()->value, base.Eval()->value);
return UpdateByIntersect(var.Eval(), entry);
}
return nullptr;
Expand All @@ -83,18 +90,12 @@ class ModularSetAnalyzer::Impl :
}

Entry VisitExpr_(const IntImm* op) final {
Entry ret;
ret.base = op->value;
ret.coeff = 0;
return ret;
return Entry(0, op->value);
}

Entry VisitExpr_(const UIntImm* op) final {
if (op->value < std::numeric_limits<int64_t>::max()) {
Entry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
return Entry(0, static_cast<int>(op->value));
} else {
return Everything();
}
Expand All @@ -103,19 +104,15 @@ class ModularSetAnalyzer::Impl :
Entry VisitExpr_(const Add* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base + b.base, ret.coeff);
return ret;
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
return Entry(coeff, a.base + b.base);
}

Entry VisitExpr_(const Sub* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base - b.base, ret.coeff);
return ret;
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
return Entry(coeff, a.base - b.base);
}

Entry VisitExpr_(const Mul* op) final {
Expand All @@ -128,10 +125,8 @@ class ModularSetAnalyzer::Impl :
int64_t pq = a.coeff * b.coeff;
int64_t pm = a.coeff * b.base;
int64_t qn = a.base * b.coeff;
Entry ret;
ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
ret.base = BaseSimplify(a.base * b.base, ret.coeff);
return ret;
int64_t coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
return Entry(coeff, a.base * b.base);
}

Entry DivByConst(const Expr& lhs,
Expand All @@ -140,20 +135,15 @@ class ModularSetAnalyzer::Impl :
Entry a = VisitExpr(lhs);
CHECK_NE(val, 0);
if (a.coeff % val == 0) {
Entry ret;
if (a.base == 0) {
// a c x / c -> a x
ret.coeff = std::abs(a.coeff / val);
ret.base = 0;
return ret;
return Entry(std::abs(a.coeff / val), 0);
}
// positive division have a clear rounding mode.
// Only handle case where we clearly know we need to round down.
if (a.base > 0 && val > 0 &&
(round_down || parent_->CanProveGreaterEqual(lhs, 0))) {
ret.coeff = a.coeff / val;
ret.base = a.base / val;
return ret;
return Entry(a.coeff / val, a.base / val);
}
}
return Everything();
Expand Down Expand Up @@ -251,41 +241,80 @@ class ModularSetAnalyzer::Impl :
}
int64_t base0 = a.base % coeff;
int64_t base1 = b.base % coeff;
Entry ret;
if (base0 == base1) {
ret.coeff = coeff;
ret.base = base0;
return ret;
return Entry(coeff, base0);
} else {
ret.coeff = ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff);
ret.base = 0;
return ret;
return Entry(ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff), base0);
}
}
/*!
* \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
* \param a The first coefficient.
* \param b The second coefficient.
* \param x The solution of x.
* \param y The solution of y.
* \return The GCD of a and b.
*/
static int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t& x, int64_t& y) {
// Extended Euclidean algorithm
// if a < 0, the problem can be convert into
// |a|* (-x) + b * y = gcd(|a|, b)
//
// initial condition:
// a * 0 + b * 1 = b
// a * 1 + b * 0 = a
int64_t s = 0, old_s = 1;
int64_t r = b, old_r = a >= 0 ? a : -a;
// Iteration (r2 < r1):
// a * x1 + b * y1 = r1
// a * x2 + b * y2 = r2
// The above two eqs can derive the following eq (q = r2 / r1)
// a * (x1 - x2 * q) + b * (y1 - y2 * q) = r2 - r1 * q = r3
// Because r3 < r2, the iteration can eventually terminate
while (r != 0) {
int64_t q = old_r / r;
int64_t tmp = old_r;
old_r = r;
r = tmp - q * r;
tmp = old_s;
old_s = s;
s = tmp - q * s;
}

x = a >= 0 ? old_s : -old_s;
if (b != 0) {
y = (old_r - x * a) / b;
} else {
y = 1;
}

return old_r;
}
/*!
* \brief Create interect of two sets.
* \param a The left operand.
* \param b the right operand.
*/
static Entry Intersect(Entry a, Entry b) {
// simple rule for now: pick higher constraints.
// TODO(team-team): Use extended euclidean algorithm.
if (a.coeff == 0) return a;
if (b.coeff == 0) return b;
if (a.coeff >= b.coeff) return a;
return b;
}
/*!
* \brief Simplify base so that it is in [0, coeff) when coeff != 0.
* \param base The base value.
* \param coeff The coeff value.
* \return The simplified base.
*/
static int64_t BaseSimplify(int64_t base, int64_t coeff) {
if (coeff == 0) return base;
base = base % coeff;
if (base < 0) base += coeff;
return base;
int64_t x, y;
int64_t c1 = a.coeff, b1 = a.base, c2 = b.coeff, b2 = b.base;
// z = c1 * p + b1
// z = c2 * q + b2
// c1 * x + c2 * y = gcd(c1, c2)
// -> c1 * p - c2 * q = b2 - b1
// -> p = (b2 - b1) / gcd * x
// -> q = (b2 - b1) / gcd * (-y)
// -> z = LCM(x, y) * k + (c1 * p + b1)
int64_t gcd = ExtendedEuclidean(c1, c2, x, y);
int64_t v = b2 - b1;
if (v % gcd == 0) {
x = v / gcd * x;
y = v / gcd * (-y);
int64_t coeff = c1 / gcd * c2;
return Entry(coeff, x * c1 + b1);
} else {
return Nothing();
}
}
/*!
* \brief Take GCD of a and b.
Expand All @@ -311,9 +340,14 @@ class ModularSetAnalyzer::Impl :
* \return Bound that represent everything dtype can represent.
*/
static Entry Everything() {
Entry ret;
ret.coeff = 1; ret.base = 0;
return ret;
return Entry(1, 0);
}
/*!
* \brief return an empty set
* \return Bound that represent everything dtype can represent.
*/
static Entry Nothing() {
return Entry(0, 1);
}
};

Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_arith_modular_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,22 @@ def test_constraint_scope():
assert m.coeff == 1
assert m.base == 0

def test_intersect():
a = tvm.var("a")
analyzer = tvm.arith.Analyzer()
with analyzer.constraint_scope(a % 4 == 1):
with analyzer.constraint_scope(a % 3 == 1):
m = analyzer.modular_set(a)
assert m.coeff == 12
assert m.base == 1

with analyzer.constraint_scope(a % 3 == 2):
with analyzer.constraint_scope(a % 5 == 3):
with analyzer.constraint_scope(a % 7 == 2):
m = analyzer.modular_set(a)
assert m.coeff == 105
assert m.base == 23


if __name__ == "__main__":
test_cast()
Expand All @@ -126,3 +142,4 @@ def test_constraint_scope():
test_min_max_select()
test_mix_index()
test_constraint_scope()
test_intersect()

0 comments on commit 456f84c

Please sign in to comment.