Skip to content

Commit

Permalink
[ARITH] CanonicalSimplifier, better folding, eliminate store. (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and wweic committed Jul 11, 2019
1 parent f7d2c48 commit dac6883
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/arithmetic/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,14 @@ class SumExprNode : public CanonicalExprNode {
SplitExpr& rhs = args[j];
if (!lhs->IndexEqual(rhs)) break;
if (lhs->upper_factor < rhs->lower_factor) break;
if (lhs->lower_factor == rhs->upper_factor &&
lhs->scale % rhs->scale == 0 &&
lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor) {
if (lhs->upper_factor == rhs->upper_factor &&
lhs->lower_factor == rhs->lower_factor) {
// folding same co-efficient.
rhs.CopyOnWrite()->scale += lhs->scale;
lhs.CopyOnWrite()->scale = 0;
} else if (lhs->lower_factor == rhs->upper_factor &&
lhs->scale % rhs->scale == 0 &&
lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor) {
// Rules used in the proof:
//
// Rule 1: (x % (c * s)) / c = (x / c) % s
Expand Down
13 changes: 13 additions & 0 deletions src/arithmetic/stmt_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ class StmtSimplifier : public IRMutator {
}
}

// eliminate useless stores
Stmt Mutate_(const Store* op, const Stmt& s) {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
if (const Load* load = op->value.as<Load>()) {
if (load->buffer_var.same_as(op->buffer_var) &&
Equal(load->index, op->index)) {
return Evaluate::make(0);
}
}
return stmt;
}

protected:
Analyzer analyzer_;
// variable domain
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,18 @@ def test_simplify_if_then_else():
ck.verify(res, 0)


def test_complex_cases():
ck = CanonicalChecker()
x = tvm.var("x")
y = tvm.var("y")
res2 = (((((((((((x*128) + y) % 1296)/36)*2) + 1)/2)*36) +
((((((x*128) + y) % 36)*2) + 1)/2))
- (((x*128) + y) % 1296)) + 1)
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 5))
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 127))
ck.verify(res2, 1)


if __name__ == "__main__":
test_simplify_if_then_else()
test_div_simplify()
Expand All @@ -195,3 +207,4 @@ def test_simplify_if_then_else():
test_mul_sum_simplify()
test_split_index_simplify()
test_canonical_mixed()
test_complex_cases()

0 comments on commit dac6883

Please sign in to comment.