From 48c92376fb463114209fb0a6414e278d510ce02e Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 3 May 2019 21:07:14 -0400 Subject: [PATCH] [ARITH] Constraint-aware ConstIntBound, Enhance CanonicalSimplify (#3132) --- src/arithmetic/canonical_simplify.cc | 17 +++- src/arithmetic/const_int_bound.cc | 77 ++++++++++++++++++- src/arithmetic/modular_set.cc | 4 +- src/arithmetic/rewrite_simplify.cc | 66 +++++++++++++--- .../unittest/test_arith_canonical_simplify.py | 33 +++++++- 5 files changed, 180 insertions(+), 17 deletions(-) diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index d9b528291211..0feb00fc904b 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -453,6 +453,9 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { if (const auto* op = expr.as()) { return GetRef(op); } + if (const auto* op = expr.as()) { + if (op->base == 0 && op->args.size() == 1) return op->args[0]; + } if (const auto* op = expr.as_derived()) { expr = op->Normalize(); } @@ -764,6 +767,16 @@ Mutate_(const Mod* op, const Expr& self) { } } } + // Simplify the offset constant if necessary. + // (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0 + auto cbound = parent_->const_int_bound(Normalize(a)); + int64_t new_base = psum->base % cval; + if (cbound->min_value >= 0 && + cbound->min_value - psum->base + new_base >= 0) { + SumExpr sum_expr(std::move(a.node_)); + sum_expr.CopyOnWrite()->base = new_base; + return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval); + } } else { // if a >= 0 && a < cval, then result == 0 auto cbound = parent_->const_int_bound(Normalize(a)); diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index c591e58aa542..bfd06c8ba255 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,6 +25,7 @@ #include #include #include "int_op_overflow.h" +#include "pattern_match.h" namespace tvm { namespace arith { @@ -65,6 +66,19 @@ struct ConstIntBoundAnalyzer::Entry { class ConstIntBoundAnalyzer::Impl : public ExprFunctor { public: + /*! \brief additional bound info about expr \in bound */ + struct BoundInfo { + /*! \brief The expr */ + Expr expr; + /*! \brief The additional bound */ + Entry bound; + + BoundInfo() {} + BoundInfo(Expr expr, Entry bound) + : expr(expr), bound(bound) { + } + }; + void Bind(const Var& var, const Range& range) { Entry a = VisitExpr(range->min); Entry b = VisitExpr(range->extent); @@ -99,6 +113,18 @@ class ConstIntBoundAnalyzer::Impl : static_cast(op)->type); } + Entry VisitExpr(const Expr& expr) final { + Entry res = ExprFunctor::VisitExpr(expr); + // a linear search over additional info + // assume we won't have a lot of conditions + for (const BoundInfo& info : additional_info_) { + if (ir::Equal(expr, info.expr)) { + res = Intersect(res, info.bound); + } + } + return res; + } + Entry VisitExpr_(const Cast* op) final { Entry a = VisitExpr(op->value); Entry b = Everything(op->type); @@ -243,9 +269,24 @@ class ConstIntBoundAnalyzer::Impl : } } + std::function EnterConstraint(const Expr& constraint) { + std::vector info = DetectBoundInfo(constraint); + if (info.size() == 0) return nullptr; + size_t old_size = additional_info_.size(); + additional_info_.insert(additional_info_.end(), info.begin(), info.end()); + size_t new_size = old_size + info.size(); + auto frecover = [old_size, new_size, this]() { + CHECK_EQ(additional_info_.size(), new_size); + additional_info_.resize(old_size); + }; + return frecover; + } + private: // internal variable map std::unordered_map var_map_; + // additional bound info + std::vector additional_info_; // constants: the limit value means umlimited // NOTE: kNegInf/kPosInf are used to represent infinity. static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf; @@ -387,6 +428,36 @@ class ConstIntBoundAnalyzer::Impl : } return ret; } + + /*! + * \brief Detect additional constant bound from cond, if any + * \param cond The constraint condition. + * \return List of detected bounds. + */ + static std::vector DetectBoundInfo(const Expr& cond) { + PVar x, y; + PVar c; + // NOTE: canonical form always use <= or < + if ((c <= x).Match(cond)) { + return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, kPosInf))}; + } + if ((c < x).Match(cond)) { + return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value + 1, kPosInf))}; + } + if ((x <= c).Match(cond)) { + return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value))}; + } + if ((x < c).Match(cond)) { + return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value - 1))}; + } + if ((x && y).Match(cond)) { + auto ret1 = DetectBoundInfo(x.Eval()); + auto ret2 = DetectBoundInfo(y.Eval()); + ret1.insert(ret1.end(), ret2.begin(), ret2.end()); + return ret1; + } + return {}; + } }; ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) { @@ -405,7 +476,7 @@ void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) { } std::function ConstIntBoundAnalyzer::EnterConstraint(const Expr& constraint) { - return nullptr; + return impl_->EnterConstraint(constraint); } ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 5958233d6d52..7701e04844fa 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 6098faa44846..58d2b83a223a 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -1197,14 +1197,32 @@ Mutate_(const Or* op, const Expr& self) { Expr RewriteSimplifier::Impl:: Mutate_(const Select* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); - op = ret.as(); // Pattern var to match any expression PVar x, y; TVM_TRY_REWRITE(select(x, y, y), y); @@ -1213,7 +1231,37 @@ Mutate_(const Select* op, const Expr& self) { Expr RewriteSimplifier::Impl:: Mutate_(const Call* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); + // add condition context to if_then_else + Expr ret; + if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) { + Expr cond = Mutate(op->args[0]); + Expr true_value, false_value; + { + ConstraintContext constraint(parent_, cond); + true_value = Mutate(op->args[1]); + } + { + ConstraintContext constraint(parent_, Mutate(Not::make(cond))); + false_value = Mutate(op->args[2]); + } + if (is_zero(cond)) { + return false_value; + } + if (is_one(cond)) { + return true_value; + } + if (cond.same_as(op->args[0]) && + true_value.same_as(op->args[1]) && + false_value.same_as(op->args[2])) { + ret = self; + } else { + ret = Call::make(op->type, op->name, + {cond, true_value, false_value}, + op->call_type); + } + } else { + ret = IRMutator::Mutate_(op, self); + } op = ret.as(); if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) { return op->args[0]; diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 6af058523cd8..3e69f21fa2b2 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -22,7 +22,7 @@ def __init__(self): def verify(self, data, expected): res = self.analyzer.canonical_simplify(data) - assert tvm.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format(data, res, expected) + assert tvm.ir_pass.Equal(res, expected), "\ndata={}\nres={}\nexpected={}".format(data, res, expected) def test_mul_sum_simplify(): @@ -157,7 +157,38 @@ def test_reduce_simplify(): ck.verify(tvm.sum(k / 10, k), tvm.sum(tvm.const(0, "int32"), k)) +def test_simplify_if_then_else(): + ck = CanonicalChecker() + x = tvm.var("x") + y = tvm.var("y") + # simplification that takes condition into account. + res = tvm.if_then_else((x * 4 + y) >= 466036, + tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528), + (((((x*4) + y) - 466036) % 24528) -24512) % 16, + x), y) + expected = tvm.if_then_else( + tvm.expr.LE(466036, (x * 4 + y)), + tvm.if_then_else(tvm.expr.LE(24512, ((((x*4) + y) - 4) % 24528)), + (((x*4) + y) - 4) % 16, + x), y) + ck.verify(res, expected) + # can only simplify if condition + res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 100) % 3, (x + 100) % 3) + expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 1) % 3, (x + 100) % 3) + ck.verify(res, ck.analyzer.canonical_simplify(expected)) + + res = tvm.expr.Select(x >= 10, + tvm.if_then_else(x / 3 > 2, x, 0), 0) + expected = tvm.expr.Select(x >= 10, x, 0) + ck.verify(res, ck.analyzer.canonical_simplify(expected)) + + res = tvm.expr.Select(x >= 10, + tvm.if_then_else(x / 3 < 2, x, 0), 0) + ck.verify(res, 0) + + if __name__ == "__main__": + test_simplify_if_then_else() test_div_simplify() test_reduce_simplify() test_reduce_combiner_simplify()