Skip to content

Commit

Permalink
[ARITH] Constraint-aware ConstIntBound, Enhance CanonicalSimplify (#3132
Browse files Browse the repository at this point in the history
)
  • Loading branch information
tqchen authored May 4, 2019
1 parent 8fb7f82 commit 48c9237
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 17 deletions.
17 changes: 15 additions & 2 deletions src/arithmetic/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -453,6 +453,9 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
if (const auto* op = expr.as<SplitExprNode>()) {
return GetRef<SplitExpr>(op);
}
if (const auto* op = expr.as<SumExprNode>()) {
if (op->base == 0 && op->args.size() == 1) return op->args[0];
}
if (const auto* op = expr.as_derived<CanonicalExprNode>()) {
expr = op->Normalize();
}
Expand Down Expand Up @@ -764,6 +767,16 @@ Mutate_(const Mod* op, const Expr& self) {
}
}
}
// Simplify the offset constant if necessary.
// (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0
auto cbound = parent_->const_int_bound(Normalize(a));
int64_t new_base = psum->base % cval;
if (cbound->min_value >= 0 &&
cbound->min_value - psum->base + new_base >= 0) {
SumExpr sum_expr(std::move(a.node_));
sum_expr.CopyOnWrite()->base = new_base;
return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval);
}
} else {
// if a >= 0 && a < cval, then result == 0
auto cbound = parent_->const_int_bound(Normalize(a));
Expand Down
77 changes: 74 additions & 3 deletions src/arithmetic/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -25,6 +25,7 @@
#include <tvm/ir_functor_ext.h>
#include <algorithm>
#include "int_op_overflow.h"
#include "pattern_match.h"

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -65,6 +66,19 @@ struct ConstIntBoundAnalyzer::Entry {
class ConstIntBoundAnalyzer::Impl :
public ExprFunctor<ConstIntBoundAnalyzer::Entry(const Expr&)> {
public:
/*! \brief additional bound info about expr \in bound */
struct BoundInfo {
/*! \brief The expr */
Expr expr;
/*! \brief The additional bound */
Entry bound;

BoundInfo() {}
BoundInfo(Expr expr, Entry bound)
: expr(expr), bound(bound) {
}
};

void Bind(const Var& var, const Range& range) {
Entry a = VisitExpr(range->min);
Entry b = VisitExpr(range->extent);
Expand Down Expand Up @@ -99,6 +113,18 @@ class ConstIntBoundAnalyzer::Impl :
static_cast<const ir::BaseExprNode*>(op)->type);
}

Entry VisitExpr(const Expr& expr) final {
Entry res = ExprFunctor::VisitExpr(expr);
// a linear search over additional info
// assume we won't have a lot of conditions
for (const BoundInfo& info : additional_info_) {
if (ir::Equal(expr, info.expr)) {
res = Intersect(res, info.bound);
}
}
return res;
}

Entry VisitExpr_(const Cast* op) final {
Entry a = VisitExpr(op->value);
Entry b = Everything(op->type);
Expand Down Expand Up @@ -243,9 +269,24 @@ class ConstIntBoundAnalyzer::Impl :
}
}

std::function<void()> EnterConstraint(const Expr& constraint) {
std::vector<BoundInfo> info = DetectBoundInfo(constraint);
if (info.size() == 0) return nullptr;
size_t old_size = additional_info_.size();
additional_info_.insert(additional_info_.end(), info.begin(), info.end());
size_t new_size = old_size + info.size();
auto frecover = [old_size, new_size, this]() {
CHECK_EQ(additional_info_.size(), new_size);
additional_info_.resize(old_size);
};
return frecover;
}

private:
// internal variable map
std::unordered_map<Var, Entry, ExprHash, ExprEqual> var_map_;
// additional bound info
std::vector<BoundInfo> additional_info_;
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
Expand Down Expand Up @@ -387,6 +428,36 @@ class ConstIntBoundAnalyzer::Impl :
}
return ret;
}

/*!
* \brief Detect additional constant bound from cond, if any
* \param cond The constraint condition.
* \return List of detected bounds.
*/
static std::vector<BoundInfo> DetectBoundInfo(const Expr& cond) {
PVar<Expr> x, y;
PVar<Integer> c;
// NOTE: canonical form always use <= or <
if ((c <= x).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, kPosInf))};
}
if ((c < x).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value + 1, kPosInf))};
}
if ((x <= c).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value))};
}
if ((x < c).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value - 1))};
}
if ((x && y).Match(cond)) {
auto ret1 = DetectBoundInfo(x.Eval());
auto ret2 = DetectBoundInfo(y.Eval());
ret1.insert(ret1.end(), ret2.begin(), ret2.end());
return ret1;
}
return {};
}
};

ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) {
Expand All @@ -405,7 +476,7 @@ void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) {
}

