Skip to content

Commit

Permalink
[ARITH] Analyzer RewriteSimplifier: add/sub/mul/div/mod (apache#2722)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and wweic committed Mar 12, 2019
1 parent f49c002 commit e12a8dc
Show file tree
Hide file tree
Showing 9 changed files with 1,016 additions and 8 deletions.
35 changes: 35 additions & 0 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,39 @@ class ModularSetAnalyzer {
Impl* impl_;
};

/*!
* \brief Rewrite-rule based simplifier.
*/
class RewriteSimplifier {
public:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr operator()(const Expr& expr);

/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const Expr& new_expr,
bool override = false);

private:
friend class Analyzer;
friend class ConstraintContext;
explicit RewriteSimplifier(Analyzer* parent);
~RewriteSimplifier();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};

/*!
* \brief A RAII constraint context.
*
Expand Down Expand Up @@ -242,6 +275,8 @@ class Analyzer {
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplfy */
RewriteSimplifier rewrite_simplify;
/*! \brief constructor */
Analyzer();
/*!
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(self):
self._const_int_bound_update = _mod("const_int_bound_update")
self._bind = _mod("bind")
self._modular_set = _mod("modular_set")
self._rewrite_simplify = _mod("rewrite_simplify")
self._enter_constraint_context = _mod("enter_constraint_context")

def const_int_bound(self, expr):
Expand Down Expand Up @@ -128,6 +129,21 @@ def modular_set(self, expr):
"""
return self._modular_set(expr)

def rewrite_simplify(self, expr):
"""Simplify expression via rewriting rules.
Parameters
----------
expr : tvm.Expr
The expression.
Returns
-------
result : Expr
The result.
"""
return self._rewrite_simplify(expr)

def bind(self, var, expr):
"""Bind a variable to the expression.
Expand Down
4 changes: 4 additions & 0 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
self->const_int_bound.Update(args[0], args[1], args[2]);
});
} else if (name == "rewrite_simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->rewrite_simplify(args[0]);
});
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto& sptr = args[1].node_sptr();
Expand Down
11 changes: 9 additions & 2 deletions src/arithmetic/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,30 @@
* Copyright (c) 2019 by Contributors
* \file tvm/arithmetic/analyzer.cc
*/
#include <tvm/ir.h>
#include <tvm/arithmetic.h>

namespace tvm {
namespace arith {

Analyzer::Analyzer()
: const_int_bound(this),
modular_set(this) {
modular_set(this),
rewrite_simplify(this) {
}

void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
Var var(v.node_);
this->const_int_bound.Update(var, this->const_int_bound(expr));
this->modular_set.Update(var, this->modular_set(expr));
this->rewrite_simplify.Update(var, this->rewrite_simplify(expr));
}

void Analyzer::Bind(const VarExpr& v, const Range& range) {
Var var(v.node_);
this->const_int_bound.Bind(var, range);
// skip modular_set
// skip rewrite simplify
}

ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) {
Expand All @@ -36,7 +40,10 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint)
}

bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
auto bd = this->const_int_bound(expr);
if (const auto* ptr = expr.as<ir::IntImm>()) {
return ptr->value > lower_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->min_value >= lower_bound) return true;
return false;
}
Expand Down
4 changes: 3 additions & 1 deletion src/arithmetic/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ namespace arith {
* \return nullptr if constant fold fails, otherwise return folded result.
*/
template<typename Op>
inline Expr TryConstFold(Expr a, Expr b);
inline Expr TryConstFold(Expr a, Expr b) {
return Expr();
}

/*!
* \brief Try to run unary compute with constant folding.
Expand Down
51 changes: 46 additions & 5 deletions src/arithmetic/pattern_match.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

#include <tvm/ir_pass.h>
#include <tuple>
#include "const_fold.h"

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -242,20 +243,60 @@ class PBinaryExpr :
}

Expr Eval() const {
return NodeType::make(a_.Eval(), b_.Eval());
Expr lhs = a_.Eval();
Expr rhs = b_.Eval();
Expr ret = TryConstFold<NodeType>(lhs, rhs);
if (ret.defined()) return ret;
return NodeType::make(lhs, rhs);
}

private:
typename TA::Nested a_;
typename TB::Nested b_;
};

template<typename TA>
class PConstWithTypeLike :
public Pattern<PConstWithTypeLike<TA> > {
public:
PConstWithTypeLike(const TA& ref, int64_t value)
: ref_(ref), value_(value) {}

void InitMatch_() const {}

bool Match_(const NodeRef& node) const {
if (const ir::IntImm* ptr = node.as<ir::IntImm>()) {
return ptr->value == value_;
} else {
return false;
}
}

Expr Eval() const {
return make_const(ref_.Eval().type(), value_);
}

private:
typename TA::Nested ref_;
int64_t value_;
};


#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
template<typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
template<typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> > \
FuncName(const Pattern<TA>& a, int64_t b) { \
return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> \
FuncName(int64_t b, const Pattern<TA>& a) { \
return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
}

// arithmetic expressions
Expand Down
Loading

0 comments on commit e12a8dc

Please sign in to comment.