Skip to content

Commit

Permalink
[ARITH] Analyzer CanonicalSimplifier (apache#2891)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and wweic committed Apr 11, 2019
1 parent a102547 commit af4a7a3
Show file tree
Hide file tree
Showing 17 changed files with 1,490 additions and 1,254 deletions.
36 changes: 36 additions & 0 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,47 @@ class RewriteSimplifier {
private:
friend class Analyzer;
friend class ConstraintContext;
friend class CanonicalSimplifier;
explicit RewriteSimplifier(Analyzer* parent);
~RewriteSimplifier();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};

/*!
* \brief Canonical-form based simplifier.
*/
class CanonicalSimplifier {
public:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr operator()(const Expr& expr);

/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const Expr& new_expr,
bool override = false);

private:
friend class Analyzer;
friend class ConstraintContext;
explicit CanonicalSimplifier(Analyzer* parent);
~CanonicalSimplifier();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};

/*!
* \brief A RAII constraint context.
*
Expand Down Expand Up @@ -277,6 +311,8 @@ class Analyzer {
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplfy */
RewriteSimplifier rewrite_simplify;
/*! \brief sub-analyzer rewrite simplfy */
CanonicalSimplifier canonical_simplify;
/*! \brief constructor */
Analyzer();
/*!
Expand Down
39 changes: 39 additions & 0 deletions include/tvm/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <string>
#include <memory>
#include <functional>
#include <utility>
#include "runtime/registry.h"

namespace tvm {
Expand All @@ -32,6 +33,44 @@ using ::tvm::AttrVisitor;
using ContainerType = NodeName; \
}; \

/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWNodeRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
inline NodeName* CopyOnWrite() { \
CHECK(node_ != nullptr); \
if (!node_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
NodePtr<Node>(std::move(n)).swap(node_); \
} \
return static_cast<NodeName*>(node_.get()); \
} \
using ContainerType = NodeName; \
};


/*!
* \brief save the node as well as all the node it depends on as json.
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(self):
self._bind = _mod("bind")
self._modular_set = _mod("modular_set")
self._rewrite_simplify = _mod("rewrite_simplify")
self._canonical_simplify = _mod("canonical_simplify")
self._enter_constraint_context = _mod("enter_constraint_context")

def const_int_bound(self, expr):
Expand Down Expand Up @@ -144,6 +145,21 @@ def rewrite_simplify(self, expr):
"""
return self._rewrite_simplify(expr)

def canonical_simplify(self, expr):
"""Simplify expression via canonicalization.
Parameters
----------
expr : tvm.Expr
The expression.
Returns
-------
result : Expr
The result.
"""
return self._canonical_simplify(expr)

def bind(self, var, expr):
"""Bind a variable to the expression.
Expand Down
4 changes: 4 additions & 0 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->rewrite_simplify(args[0]);
});
} else if (name == "canonical_simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->canonical_simplify(args[0]);
});
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto& sptr = args[1].node_sptr();
Expand Down
16 changes: 12 additions & 4 deletions src/arithmetic/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,21 @@ namespace arith {
Analyzer::Analyzer()
: const_int_bound(this),
modular_set(this),
rewrite_simplify(this) {
rewrite_simplify(this),
canonical_simplify(this) {
}

void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
Var var(v.node_);
this->const_int_bound.Update(var, this->const_int_bound(expr));
this->modular_set.Update(var, this->modular_set(expr));
this->rewrite_simplify.Update(var, this->rewrite_simplify(expr));

Expr new_expr = expr;
new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr);

this->const_int_bound.Update(var, this->const_int_bound(new_expr));
this->modular_set.Update(var, this->modular_set(new_expr));
this->rewrite_simplify.Update(var, new_expr);
this->canonical_simplify.Update(var, new_expr);
}

void Analyzer::Bind(const VarExpr& v, const Range& range) {
Expand Down Expand Up @@ -47,5 +54,6 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
if (bd->min_value >= lower_bound) return true;
return false;
}

} // namespace arith
} // namespace tvm
Loading

0 comments on commit af4a7a3

Please sign in to comment.