std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const Expr& constraint) {
return nullptr;
return impl_->EnterConstraint(constraint);
}

ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent)
Expand Down
4 changes: 2 additions & 2 deletions src/arithmetic/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down
66 changes: 57 additions & 9 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -1197,14 +1197,32 @@ Mutate_(const Or* op, const Expr& self) {

Expr RewriteSimplifier::Impl::
Mutate_(const Select* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Select>();
if (is_zero(op->condition)) {
return op->false_value;
Expr cond = Mutate(op->condition);
Expr true_value, false_value;
{
ConstraintContext constraint(parent_, cond);
true_value = Mutate(op->true_value);
}
{
ConstraintContext constraint(parent_, Mutate(Not::make(cond)));
false_value = Mutate(op->false_value);
}
if (is_zero(cond)) {
return false_value;
}
if (is_one(op->condition)) {
return op->true_value;
if (is_one(cond)) {
return true_value;
}
// normal path
Expr ret;
if (cond.same_as(op->condition) &&
true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
ret = self;
} else {
ret = Select::make(cond, true_value, false_value);
}
op = ret.as<Select>();
// Pattern var to match any expression
PVar<Expr> x, y;
TVM_TRY_REWRITE(select(x, y, y), y);
Expand All @@ -1213,7 +1231,37 @@ Mutate_(const Select* op, const Expr& self) {

Expr RewriteSimplifier::Impl::
Mutate_(const Call* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
// add condition context to if_then_else
Expr ret;
if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
Expr cond = Mutate(op->args[0]);
Expr true_value, false_value;
{
ConstraintContext constraint(parent_, cond);
true_value = Mutate(op->args[1]);
}
{
ConstraintContext constraint(parent_, Mutate(Not::make(cond)));
false_value = Mutate(op->args[2]);
}
if (is_zero(cond)) {
return false_value;
}
if (is_one(cond)) {
return true_value;
}
if (cond.same_as(op->args[0]) &&
true_value.same_as(op->args[1]) &&
false_value.same_as(op->args[2])) {
ret = self;
} else {
ret = Call::make(op->type, op->name,
{cond, true_value, false_value},
op->call_type);
}
} else {
ret = IRMutator::Mutate_(op, self);
}
op = ret.as<Call>();
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
return op->args[0];
Expand Down
33 changes: 32 additions & 1 deletion tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self):

def verify(self, data, expected):
res = self.analyzer.canonical_simplify(data)
assert tvm.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format(data, res, expected)
assert tvm.ir_pass.Equal(res, expected), "\ndata={}\nres={}\nexpected={}".format(data, res, expected)


def test_mul_sum_simplify():
Expand Down Expand Up @@ -157,7 +157,38 @@ def test_reduce_simplify():
ck.verify(tvm.sum(k / 10, k), tvm.sum(tvm.const(0, "int32"), k))


def test_simplify_if_then_else():
ck = CanonicalChecker()
x = tvm.var("x")
y = tvm.var("y")
# simplification that takes condition into account.
res = tvm.if_then_else((x * 4 + y) >= 466036,
tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528),
(((((x*4) + y) - 466036) % 24528) -24512) % 16,
x), y)
expected = tvm.if_then_else(
tvm.expr.LE(466036, (x * 4 + y)),
tvm.if_then_else(tvm.expr.LE(24512, ((((x*4) + y) - 4) % 24528)),
(((x*4) + y) - 4) % 16,
x), y)
ck.verify(res, expected)
# can only simplify if condition
res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 100) % 3, (x + 100) % 3)
expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 1) % 3, (x + 100) % 3)
ck.verify(res, ck.analyzer.canonical_simplify(expected))

res = tvm.expr.Select(x >= 10,
tvm.if_then_else(x / 3 > 2, x, 0), 0)
expected = tvm.expr.Select(x >= 10, x, 0)
ck.verify(res, ck.analyzer.canonical_simplify(expected))

res = tvm.expr.Select(x >= 10,
tvm.if_then_else(x / 3 < 2, x, 0), 0)
ck.verify(res, 0)


if __name__ == "__main__":
test_simplify_if_then_else()
test_div_simplify()
test_reduce_simplify()
test_reduce_combiner_simplify()
Expand Down

0 comments on commit 48c9237

Please sign in to comment.