diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index d023f8f1cf7e..07004937f621 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -218,6 +218,7 @@ class RewriteSimplifier { private: friend class Analyzer; friend class ConstraintContext; + friend class CanonicalSimplifier; explicit RewriteSimplifier(Analyzer* parent); ~RewriteSimplifier(); class Impl; @@ -225,6 +226,39 @@ class RewriteSimplifier { Impl* impl_; }; +/*! + * \brief Canonical-form based simplifier. + */ +class CanonicalSimplifier { + public: + /*! + * \brief analyze the expr + * \param expr The expression of interest. + * \return the result of the analysis. + */ + Expr operator()(const Expr& expr); + + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param new_expr + * \param override Whether do we allow override of existing information. + */ + void Update(const Var& var, + const Expr& new_expr, + bool override = false); + + private: + friend class Analyzer; + friend class ConstraintContext; + explicit CanonicalSimplifier(Analyzer* parent); + ~CanonicalSimplifier(); + class Impl; + /*! \brief Internal impl */ + Impl* impl_; +}; + /*! * \brief A RAII constraint context. * @@ -277,6 +311,8 @@ class Analyzer { ModularSetAnalyzer modular_set; /*! \brief sub-analyzer rewrite simplfy */ RewriteSimplifier rewrite_simplify; + /*! \brief sub-analyzer rewrite simplfy */ + CanonicalSimplifier canonical_simplify; /*! \brief constructor */ Analyzer(); /*! diff --git a/include/tvm/base.h b/include/tvm/base.h index 7104688aa169..863bde52e2a5 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -12,6 +12,7 @@ #include #include #include +#include #include "runtime/registry.h" namespace tvm { @@ -32,6 +33,44 @@ using ::tvm::AttrVisitor; using ContainerType = NodeName; \ }; \ +/*! + * \brief Macro to make it easy to define node ref type that + * has a CopyOnWrite member function. + * + * CopyOnWrite will generate a unique copy of the internal node. + * The node will be copied if it is referenced by multiple places. + * The function returns the raw pointer to the node to allow modification + * of the content. + * + * \code + * + * MyCOWNodeRef ref, ref2; + * ref2 = ref; + * ref.CopyOnWrite()->value = new_value; + * assert(ref2->value == old_value); + * assert(ref->value == new_value); + * + * \endcode + */ +#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ + class TypeName : public BaseType { \ + public: \ + TypeName() {} \ + explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {} \ + const NodeName* operator->() const { \ + return static_cast(node_.get()); \ + } \ + inline NodeName* CopyOnWrite() { \ + CHECK(node_ != nullptr); \ + if (!node_.unique()) { \ + NodePtr n = make_node(*(operator->())); \ + NodePtr(std::move(n)).swap(node_); \ + } \ + return static_cast(node_.get()); \ + } \ + using ContainerType = NodeName; \ + }; + /*! * \brief save the node as well as all the node it depends on as json. diff --git a/python/tvm/arith.py b/python/tvm/arith.py index 3981a4815aeb..85560afd0694 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -97,6 +97,7 @@ def __init__(self): self._bind = _mod("bind") self._modular_set = _mod("modular_set") self._rewrite_simplify = _mod("rewrite_simplify") + self._canonical_simplify = _mod("canonical_simplify") self._enter_constraint_context = _mod("enter_constraint_context") def const_int_bound(self, expr): @@ -144,6 +145,21 @@ def rewrite_simplify(self, expr): """ return self._rewrite_simplify(expr) + def canonical_simplify(self, expr): + """Simplify expression via canonicalization. + + Parameters + ---------- + expr : tvm.Expr + The expression. + + Returns + ------- + result : Expr + The result. + """ + return self._canonical_simplify(expr) + def bind(self, var, expr): """Bind a variable to the expression. diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index cc7d814617a9..f4b673533d8f 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -102,6 +102,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer") return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { *ret = self->rewrite_simplify(args[0]); }); + } else if (name == "canonical_simplify") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + *ret = self->canonical_simplify(args[0]); + }); } else if (name == "bind") { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { auto& sptr = args[1].node_sptr(); diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 81195eba2747..da30dc2a3a7b 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -11,14 +11,21 @@ namespace arith { Analyzer::Analyzer() : const_int_bound(this), modular_set(this), - rewrite_simplify(this) { + rewrite_simplify(this), + canonical_simplify(this) { } void Analyzer::Bind(const VarExpr& v, const Expr& expr) { Var var(v.node_); - this->const_int_bound.Update(var, this->const_int_bound(expr)); - this->modular_set.Update(var, this->modular_set(expr)); - this->rewrite_simplify.Update(var, this->rewrite_simplify(expr)); + + Expr new_expr = expr; + new_expr = this->canonical_simplify(new_expr); + new_expr = this->rewrite_simplify(new_expr); + + this->const_int_bound.Update(var, this->const_int_bound(new_expr)); + this->modular_set.Update(var, this->modular_set(new_expr)); + this->rewrite_simplify.Update(var, new_expr); + this->canonical_simplify.Update(var, new_expr); } void Analyzer::Bind(const VarExpr& v, const Range& range) { @@ -47,5 +54,6 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { if (bd->min_value >= lower_bound) return true; return false; } + } // namespace arith } // namespace tvm diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc deleted file mode 100644 index 77c44f184e0f..000000000000 --- a/src/arithmetic/canonical.cc +++ /dev/null @@ -1,938 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file canonical.cc - * \brief Canonicalize simplification. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "canonical.h" -#include "compute_expr.h" -#include "arithmetic/Simplify.h" - -namespace tvm { -namespace arith { -using namespace ir; - -// Canonical entry for communicative ops. -struct ComExprEntry { - // the value of the expression. - Expr value; - // the level of the expression. - int level{0}; - // The integer scale on value - int64_t scale{1}; - - ComExprEntry() {} - ComExprEntry(Expr value, int level) - : value(value), level(level) {} - inline bool operator<(const ComExprEntry& other) const { - if (level < other.level) return true; - if (level > other.level) return false; - // compare top operator of entries and sort on that if possible (fast check) - if (value.type_index() < other.value.type_index()) return true; - if (value.type_index() > other.value.type_index()) return false; - // if none of the above distinguishes the terms, compare the expression tree of the entries. - // This is a slower check. - int compare_result = Compare(value, other.value); - if (compare_result < 0) return true; - if (compare_result > 0) return false; - // it's a problem if we see identical entries at this point. They should've been merged earlier. - LOG(WARNING) << "we should not have identical entries at this point"; - return false; - } -}; - -// canonical expression for communicative expression. -struct ComExprNode : public NodeBase { - // base constant value. - int64_t base{0}; - // The values to be sumed. - std::vector elem; -}; - -// canonical communicative expression -struct ComExpr { - public: - // constructor - ComExpr() {} - explicit ComExpr(NodePtr ptr) : ptr_(ptr) {} - // get member - ComExprNode* operator->() const { - return ptr_.get(); - } - void reset() { - ptr_.reset(); - } - bool defined() const { - return ptr_.get() != nullptr; - } - // comparator - bool operator<(const ComExpr& b) const { - const ComExpr& a = *this; - if (a->base < b->base) return true; - if (a->base > b->base) return false; - if (a->elem.size() < b->elem.size()) return true; - if (a->elem.size() > b->elem.size()) return false; - for (size_t i = 0; i < a->elem.size(); ++i) { - const ComExprEntry& ea = a->elem[i]; - const ComExprEntry& eb = b->elem[i]; - if (ea.level < eb.level) return true; - if (ea.level > eb.level) return false; - if (ea.value.get() < eb.value.get()) return true; - if (ea.value.get() > eb.value.get()) return false; - if (ea.scale < eb.scale) return true; - if (ea.scale > eb.scale) return false; - } - return false; - } - // equality - bool operator==(const ComExpr& b) const { - const ComExpr& a = *this; - if (a->base != b->base) return false; - if (a->elem.size() != b->elem.size()) return false; - for (size_t i = 0; i < a->elem.size(); ++i) { - const ComExprEntry& ea = a->elem[i]; - const ComExprEntry& eb = b->elem[i]; - if (ea.level != eb.level) return false; - if (ea.value.get() != eb.value.get()) return false; - if (ea.scale != eb.scale) return false; - } - return true; - } - - private: - NodePtr ptr_; -}; - -// binary comparison op. -struct BinaryExpr { - int kind; - Expr lhs, rhs; - // comparator - bool operator<(const BinaryExpr& b) const { - if (kind < b.kind) return true; - if (kind > b.kind) return false; - if (lhs.get() < b.lhs.get()) return true; - if (lhs.get() > b.lhs.get()) return false; - return rhs.get() < b.rhs.get(); - } - // equality - bool operator==(const BinaryExpr& b) const { - return kind == b.kind && - lhs.same_as(b.lhs) && - rhs.same_as(b.rhs); - } -}; - - -template -inline Expr Binary_(const T* op, - const Expr& e, - Expr a, Expr b) { - if (a.same_as(op->a) && b.same_as(op->b)) { - return e; - } else { - return T::make(a, b); - } -} - -// internal of canonical engine. -class Canonical::Internal : public IRMutator { - public: - explicit Internal(Map vrange) { - for (auto kv : vrange) { - SetRange(kv.first, kv.second, 0); - } - } - // stack entry. - struct StackEntry { - int max_level{0}; - bool has_side_effect{false}; - }; - // aggressively canonicalized expression - struct CacheEntry { - // The canonical value of the expression. - Expr value; - // The level of the expression. - int max_level{0}; - // whether the expression might have side effect. - bool has_side_effect{false}; - // if not null, corresponds to to sum - ComExpr sum; - // reset the return entry. - void reset() { - sum.reset(); - } - // as sum expr - ComExpr AsSum() const { - if (sum.defined()) return sum; - const int64_t *v1 = as_const_int(value); - const uint64_t *v2 = as_const_uint(value); - auto n = make_node(); - if (v1) { - n->base = *v1; - } else if (v2) { - CHECK_LE(*v2, - static_cast(std::numeric_limits::max())); - n->base = static_cast(*v2); - } else { - n->elem.push_back(ComExprEntry(value, max_level)); - } - return ComExpr(n); - } - }; - // Set range and level of var. - void SetRange(Var v, Range r, int level) { - var_range_[v.get()] = IntSet::range(r); - var_level_[v.get()] = level; - var_rec_.push_back(v); - } - // functions - Stmt Mutate(Stmt stmt) final { - stmt = IRMutator::Mutate(stmt); - return stmt; - } - Expr MutateExpr_(Expr expr) { - stack_.push_back(StackEntry()); - expr = IRMutator::Mutate(expr); - // update result of parent automatically during pop - if (stack_.size() > 1) { - StackEntry& back = stack_[stack_.size() - 1]; - StackEntry& prev = stack_[stack_.size() - 2]; - prev.max_level = std::max(prev.max_level, back.max_level); - if (back.has_side_effect) prev.has_side_effect = true; - } - // copy result from stack - ret_entry_.has_side_effect = stack_.back().has_side_effect; - ret_entry_.max_level = stack_.back().max_level; - stack_.pop_back(); - CHECK(expr.defined()); - if (const IntImm* op = expr.as()) { - return Mutate_(op, expr); - } - return expr; - } - // call produce to get a cache entry. - CacheEntry Produce(Expr expr) { - ret_entry_.reset(); - ret_entry_.value = MutateExpr_(expr); - CacheEntry ret = ret_entry_; - ret_entry_.reset(); - return ret; - } - Expr Mutate(Expr expr) final { - ret_entry_.reset(); - expr = MutateExpr_(expr); - ret_entry_.reset(); - return expr; - } - - // Check whether do special canonicalization. - bool EnableOpt(Type t) const { - return (t.lanes() == 1 && (t.is_int() || t.is_uint())); - } - // Max - Expr Mutate_(const Max* op, const Expr& e) final { - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - return Binary(op, e); - } - // Min - Expr Mutate_(const Min* op, const Expr& e) final { - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - return Binary(op, e); - } - // Add - Expr Mutate_(const Add* op, const Expr& e) final { - if (!EnableOpt(op->type)) { - return Binary(op, e); - } - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - return SumAdd(a, b, +1); - } - // Sub - Expr Mutate_(const Sub* op, const Expr& e) final { - if (!EnableOpt(op->type)) { - return Binary(op, e); - } - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - return SumAdd(a, b, -1); - } - // Mul - Expr Mutate_(const Mul* op, const Expr& e) final { - if (!EnableOpt(op->type)) { - return Binary(op, e); - } - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - if (is_const(a.value) && is_const(b.value)) { - return ComputeExpr(a.value, b.value); - } else if (is_const(a.value)) { - return SumMulConst(b.AsSum(), a.value); - } else if (is_const(b.value)) { - return SumMulConst(a.AsSum(), b.value); - } else { - return Binary(op, e); - } - } - // Variable - Expr Mutate_(const Variable* op, const Expr& e) final { - auto it = var_level_.find(op); - if (it != var_level_.end()) { - stack_.back().max_level = it->second; - } - return IRMutator::Mutate_(op, e); - } - // comparison - Expr Mutate_(const LT* op, const Expr& e) { - if (!EnableOpt(op->a.type())) { - return Binary(op, e); - } - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - Expr b_sub_a = SumAdd(b, a, -1); - if (EvalSet(b_sub_a, var_range_).can_prove_positive()) { - return make_const(op->type, true); - } else { - return Binary_(op, e, a.value, b.value); - } - } - // IntImm - Expr Mutate_(const IntImm* op, const Expr& e) final { - if (op->type != Int(32)) return e; - auto it = cache_intimm_.find(op->value); - if (it != cache_intimm_.end()) { - return it->second; - } else { - cache_intimm_[op->value] = e; - return e; - } - } - // Div operator - Expr Mutate_(const Div* op, const Expr& e) final { - if (!EnableOpt(op->type)) { - return Binary(op, e); - } - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - if (is_const(a.value) && is_const(b.value)) { - return ComputeExpr
(a.value, b.value); - } else if (is_const(b.value)) { - return SumDivConst(a.AsSum(), b.value); - } else { - return Binary(op, e); - } - } - // Mod operator - Expr Mutate_(const Mod* op, const Expr& e) final { - if (!EnableOpt(op->type)) { - return Binary(op, e); - } - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - if (is_const(a.value) && is_const(b.value)) { - return ComputeExpr(a.value, b.value); - } else if (is_const(b.value)) { - return SumModConst(a.AsSum(), b.value); - } else { - return Binary(op, e); - } - } - - Expr Mutate_(const And* op, const Expr& e) final { - Expr expr = IRMutator::Mutate_(op, e); - op = expr.as(); - if (is_one(op->a)) return op->b; - if (is_one(op->b)) return op->a; - return expr; - } - // Call - Expr Mutate_(const Call* op, const Expr& e) final { - if (!op->is_pure()) { - stack_.back().has_side_effect = true; - } - Expr expr = IRMutator::Mutate_(op, e); - op = expr.as(); - if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) { - return op->args[0]; - } else { - return expr; - } - } - // For - Stmt Mutate_(const For* op, const Stmt& s) { - ++level_counter_; - Var loop_var(op->loop_var.node_); - this->SetRange(loop_var, - Range::make_by_min_extent(op->min, op->extent), - level_counter_); - Stmt stmt = IRMutator::Mutate_(op, s); - --level_counter_; - return stmt; - } - // IfThenElse - Stmt Mutate_(const IfThenElse* op, const Stmt& s) { - Stmt stmt = IRMutator::Mutate_(op, s); - op = stmt.as(); - if (is_one(op->condition)) return op->then_case; - return stmt; - } - // AttrStmt - Stmt Mutate_(const AttrStmt* op, const Stmt& s) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { - ++level_counter_; - IterVar iv(op->node.node_); - CHECK_NE(iv->thread_tag.length(), 0U); - if (!var_level_.count(iv->var.get())) { - this->SetRange(iv->var, - Range::make_by_min_extent(0, op->value), - level_counter_); - } - Stmt stmt = IRMutator::Mutate_(op, s); - --level_counter_; - return stmt; - } else { - return IRMutator::Mutate_(op, s); - } - } - // The simplify statement. - static FMutateExpr& vtable_expr() { // NOLINT(*) - static FMutateExpr inst; return inst; - } - - private: - template - Expr Binary(const T* op, Expr e) { - Expr a = this->Mutate(op->a); - Expr b = this->Mutate(op->b); - BinaryExpr key{static_cast(T::_type_info), a, b}; - auto it = cache_binary_.find(key); - if (it != cache_binary_.end()) { - return it->second; - } else { - Expr ret = Binary_(op, e, a, b); - cache_binary_[key] = ret; - return ret; - } - } - // return entry - CacheEntry ret_entry_; - // internal information stack - std::vector stack_; - // cache sum - std::map cache_sum_; - // cache of normal binary op - std::map cache_binary_; - // cache of int constant - std::unordered_map cache_intimm_; - // range of each var - std::unordered_map var_range_; - // level of each var - std::unordered_map var_level_; - // record history vars, to avoid false positive. - std::vector var_rec_; - // level counter - int level_counter_{0}; - // get constant int value - int64_t GetConstIntValue(const Expr& v) { - int64_t value = 0; - const int64_t *v1 = as_const_int(v); - const uint64_t *v2 = as_const_uint(v); - CHECK(v1 || v2); - if (v1) { - value = *v1; - } else if (v2) { - CHECK_LE(*v2, - static_cast(std::numeric_limits::max())); - value = static_cast(*v2); - } - return value; - } - // Detect if a = q * coeff + r, where r \in [0, coeff), coeff > 0 - // (in Euclidean division) - // returns pair (q, r) if such detection is successful - // returns empty vector otherwise. - // Assumes that coeff is a constant integer - std::vector TryLinearEquation(const ComExpr& a, - const Expr& coeff) { - Type type = coeff.type(); - int64_t value = GetConstIntValue(coeff); - CHECK_NE(value, 0); - if (value < 0) return {}; - // Given that denominator (value variable) is positive, truncated division - // (i.e., TVM's division semantics) is equivalent to Euclidean division if and only if - // numerator is non-negative or numerator is divisible by denominator (i.e., value) - IntSet numerator_int_set = EvalSet(Sum2Expr(a, type), var_range_); - bool numerator_is_non_neg = numerator_int_set.can_prove_non_negative(); - // Try to separate terms of a into ones that can be proven to be - // divisible by coeff and ones that are not - // We will build q and r from divisible and non_divisible respectively - auto divisible = make_node(); - auto non_divisible = make_node(); - if (a->base % value == 0) { - divisible->base = a->base; - } else { - non_divisible->base = a->base; - } - for (const auto& e : a->elem) { - if (e.scale % value == 0) { - divisible->elem.push_back(e); - } else { - non_divisible->elem.push_back(e); - } - } - bool non_divisible_is_simplified = false; - int64_t div_result; - Expr non_divisible_res = Sum2Expr(ComExpr(non_divisible), type); - // if non_divisible part consists of only an integer and numerator is non-negative, - // we can simply divide it by coeff - if (is_const(non_divisible_res)) { - int64_t non_divisible_const = GetConstIntValue(non_divisible_res); - if (numerator_is_non_neg || non_divisible_const == 0) { - non_divisible_is_simplified = true; - // We need to do an Euclidean division here because (a*b + c)/b == a + c/b - // holds true only if division is Euclidean - div_result = HalideIR::Internal::div_imp(non_divisible_const , value); - } - } else { - // If we can prove that non_divisible part lies within [0, coeff), then - // non_divisible itself will be our r - IntSet non_divisible_set = EvalSet(non_divisible_res, var_range_); - if (non_divisible_set.min().type() == type && - non_divisible_set.max().type() == type) { - if ( (non_divisible_set.is_single_point() && - can_prove(non_divisible_set.point_value() == 0)) || - (numerator_is_non_neg && - can_prove(non_divisible_set.min() >= make_zero(type)) && - can_prove(non_divisible_set.max() < coeff)) ) { - non_divisible_is_simplified = true; - div_result = 0; - } - } - } - if (non_divisible_is_simplified) { - non_divisible->base -= div_result * value; - divisible->base /= value; - divisible->base += div_result; - for (auto& e : divisible->elem) { - e.scale /= value; - } - return {ComExpr(divisible), ComExpr(non_divisible)}; - } else { - return {}; - } - } - // subroutine to do produce a % v - Expr SumModConst(ComExpr a, Expr v) { - std::vector pair = TryLinearEquation(a, v); - if (pair.size() == 0) { - int64_t value = GetConstIntValue(v); - auto n = make_node(); - // FIXME(derisavi) : The following can be done only for Euclidean division/mod. - // Therefore, it's only valid when truncated division/mod is equivalent to Euclidean one, - // that is, if and only if a and v are - // both negative or both positive or a is divisible by v. - // Extend the code to handle cases where the above condition is not satisfied, i.e., - // a and v are of different signs and a is not divisible by v. - n->base = a->base % value; - for (auto e : a->elem) { - if (e.scale % value == 0) continue; - e.scale = e.scale % value; - n->elem.push_back(e); - } - Expr ret = Sum2Expr(ComExpr(n), v.type()) % v; - if (const Mod* mod = ret.as()) { - return Binary(mod, ret); - } else { - // Sometimes the result is a constant, this may happen when value is -1 - CHECK(is_const(ret)) << "CanonicalSimplify: " - << Sum2Expr(ComExpr(n), v.type()) << " % " << v << " is " << ret - << " which is neither Mod, nor a constant"; - return ret; - } - } - ret_entry_.sum = pair[1]; - ret_entry_.max_level = stack_.back().max_level; - ret_entry_.has_side_effect = stack_.back().has_side_effect; - auto it = cache_sum_.find(ret_entry_.sum); - if (it != cache_sum_.end()) { - ret_entry_ = it->second; - } else { - ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type()); - cache_sum_[ret_entry_.sum] = ret_entry_; - } - return ret_entry_.value; - } - // subroutine to do produce a % v - Expr SumDivConst(ComExpr a, Expr v) { - std::vector pair = TryLinearEquation(a, v); - if (pair.size() == 0) { - Expr ret = Sum2Expr(a, v.type()) / v; - return Binary(ret.as
(), ret); - } - ret_entry_.sum = pair[0]; - ret_entry_.max_level = stack_.back().max_level; - ret_entry_.has_side_effect = stack_.back().has_side_effect; - auto it = cache_sum_.find(ret_entry_.sum); - if (it != cache_sum_.end()) { - ret_entry_ = it->second; - } else { - ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type()); - cache_sum_[ret_entry_.sum] = ret_entry_; - } - return ret_entry_.value; - } - // subroutine to do produce - Expr SumMulConst(ComExpr a, Expr v) { - int64_t value = GetConstIntValue(v); - if (value == 0) { - return make_zero(v.type()); - } - auto vsum = make_node(*a.operator->()); - vsum->base *= value; - for (auto& e : vsum->elem) { - e.scale *= value; - } - ret_entry_.sum = ComExpr(vsum); - ret_entry_.max_level = stack_.back().max_level; - ret_entry_.has_side_effect = stack_.back().has_side_effect; - auto it = cache_sum_.find(ret_entry_.sum); - if (it != cache_sum_.end()) { - ret_entry_ = it->second; - } else { - ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type()); - cache_sum_[ret_entry_.sum] = ret_entry_; - } - return ret_entry_.value; - } - // add two ComExpr together - ComExpr SumAdd_(const ComExpr& suma, - const ComExpr& sumb, - int bscale) { - auto n = make_node(); - n->base = suma->base + sumb->base * bscale; - // merge of suma and sumb; - size_t i = 0, j = 0; - while (i < suma->elem.size() && j < sumb->elem.size()) { - const auto& a = suma->elem[i]; - const auto& b = sumb->elem[j]; - if (a.value.same_as(b.value) && a.level == b.level) { - ComExprEntry e = a; - e.scale = a.scale + b.scale * bscale; - if (e.scale != 0) { - n->elem.push_back(e); - } - ++i; ++j; - } else if (a < b) { - n->elem.push_back(a); - ++i; - } else { - ComExprEntry e = b; - e.scale *= bscale; - n->elem.push_back(e); - ++j; - } - } - for (; i < suma->elem.size(); ++i) { - n->elem.push_back(suma->elem[i]); - } - for (; j < sumb->elem.size(); ++j) { - ComExprEntry e = sumb->elem[j]; - e.scale *= bscale; - n->elem.push_back(e); - } - return ComExpr(n); - } - // subroutine to do produce - Expr SumAdd(CacheEntry a, CacheEntry b, int bscale) { - ret_entry_.sum = SumAdd_(a.AsSum(), b.AsSum(), bscale); - CHECK_NE(stack_.size(), 0U); - ret_entry_.max_level = stack_.back().max_level; - ret_entry_.has_side_effect = stack_.back().has_side_effect; - auto it = cache_sum_.find(ret_entry_.sum); - if (it != cache_sum_.end()) { - ret_entry_ = it->second; - } else { - ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type()); - cache_sum_[ret_entry_.sum] = ret_entry_; - } - return ret_entry_.value; - } - // convert sum to expr - Expr Sum2Expr(const ComExpr& com, Type t) { - Expr vsum; - if (com->base > 0) { - vsum = make_const(t, com->base); - } - for (const ComExprEntry& e : com->elem) { - if (e.scale > 0) { - Expr v = e.value; - if (e.scale != 1) { - v = Mul::make(v, make_const(t, e.scale)); - } - if (vsum.defined()) { - vsum = Add::make(vsum, v); - } else { - vsum = v; - } - } - } - if (com->base < 0) { - if (vsum.defined()) { - vsum = Sub::make(vsum, make_const(t, -com->base)); - } else { - vsum = make_const(t, com->base); - } - } - for (const ComExprEntry& e : com->elem) { - if (e.scale < 0) { - Expr v = e.value; - if (e.scale != -1) { - v = Mul::make(v, make_const(t, -e.scale)); - } - if (vsum.defined()) { - vsum = Sub::make(vsum, v); - } else { - vsum = Sub::make(make_zero(t), v); - } - } - } - if (vsum.defined()) { - return vsum; - } else { - return make_zero(t); - } - } -}; - -using CInternal = Canonical::Internal; - -Canonical::Canonical(Map vrange) - : ptr_(std::make_shared(vrange)) {} - -Expr Canonical::Simplify(Expr expr) { - return ptr_->Mutate(expr); -} - -Stmt Canonical::Simplify(Stmt stmt) { - return ptr_->Mutate(stmt); -} - -void Canonical::SetRange(Var v, Range r, int level) { - ptr_->SetRange(v, r, level); -} -} // namespace arith - -namespace ir { - -Stmt CanonicalSimplify(Stmt stmt, Map vrange) { - return arith::Canonical(vrange).Simplify(stmt); -} - -Expr CanonicalSimplify(Expr expr, Map vrange) { - return arith::Canonical(vrange).Simplify(expr); -} - -template -T Simplify_(T a, Map vrange) { - using namespace HalideIR::Internal; - Scope rscope; - for (auto kv : vrange) { - Range r = kv.second; - rscope.push( - kv.first.get(), - Interval(r->min, - simplify(r->min + r->extent - make_const(r->min.type(), 1)))); - } - return HalideIR::Internal::simplify(a, true, rscope); -} - - -/*! - * \brief Simplify just the combiner of the given reduce node. - * - * This function applies Simplify to the components of the top reduction's - * combiner, but not to the source or condition of the reduction. - * It also removes all components which are not used to - * compute the resulting value (the value_index-th value). - * - * If \p expr is not a reduction node, it is left unchanged. - * - * \param expr The expression to be simplifed. - * \return Simplified expression. - */ -Expr SimplifyCombiner(const Expr& expr, const Map& vrange = Map()) { - const Reduce* op = expr.as(); - if (!op) { - return expr; - } - - // First simplify the results - Array simplified_result; - for (const auto& res : op->combiner->result) { - simplified_result.push_back(Simplify(res, vrange)); - } - - // Which components to keep - std::vector used(op->combiner->result.size(), false); - - // This function recursively marks the used components starting from - // the index idx - std::function mark_used; - mark_used = [&used, &simplified_result, op, &mark_used](size_t idx) { - // if the idx-th component was marked as used before, do nothing - if (used[idx]) return; - used[idx] = true; - - // check if the idx-th result expr uses some lhs or rhs variables - // and recursively mark the corresponding components - for (size_t i = 0; i < simplified_result.size(); ++i) - if (!used[i]) { - if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) || - ExprUseVar(simplified_result[idx], op->combiner->rhs[i])) - mark_used(i); - } - }; - - // mark all used components starting from the value_index - mark_used(op->value_index); - - // components which have side effects should also be preserved - for (size_t i = 0; i < used.size(); ++i) { - if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) || - HasSideEffect(op->combiner->result[i])) { - mark_used(i); - } - } - - int new_value_index = op->value_index; - Array new_result; - Array new_identity; - Array new_lhs; - Array new_rhs; - Array new_source; - - // new stuff is old stuff which is used - for (size_t i = 0; i < used.size(); ++i) { - if (used[i]) { - // We simplify the result and identity, but not the source - new_result.push_back(simplified_result[i]); - new_identity.push_back(Simplify(op->combiner->identity_element[i], vrange)); - new_lhs.push_back(op->combiner->lhs[i]); - new_rhs.push_back(op->combiner->rhs[i]); - new_source.push_back(op->source[i]); - } else if (static_cast(i) < op->value_index) { - // value_index should also be adjusted - new_value_index--; - } - } - - CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); - return Reduce::make(new_combiner, new_source, op->axis, op->condition, new_value_index); -} - -/*! - * \brief Remove a single reduction over empty axis. - * - * If \p e is a reduction node and its axis is empty, replace it with its source, - * otherwise return \p e unchanged. - * - * \param e The expression to be transformed. - * \return The transformed expression. - */ -Expr RemoveEmptyReduction(const Expr& e) { - const Reduce* r = e.as(); - if (r && r->axis.empty()) { - // Note that here we assume that the identity element is indeed identity. Without this - // assumption we would have to perform a single iteration of the loop, i.e. use - // `(*r->combiner.get())(r->combiner->identity_element, r->source)[r->value_index]` - // instead of `r->source[r->value_index]`. The former may be more difficult to simplify. - return Select::make(r->condition, - r->source[r->value_index], - r->combiner->identity_element[r->value_index]); - } - return e; -} - -Expr Simplify(Expr a, Map vrange) { - // We should not pass an expression having a non-HalideIR op to - // Halide::Internal::simplify. Reduce op is the only such op at this time - // and it only appears as the top op in an expression. So we strip it - // first and send the sub-expressions to the simplifier. - if (const Reduce* r = a.as()) { - // If axis is empty, we can remove the reduce op completely. - if (r->axis.empty()) - return Simplify_(RemoveEmptyReduction(a), vrange); - - // Simplify the combiner of the reduction - a = SimplifyCombiner(a, vrange); - r = a.as(); - - // If axis is not empty then we add the information about ranges to vrange - for (const IterVar& iv : r->axis) { - if (vrange.count(iv->var)) { - Range existing_range = vrange[iv->var]; - CHECK(Equal(existing_range->min, iv->dom->min) && - Equal(existing_range->extent, iv->dom->extent)) - << "Simplify was given vrange stating that the range of the reduction var " - << iv << " is " << existing_range << ". This is probably a mistake."; - } - vrange.Set(iv->var, iv->dom); - } - - Array new_source; - for (auto& e : r->source) { - new_source.push_back(Simplify_(e, vrange)); - } - Expr new_condition = Simplify_(r->condition, vrange); - if (r->source.same_as(new_source) && - r->condition.same_as(new_condition)) { - return a; - } else { - return Reduce::make( - r->combiner, new_source, r->axis, new_condition, r->value_index); - } - } - return Simplify_(a, vrange); -} - -Stmt Simplify(Stmt a, Map vrange) { - return Simplify_(a, vrange); -} -} // namespace ir -} // namespace tvm diff --git a/src/arithmetic/canonical.h b/src/arithmetic/canonical.h deleted file mode 100644 index a02dbeef7e3a..000000000000 --- a/src/arithmetic/canonical.h +++ /dev/null @@ -1,56 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file canonical.h - * \brief Internal canonicalized expression simplification engine. - */ -#ifndef TVM_ARITHMETIC_CANONICAL_H_ -#define TVM_ARITHMETIC_CANONICAL_H_ - -#include -#include -#include - -namespace tvm { -namespace arith { - -/*! - * \brief A stateful CanonicalEngine over SSA. - * - * Simplify and CSE with canonicalization expressions. - * Each call's result will get cached, so next call will - * simply return the cached result. - */ -class Canonical { - public: - /*! \brief constructor */ - explicit Canonical(Map var_range); - /*! - * \brief simplify expression e. - * \param expr The expression to be simplified. - */ - Expr Simplify(Expr expr); - /*! - * \brief simplify stmt. - * \param stmt The stmt to be simplified. - */ - Stmt Simplify(Stmt expr); - /*! - * \brief Set range and level variable - * \param v The variable - * \param r The range of the variable, can be undefined. - * \param level The scope level of the variable, - * affect the order of formula in communicative ops. - */ - void SetRange(Var v, Range r, int level); - - class Internal; - private: - // Internal pointer - std::shared_ptr ptr_; -}; - - -} // namespace arith -} // namespace tvm - -#endif // TVM_ARITHMETIC_CANONICAL_H_ diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc new file mode 100644 index 000000000000..ed4ebe975bfd --- /dev/null +++ b/src/arithmetic/canonical_simplify.cc @@ -0,0 +1,884 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file canonical_simplify.cc + * \brief Canonical form based simplification. + */ +#include +#include +#include +#include "const_fold.h" +#include "pattern_match.h" +#include "rewrite_simplify.h" + +namespace tvm { +namespace arith { + +using namespace ir; + +class SumExpr; +class SplitExpr; + +/*! + * \brief Base class of all temporary expression introduced + * for canonicalization. + */ +class CanonicalExprNode : public BaseExprNode { + public: + /*! + * \brief Return the normal Expr that is equivalent to self. + * \note Can mutate the internal data structure. + * \return The normal expression. + */ + virtual Expr Normalize() const = 0; + + // overrides + void VisitAttrs(tvm::AttrVisitor* v) final { + } + void accept(HalideIR::Internal::IRVisitor* v, const Expr& e) const final { + LOG(FATAL) << "not supported"; + } + IRNodeType type_info() const final { + return IRNodeType::ExtensionExpr; + } + + static constexpr const char* _type_key = "arith.CanonicalExpr"; + TVM_DECLARE_BASE_NODE_INFO(CanonicalExprNode, BaseExprNode); +}; + +/*! + * \brief Internal "Split normal form" of expression. + * + * This is a special expression that represents + * a scaled value derived from a split of an index. + * + * result = ((index % upper_factor) / lower_factor) * scale + */ +class SplitExprNode : public CanonicalExprNode { + public: + /*! \brief The base index expression. */ + Expr index; + /*! \brief The division factor ratio. */ + int64_t lower_factor{1}; + /*! + * \brief The upper factor. + * invariance: (upper_factor == kPosInf || upper_factor % lower_factor == 0) + */ + int64_t upper_factor{kPosInf}; + /*! \brief scale to the expression. */ + int64_t scale{1}; + + /*! \brief verify that this is a valid entry. */ + void Verify() const { + CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0); + } + + Expr NormalizeWithScale(int64_t sscale) const { + Expr res = this->index; + Type dtype = this->type; + if (this->scale == 0) { + return make_const(dtype, 0); + } + if (this->upper_factor != SplitExprNode::kPosInf) { + res = res % make_const(dtype, this->upper_factor); + } + if (this->lower_factor != 1) { + res = res / make_const(dtype, this->lower_factor); + } + sscale *= this->scale; + if (sscale != 1) { + CHECK(!dtype.is_uint() || sscale > 0); + res = res * make_const(dtype, sscale); + } + return res; + } + + Expr Normalize() const final { + return NormalizeWithScale(1); + } + + void MulToSelf(int64_t scale) { + this->scale *= scale; + } + + inline bool IndexEqual(const SplitExpr& other) const; + + /*! \brief positive infty */ + static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; + static constexpr const char* _type_key = "arith.SplitExpr"; + TVM_DECLARE_NODE_TYPE_INFO(SplitExprNode, CanonicalExprNode); +}; + +TVM_DEFINE_COW_NODE_REF(SplitExpr, Expr, SplitExprNode); + +inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const { + if (index.same_as(other->index)) return true; + return ir::Equal(index, other->index); +} + +/*! + * \brief Normal form that represents sum of expressions. + * + * result = sum(args) + base. + */ +class SumExprNode : public CanonicalExprNode { + public: + /*! + * \brief arguments to be summed up. + * + * args are divided into segments with the same index. + * within each segment, the SplitExpr is ordered in descending order of lower_factor. + */ + std::vector args; + /*! \brief Base value in the summation. */ + int64_t base{0}; + /*! + * \brief Return the normal Expr that is equivalent to self. + * \return The normal expression. + */ + Expr Normalize() const final { + // quick path 1. + if (this->args.size() == 0) { + return make_const(this->type, this->base); + } + return Normalize_(this->type, + SimplifySplitExprs(args), + base); + } + /*! + * \brief Whether self is divisible by scale. + * \param scale The scale to be applied. + */ + bool DivisibleBy(int64_t scale) { + if (base % scale != 0) return false; + for (size_t i = 0; i < this->args.size(); ++i) { + if (args[i]->scale % scale != 0) return false; + } + return true; + } + /*! + * \brief mul scale to self. + * \param scale The scale to be applied. + */ + void MulToSelf(int64_t scale) { + this->base *= scale; + for (size_t i = 0; i < this->args.size(); ++i) { + args[i].CopyOnWrite()->scale *= scale; + } + } + /*! + * \brief divide by scale. + * \param scale The scale to be applied. + */ + void DivideBy(int64_t scale) { + CHECK_EQ(this->base % scale, 0); + this->base /= scale; + for (size_t i = 0; i < this->args.size(); ++i) { + CHECK_EQ(args[i]->scale % scale, 0); + args[i].CopyOnWrite()->scale /= scale; + } + } + /*! + * \brief add constant value to self. + * \param value to be added. + */ + void AddToSelf(int64_t value) { + this->base += value; + } + /*! + * \brief self += other * scale; + * \param other The expression to be added. + * \param scale The additional scale on value. + */ + void AddToSelf(SplitExpr other, int64_t scale) { + if (other->scale == 0) return; + // We need to maintain the segment invariance: + // Same index are stored close to each other. + // sorted from big lower_factor to small one. + size_t start = 0; + for (; start < args.size(); ++start) { + if (args[start]->IndexEqual(other)) break; + } + for (size_t j = start; j < args.size(); ++j) { + if (!args[j]->IndexEqual(other) || + other->lower_factor > args[j]->lower_factor) { + other.CopyOnWrite()->scale *= scale; + this->args.insert(this->args.begin() + j, other); + return; + } + if (other->lower_factor == args[j]->lower_factor && + other->upper_factor == args[j]->upper_factor) { + args[j].CopyOnWrite()->scale += other->scale * scale; + return; + } + } + // Insert other in the end. + other.CopyOnWrite()->scale *= scale; + this->args.emplace_back(std::move(other)); + } + + void AddToSelf(const SumExpr& other, int64_t scale); + + static constexpr const char* _type_key = "arith.SumExpr"; + TVM_DECLARE_NODE_TYPE_INFO(SumExprNode, CanonicalExprNode); + + private: + /*! + * \brief Simplify the args by merging SplitExprs + * \param args The original list of arguments. + * \return simplified version. + */ + static std::vector + SimplifySplitExprs(std::vector args) { + // NOTE: This algorithm relies on the factor that args are divided into segments + // and each segment is sorted in descending order of lower_factor. + for (size_t i = 0; i < args.size(); ++i) { + if (args[i]->scale == 0) continue; + for (size_t j = i + 1; j < args.size(); ++j) { + SplitExpr& lhs = args[i]; + SplitExpr& rhs = args[j]; + if (!lhs->IndexEqual(rhs)) break; + if (lhs->upper_factor < rhs->lower_factor) break; + if (lhs->lower_factor == rhs->upper_factor && + lhs->scale % rhs->scale == 0 && + lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor) { + // Rules used in the proof: + // + // Rule 1: (x % (c * s)) / c = (x / c) % s + // Proof: + // x can always be decomposed into p * c * s + q * c + r + // where 0 <= q * c + r < c * s and 0 <= r < c. + // Then, lhs = ((p * c * s + q * c + r) % (c * s)) / c = (q * c + r) / c = q + // rhs = ((p * c * s + q * c + r) / c) % s = (p * s + q) % s = q + // Thus, lhs = rhs + // + // The above proof is for the floordiv. + // The same rule also holds for trucdiv(division rule in C). + // Because both sides only involve mul, div and mod, + // we can take abs of x, c and s, apply the floordiv proof, + // and finally add the sign back. + // + // Rule 2: (x / s) * s + x % s = x (true for both truc and floor div) + // + // General merge condition and proof: + // - x = lhs->index % lhs->upper_factor + // - s = lhs->scale / rhs->scale + // - c = rhs->lower_factor + // + // (x / (c * s)) * s + (x % (c * s)) / c + // => ((x / c) / s) * s + ((x / c) % s) + // => (x / c) + // + // Examples: + // + // (z / 6) * 6 + ((z % 6) / 3) * 3 + // => ((z / 6) * 2 + (z % 6) / 3) * 3 + // => (z / 3) * 3 + // note: x = z, c = 3, s = 2 + // + // ((z % 12) / 6) * 6 + ((z % 6) / 3) * 3 + // => (((z % 12) / 6) * 2 + ((z % 12) % 6) / 3) * 3 + // => ((z % 12) / 3) * 3 + // note: x = z % 12, c = 3, s = 2 + // note also the invariance lhs->upper_factor % lhs->lower_factor == 0 + // + SplitExprNode* merged = rhs.CopyOnWrite(); + merged->upper_factor = lhs->upper_factor; + // reset args[i] to be zero. + lhs.CopyOnWrite()->scale = 0; + break; + } + } + } + // sort by the entry + // Here we simply sort by descending order of scales. + // For now, we do not compare by index because that comparison + // can be runtime dependent and create inderminism. + // we do not sort by index for now because it can be costly + // to deep compare Exprs, and address of Vars can be runtime dependent. + // + auto fcompare = [](const SplitExpr& lhs, const SplitExpr& rhs) { + // order by scale first + if (lhs->scale > rhs->scale) return true; + if (lhs->scale < rhs->scale) return false; + // then order by factor + if (lhs->lower_factor > rhs->lower_factor) return true; + if (lhs->lower_factor < rhs->lower_factor) return false; + // then order by upper factor + if (lhs->upper_factor > rhs->upper_factor) return true; + if (lhs->upper_factor < rhs->upper_factor) return false; + // tie. + // TODO(tvm-team) We might consider index as the last comparison point, + // after we make deep comparator more derministic. + // Specifically, we can consider comparing names of vars and break ties with address. + return false; + }; + std::stable_sort(args.begin(), args.end(), fcompare); + return args; + } + static Expr Normalize_(Type dtype, + const std::vector& args, + int64_t base) { + // Positive scales first + Expr res = make_const(dtype, 0); + for (size_t i = 0; i < args.size(); ++i) { + if (args[i]->scale > 0) { + res = res + args[i]->Normalize(); + } + } + if (base > 0) { + res = res + make_const(dtype, base); + } + // negative scales follows using sub. + for (size_t i = 0; i < args.size(); ++i) { + if (args[i]->scale < 0) { + res = res - args[i]->NormalizeWithScale(-1); + } + } + if (base < 0) { + res = res - make_const(dtype, -base); + } + return res; + } +}; + +TVM_DEFINE_COW_NODE_REF(SumExpr, Expr, SumExprNode); + +void SumExprNode::AddToSelf(const SumExpr& other, int64_t scale) { + // NOTE: it is rare to have a balanced long expression, + // linear scan is fine for our case. + for (size_t i = 0; i < other->args.size(); ++i) { + this->AddToSelf(other->args[i], scale); + } + this->AddToSelf(other->base * scale); +} + +// Sub-class RewriteSimplifier::Impl to take benefit of +// rewriter for condition simplification etc. +class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { + public: + using Rewriter = RewriteSimplifier::Impl; + + explicit Impl(Analyzer* parent) + : Rewriter(parent) {} + + + Expr CanonicalSimplify(Expr expr) { + expr = Mutate(expr); + return expr; + } + + // override the original mutate function. + Expr Mutate(Expr expr) final { + expr = IRMutator::Mutate(expr); + return Normalize(expr); + } + + // Normal mutation without normalization. + Expr CanonicalMutate(Expr expr) { + return IRMutator::Mutate(expr); + } + + using Rewriter::Mutate_; + Expr Mutate_(const Add* op, const Expr& self) final; + Expr Mutate_(const Sub* op, const Expr& self) final; + Expr Mutate_(const Mul* op, const Expr& self) final; + Expr Mutate_(const Div* op, const Expr& self) final; + Expr Mutate_(const Mod* op, const Expr& self) final; + Expr Mutate_(const Reduce* op, const Expr& self) final; + + private: + /*! + * \brief compute lhs / cval + * \param lhs The left operand. + * \param cval The constant value. + * \return The result expression; + */ + SplitExpr SplitDivConst(SplitExpr lhs, int64_t cval); + /*! + * \brief compute lhs % cval + * \param lhs The left operand. + * \param cval The constant value. + * \return The result expression; + */ + SplitExpr SplitModConst(SplitExpr lhs, int64_t cval); + /*! + * \brief Detect if psum = q * coeff + r such that (q >= 0 && r >= 0) + * \param psum The sum expression. + * \param coeff The co-efficient. + * \param out_divisible The result divisible component. + * \param out_non_divisible The non-divisible component. + * \return Whether detection is successful. + */ + bool TryLinearEquation(const SumExprNode* psum, + int64_t coeff, + SumExpr* out_divisible, + SumExpr* out_non_divisible); + /*! + * \brief Normalize expr to normal expr. + * \param expr The input expression. + * \return Normalized expr. + */ + Expr Normalize(Expr expr) { + if (const auto* op = expr.as_derived()) { + return op->Normalize(); + } else { + return expr; + } + } + /*! + * \brief Create a SplitExpr from expr. + * \param expr The input expr. + * \return The transformed SplitExpr. + */ + SplitExpr ToSplitExpr(Expr expr) { + if (const auto* op = expr.as()) { + return GetRef(op); + } + if (const auto* op = expr.as_derived()) { + expr = op->Normalize(); + } + NodePtr n = make_node(); + n->type = expr.type(); + n->index = std::move(expr); + return SplitExpr(n); + } + /*! + * \brief Create a SumExpr from expr. + * \param expr The input expr. + * \return The transformed SumExpr. + */ + SumExpr ToSumExpr(Expr expr) { + if (const auto* op = expr.as()) { + return GetRef(op); + } + NodePtr n = make_node(); + n->type = expr.type(); + if (const auto* op = expr.as()) { + n->base = op->value; + return SumExpr(n); + } else { + n->args.emplace_back(ToSplitExpr(expr)); + return SumExpr(n); + } + } + // Simplify the combiner used in reduce. + Expr SimplifyReduceCombiner(const Reduce* op); +}; + +Expr CanonicalSimplifier::Impl:: +Mutate_(const Add* op, const Expr& self) { + if (!IsIndexType(op->type)) { + return Rewriter::Mutate_(op, self); + } + // normalize + Expr a = this->CanonicalMutate(op->a); + Expr b = this->CanonicalMutate(op->b); + + // const folding + Expr const_res = TryConstFold(a, b); + if (const_res.defined()) return const_res; + + // canonical form simplification. + SumExpr ret = ToSumExpr(std::move(a)); + + if (const auto* op = b.as()) { + ret.CopyOnWrite()->AddToSelf(op->value); + } else if (const auto* op = b.as()) { + ret.CopyOnWrite()->AddToSelf(GetRef(op), 1); + } else { + ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), 1); + } + return ret; +} + +Expr CanonicalSimplifier::Impl:: +Mutate_(const Sub* op, const Expr& self) { + if (!IsIndexType(op->type)) { + return Rewriter::Mutate_(op, self); + } + // normalize + Expr a = this->CanonicalMutate(op->a); + Expr b = this->CanonicalMutate(op->b); + + // const folding + Expr const_res = TryConstFold(a, b); + if (const_res.defined()) return const_res; + + // canonical form simplification. + SumExpr ret = ToSumExpr(std::move(a)); + + if (const auto* op = b.as()) { + ret.CopyOnWrite()->AddToSelf(-op->value); + } else if (const auto* op = b.as()) { + ret.CopyOnWrite()->AddToSelf(GetRef(op), -1); + } else { + ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), -1); + } + return ret; +} + + +Expr CanonicalSimplifier::Impl:: +Mutate_(const Mul* op, const Expr& self) { + if (!IsIndexType(op->type)) { + return Rewriter::Mutate_(op, self); + } + // normalize + Expr a = this->CanonicalMutate(op->a); + Expr b = this->CanonicalMutate(op->b); + + // const folding + Expr const_res = TryConstFold(a, b); + if (const_res.defined()) return const_res; + + // x * c + if (a.as()) { + std::swap(a, b); + } + if (const auto* bconst = b.as()) { + if (a.as()) { + SumExpr ret(std::move(a.node_)); + ret.CopyOnWrite()->MulToSelf(bconst->value); + return ret; + } else { + SplitExpr ret = ToSplitExpr(std::move(a)); + ret.CopyOnWrite()->MulToSelf(bconst->value); + return ret; + } + } + + // normal path. + a = Normalize(a); + b = Normalize(b); + if (op->a.same_as(a) && op->b.same_as(b)) { + return self; + } else { + return Mul::make(a, b); + } +} + + +bool CanonicalSimplifier::Impl:: +TryLinearEquation(const SumExprNode* psum, + int64_t coeff, + SumExpr* out_divisible, + SumExpr* out_non_divisible) { + auto divisible = make_node(); + auto non_divisible = make_node(); + divisible->type = psum->type; + non_divisible->type = psum->type; + + if (psum->base % coeff == 0) { + divisible->base = psum->base; + } else { + non_divisible->base = psum->base; + } + for (const auto& e : psum->args) { + if (e->scale % coeff == 0) { + divisible->args.push_back(e); + } else { + non_divisible->args.push_back(e); + } + } + *out_divisible = SumExpr(divisible); + *out_non_divisible = SumExpr(non_divisible); + + if (non_divisible->base == 0 && non_divisible->args.size() == 0) { + return true; + } + if (parent_->CanProveGreaterEqual(divisible->Normalize(), 0) && + parent_->CanProveGreaterEqual(non_divisible->Normalize(), 0)) { + return true; + } else { + return false; + } +} + +SplitExpr CanonicalSimplifier::Impl:: +SplitDivConst(SplitExpr lhs, int64_t cval) { + if (lhs->scale % cval == 0) { + lhs.CopyOnWrite()->scale /= cval; + return lhs; + } + + if (cval % lhs->scale == 0) { + int64_t scaled_cval = cval / lhs->scale; + if (lhs->upper_factor == SplitExprNode::kPosInf || + lhs->upper_factor % (lhs->lower_factor * scaled_cval) == 0) { + // directly fold division. + lhs.CopyOnWrite()->scale = 1; + lhs.CopyOnWrite()->lower_factor *= scaled_cval; + lhs->Verify(); + return lhs; + } else if (lhs->upper_factor <= (lhs->lower_factor * scaled_cval)) { + // (x % c1) / c2 => 0 when c2 >= c1 + return ToSplitExpr(make_zero(lhs.type())); + } else { + // move the upper_factor modular into index. + lhs.CopyOnWrite()->index = + lhs->index % make_const(lhs.type(), lhs->upper_factor); + lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf; + lhs.CopyOnWrite()->scale = 1; + lhs.CopyOnWrite()->lower_factor *= scaled_cval; + lhs->Verify(); + return lhs; + } + } + // directly return the split with cval == 1 + lhs = ToSplitExpr(Normalize(lhs)); + CHECK_EQ(lhs->scale, 1); + lhs.CopyOnWrite()->lower_factor *= cval; + return lhs; +} + +Expr CanonicalSimplifier::Impl:: +Mutate_(const Div* op, const Expr& self) { + if (!IsIndexType(op->type)) { + return Rewriter::Mutate_(op, self); + } + Expr a = this->CanonicalMutate(op->a); + Expr b = this->CanonicalMutate(op->b); + + // const folding + Expr const_res = TryConstFold
(a, b); + if (const_res.defined()) return const_res; + PVar c1; + // x / c1 + if (c1.Match(b) && c1.Eval()->value > 0) { + int64_t cval = c1.Eval()->value; + if (cval == 1) return a; + + if (const auto* psum = a.as()) { + SumExpr lhs, extra; + if (TryLinearEquation(psum, cval, &lhs, &extra)) { + lhs.CopyOnWrite()->DivideBy(cval); + Expr temp = Normalize(extra); + if (const auto* pconst = temp.as()) { + lhs.CopyOnWrite()->AddToSelf(pconst->value / cval); + } else { + // if extra <= cval, it means the extra can be eliminated. + if (TryCompare(temp, cval) != kLT) { + lhs.CopyOnWrite()->AddToSelf( + SplitDivConst(ToSplitExpr(temp), cval), 1); + } + } + return lhs; + } + } else { + // if a >= 0 && a < cval, then result == 0 + auto cbound = parent_->const_int_bound(Normalize(a)); + if (cbound->min_value >= 0 && cbound->max_value < cval) { + return make_zero(a.type()); + } + } + return SplitDivConst(ToSplitExpr(std::move(a)), cval); + } + // normal path + a = Normalize(a); + b = Normalize(b); + if (op->a.same_as(a) && op->b.same_as(b)) { + return self; + } else { + return Div::make(a, b); + } +} + +SplitExpr CanonicalSimplifier::Impl:: +SplitModConst(SplitExpr lhs, int64_t cval) { + if (lhs->scale % cval == 0) { + lhs.CopyOnWrite()->scale = 0; + return lhs; + } + if (cval % lhs->scale == 0) { + // (x * c1) % (c2 * c1) => (x % c2) * c1 + int64_t scaled_cval = cval / lhs->scale; + // (x / c1) % c2 => (x % (c1 * c2)) / c2 + int64_t new_upper_factor = lhs->lower_factor * scaled_cval; + // try to see if we can reduce the existing upper modular. + if (lhs->upper_factor == SplitExprNode::kPosInf || + lhs->upper_factor % new_upper_factor == 0) { + lhs.CopyOnWrite()->upper_factor = new_upper_factor; + lhs->Verify(); + return lhs; + } else if (new_upper_factor % lhs->upper_factor == 0) { + // (x % 2) % 4 => x % 2 + return lhs; + } + } + // Normalize the value. + lhs = ToSplitExpr(Normalize(lhs)); + CHECK_EQ(lhs->scale, 1); + CHECK_EQ(lhs->lower_factor, 1); + lhs.CopyOnWrite()->upper_factor = cval; + return lhs; +} + +Expr CanonicalSimplifier::Impl:: +Mutate_(const Mod* op, const Expr& self) { + if (!IsIndexType(op->type)) { + return Rewriter::Mutate_(op, self); + } + // normalize + Expr a = this->CanonicalMutate(op->a); + Expr b = this->CanonicalMutate(op->b); + + // const folding + Expr const_res = TryConstFold(a, b); + if (const_res.defined()) return const_res; + + PVar c1; + // x % c1 + if (c1.Match(b) && c1.Eval()->value > 0) { + int64_t cval = c1.Eval()->value; + if (const auto* psum = a.as()) { + SumExpr lhs, extra; + if (TryLinearEquation(psum, cval, &lhs, &extra)) { + Expr temp = Normalize(extra); + if (temp.as()) { + return temp % c1.Eval(); + } else { + // If temp < cval && temp >=0 then can remove the mod. + if (TryCompare(temp, cval) == kLT) { + return temp; + } else { + return SplitModConst(ToSplitExpr(temp), cval); + } + } + } + } else { + // if a >= 0 && a < cval, then result == 0 + auto cbound = parent_->const_int_bound(Normalize(a)); + if (cbound->min_value >= 0 && cbound->max_value < cval) { + return a; + } + } + return SplitModConst(ToSplitExpr(std::move(a)), cval); + } + // normal path + a = Normalize(a); + b = Normalize(b); + if (op->a.same_as(a) && op->b.same_as(b)) { + return self; + } else { + return Mod::make(a, b); + } +} + +// Simplify reduce expression. +Expr CanonicalSimplifier::Impl:: +SimplifyReduceCombiner(const Reduce* op) { + // First simplify the results + Array simplified_result; + for (const auto& res : op->combiner->result) { + Expr new_res = Mutate(res); + simplified_result.push_back(new_res); + } + + // Which components to keep + std::vector used(op->combiner->result.size(), false); + + // This function recursively marks the used components starting from + // the index idx + std::function mark_used; + mark_used = [&used, &simplified_result, op, &mark_used](size_t idx) { + // if the idx-th component was marked as used before, do nothing + if (used[idx]) return; + used[idx] = true; + + // check if the idx-th result expr uses some lhs or rhs variables + // and recursively mark the corresponding components + for (size_t i = 0; i < simplified_result.size(); ++i) + if (!used[i]) { + if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) || + ExprUseVar(simplified_result[idx], op->combiner->rhs[i])) + mark_used(i); + } + }; + + // mark all used components starting from the value_index + mark_used(op->value_index); + + // components which have side effects should also be preserved + for (size_t i = 0; i < used.size(); ++i) { + if (HasSideEffect(op->source[i]) || + HasSideEffect(op->combiner->identity_element[i]) || + HasSideEffect(op->combiner->result[i])) { + mark_used(i); + } + } + + int new_value_index = op->value_index; + Array new_result; + Array new_identity; + Array new_lhs; + Array new_rhs; + Array new_source; + + // new stuff is old stuff which is used + for (size_t i = 0; i < used.size(); ++i) { + if (used[i]) { + // We simplify the result and identity, but not the source + new_result.push_back(simplified_result[i]); + new_identity.push_back(Mutate(op->combiner->identity_element[i])); + new_lhs.push_back(op->combiner->lhs[i]); + new_rhs.push_back(op->combiner->rhs[i]); + new_source.push_back(op->source[i]); + } else if (static_cast(i) < op->value_index) { + // value_index should also be adjusted + new_value_index--; + } + } + + CommReducer new_combiner = + CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); + return Reduce::make( + new_combiner, new_source, op->axis, op->condition, new_value_index); +} + +Expr CanonicalSimplifier::Impl:: +Mutate_(const Reduce* op, const Expr& self) { + // Setup the domain information before simplification. + for (const IterVar& iv : op->axis) { + parent_->Bind(iv->var, iv->dom); + } + // Recursively call simplification when necessary. + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as(); + // already been simplified by const reduction axis removal + if (op == nullptr) return ret; + if (op->axis.empty()) { + // Note that here we assume that the identity element is indeed identity. Without this + // assumption we would have to perform a single iteration of the loop, i.e. use + // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]` + // instead of `op->source[op->value_index]`. The former may be more difficult to simplify. + return Mutate( + Select::make(op->condition, + op->source[op->value_index], + op->combiner->identity_element[op->value_index])); + } + // combiner simplification. + ret = SimplifyReduceCombiner(op); + return ret; +} + +Expr CanonicalSimplifier::operator()(const Expr& expr) { + return impl_->CanonicalSimplify(expr); +} + +void CanonicalSimplifier::Update(const Var& var, + const Expr& info, + bool override) { + impl_->Update(var, info, override); +} + + +CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent) + : impl_(new Impl(parent)) { +} + +CanonicalSimplifier::~CanonicalSimplifier() { + delete impl_; +} + +} // namespace arith +} // namespace tvm diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 4c247c8a7b59..e5d9760abb9e 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -7,6 +7,8 @@ #define TVM_ARITHMETIC_CONST_FOLD_H_ #include +#include +#include #include namespace tvm { diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index c83be8933b55..07846a4145d3 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -37,6 +37,10 @@ struct ConstIntBoundAnalyzer::Entry { bool is_const(int64_t value) const { return min_value == max_value && min_value == value; } + + bool operator==(const Entry& other) const { + return min_value == other.min_value && max_value == other.max_value; + } }; class ConstIntBoundAnalyzer::Impl : @@ -55,7 +59,11 @@ class ConstIntBoundAnalyzer::Impl : const Entry& info, bool override) { if (!override) { - CHECK(!var_map_.count(var)); + auto it = var_map_.find(var); + if (it != var_map_.end()) { + CHECK(it->second == info) + << "var \'" << var << "\' already updated."; + } } var_map_[var] = info; } diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index f031e094d84a..134911821fc0 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -7,8 +7,10 @@ #include #include #include +#include #include "const_fold.h" #include "pattern_match.h" +#include "rewrite_simplify.h" namespace tvm { namespace arith { @@ -39,134 +41,55 @@ using namespace ir; return RecursiveRewrite((ResExpr).Eval()); \ } - // NOTE for developers: // // We mainly focus on index expression simplification. // Besides the RewriteSimplifier, some cases can be better // handled by CanonicalSimplifier. // -class RewriteSimplifier::Impl : public IRMutator { - public: - explicit Impl(Analyzer* parent) - : parent_(parent) {} - - void Update(const Var& var, - const Expr& info, - bool override) { - if (!override) { - CHECK(!var_map_.count(var)); + +// try to prove x equals val +RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl:: +TryCompare(const Expr& x, int64_t val) { + Expr diff = Mutate(x); + if (const auto* ptr = diff.as()) { + if (ptr->value == val) { + return kEQ; + } else if (ptr->value > val) { + return kGT; + } else if (ptr->value < val) { + return kLT; } - var_map_[var] = info; } - - // Run simplification in post order - Expr PostOrderSimplify(Expr expr, int max_iter = 2) { - for (int i = 0; i < max_iter; ++i) { - Expr new_expr = this->Mutate(expr); - if (new_expr.same_as(expr)) return expr; - expr = new_expr; + if (val == 0) { + ModularSet dmod = parent_->modular_set(diff); + if (dmod->base != 0) { + return kNE; } - return expr; } - - Expr Mutate_(const Add* op, const Expr& self) final; - Expr Mutate_(const Sub* op, const Expr& self) final; - Expr Mutate_(const Mul* op, const Expr& self) final; - Expr Mutate_(const Div* op, const Expr& self) final; - Expr Mutate_(const Mod* op, const Expr& self) final; - Expr Mutate_(const Min* op, const Expr& self) final; - Expr Mutate_(const Max* op, const Expr& self) final; - Expr Mutate_(const EQ* op, const Expr& self) final; - Expr Mutate_(const NE* op, const Expr& self) final; - Expr Mutate_(const LT* op, const Expr& self) final; - Expr Mutate_(const LE* op, const Expr& self) final; - Expr Mutate_(const GT* op, const Expr& self) final; - Expr Mutate_(const GE* op, const Expr& self) final; - Expr Mutate_(const And* op, const Expr& self) final; - Expr Mutate_(const Or* op, const Expr& self) final; - Expr Mutate_(const Not* op, const Expr& self) final; - Expr Mutate_(const Select* op, const Expr& self) final; - Expr Mutate_(const Ramp* op, const Expr& self) final; - - private: - /*! \brief internal structure for comparison. */ - enum CompareResult { - kUnknown, - kEQ, - kGT, - kLT, - kGE, - kLE, - kNE - }; - // reference to the main analyzer - Analyzer* parent_; - // counter to record recursive rewrite depth. - int recur_depth_{0}; - // internal variable map - std::unordered_map var_map_; - // maximum number of recursion allowed during a single pass. - static const constexpr int kMaxRecurDepth = 5; - // Whether x >= val - bool CanProveGreaterEqual(const Expr& x, int64_t val) { - return parent_->CanProveGreaterEqual(x, val); + ConstIntBound dbound = parent_->const_int_bound(diff); + if (dbound->min_value > val) { + return kGT; } - // Whether x == val - bool CanProveEqual(const Expr& x, int64_t val) { - // TODO(tqchen) refer back to super-analyzer. - return TryCompare(x, val) == kEQ; + if (dbound->max_value < val) { + return kLT; } - // try to prove x equals val - CompareResult TryCompare(const Expr& x, int64_t val) { - Expr diff = Mutate(x); - if (const auto* ptr = diff.as()) { - if (ptr->value == val) { - return kEQ; - } else if (ptr->value > val) { - return kGT; - } else if (ptr->value < val) { - return kLT; - } - } - if (val == 0) { - ModularSet dmod = parent_->modular_set(diff); - if (dmod->base != 0) { - return kNE; - } - } - ConstIntBound dbound = parent_->const_int_bound(diff); - if (dbound->min_value > val) { - return kGT; - } - if (dbound->max_value < val) { - return kLT; - } - if (dbound->min_value >= val) { - return kGE; - } - if (dbound->max_value <= val) { - return kLE; - } - return kUnknown; + if (dbound->min_value >= val) { + return kGE; } - - // Recursive rewrite x - // we limit maximum depth of recursive rewrite allowed to - // avoid infinite loop - Expr RecursiveRewrite(const Expr& x) { - if (recur_depth_ >= kMaxRecurDepth) return x; - ++recur_depth_; - Expr res = Mutate(x); - --recur_depth_; - return res; + if (dbound->max_value <= val) { + return kLE; } + return kUnknown; +} - template - PConstWithTypeLike ZeroWithTypeLike(const Pattern& pattern) { - return PConstWithTypeLike(pattern.derived(), 0); +void RewriteSimplifier::Impl:: +Update(const Var& var, const Expr& info, bool override) { + if (!override) { + CHECK(!var_map_.count(var)); } -}; + var_map_[var] = info; +} Expr RewriteSimplifier::Impl:: Mutate_(const Add* op, const Expr& self) { @@ -1253,16 +1176,6 @@ Mutate_(const Or* op, const Expr& self) { return ret; } -Expr RewriteSimplifier::Impl:: -Mutate_(const Ramp* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); - op = ret.as(); - if (is_zero(op->stride)) { - return Broadcast::make(op->base, op->lanes); - } - return ret; -} - Expr RewriteSimplifier::Impl:: Mutate_(const Select* op, const Expr& self) { Expr ret = IRMutator::Mutate_(op, self); @@ -1275,13 +1188,30 @@ Mutate_(const Select* op, const Expr& self) { } // Pattern var to match any expression PVar x, y; - TVM_TRY_REWRITE(select(x, y, y), y); return ret; } +Expr RewriteSimplifier::Impl:: +Mutate_(const Call* op, const Expr& self) { + Expr ret = IRMutator::Mutate_(op, self); + op = ret.as(); + if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) { + return op->args[0]; + } + return ret; +} + Expr RewriteSimplifier::operator()(const Expr& expr) { - return impl_->PostOrderSimplify(expr); + // Run simplification in post order + Expr res = expr; + int max_iter = 2; + for (int i = 0; i < max_iter; ++i) { + Expr new_expr = impl_->Mutate(res); + if (new_expr.same_as(res)) return res; + res = new_expr; + } + return res; } void RewriteSimplifier::Update(const Var& var, @@ -1290,7 +1220,6 @@ void RewriteSimplifier::Update(const Var& var, impl_->Update(var, info, override); } - RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) { } diff --git a/src/arithmetic/rewrite_simplify.h b/src/arithmetic/rewrite_simplify.h new file mode 100644 index 000000000000..e3435fe9b197 --- /dev/null +++ b/src/arithmetic/rewrite_simplify.h @@ -0,0 +1,110 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file rewrite_simplify.h + * \brief Rewrite-rule based simplification. + */ +#ifndef TVM_ARITHMETIC_REWRITE_SIMPLIFY_H_ +#define TVM_ARITHMETIC_REWRITE_SIMPLIFY_H_ + +#include +#include +#include +#include +#include "const_fold.h" +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace ir; + +/*! + * \brief Rewrite-based simplifier. + * + * This class can be inheritated for other simplifiers. + */ +class RewriteSimplifier::Impl : public IRMutator { + public: + explicit Impl(Analyzer* parent) + : parent_(parent) {} + + void Update(const Var& var, const Expr& info, bool override); + Expr Mutate_(const Add* op, const Expr& self) override; + Expr Mutate_(const Sub* op, const Expr& self) override; + Expr Mutate_(const Mul* op, const Expr& self) override; + Expr Mutate_(const Div* op, const Expr& self) override; + Expr Mutate_(const Mod* op, const Expr& self) override; + Expr Mutate_(const Min* op, const Expr& self) override; + Expr Mutate_(const Max* op, const Expr& self) override; + Expr Mutate_(const EQ* op, const Expr& self) override; + Expr Mutate_(const NE* op, const Expr& self) override; + Expr Mutate_(const LT* op, const Expr& self) override; + Expr Mutate_(const LE* op, const Expr& self) override; + Expr Mutate_(const GT* op, const Expr& self) override; + Expr Mutate_(const GE* op, const Expr& self) override; + Expr Mutate_(const And* op, const Expr& self) override; + Expr Mutate_(const Or* op, const Expr& self) override; + Expr Mutate_(const Not* op, const Expr& self) override; + Expr Mutate_(const Select* op, const Expr& self) override; + Expr Mutate_(const Call* op, const Expr& self) override; + + protected: + /*! \brief internal structure for comparison. */ + enum CompareResult { + kUnknown, + kEQ, + kGT, + kGE, + kLT, + kLE, + kNE + }; + // reference to the main analyzer + Analyzer* parent_; + // counter to record recursive rewrite depth. + int recur_depth_{0}; + // internal variable map + std::unordered_map var_map_; + // maximum number of recursion allowed during a single pass. + static const constexpr int kMaxRecurDepth = 5; + + /*! + * \brief try to compare x against val. + * \param x The expression to be evaluated. + * \param val The constant value. + * \return comparison result. + */ + CompareResult TryCompare(const Expr& x, int64_t val); + + private: + // Whether x >= val + bool CanProveGreaterEqual(const Expr& x, int64_t val) { + return parent_->CanProveGreaterEqual(x, val); + } + // Whether x == val + bool CanProveEqual(const Expr& x, int64_t val) { + // TODO(tqchen) refer back to super-analyzer. + return TryCompare(x, val) == kEQ; + } + + // Recursive rewrite x + // we limit maximum depth of recursive rewrite allowed to + // avoid infinite loop + Expr RecursiveRewrite(const Expr& x) { + if (recur_depth_ >= kMaxRecurDepth) return x; + ++recur_depth_; + Expr res = Mutate(x); + --recur_depth_; + return res; + } + + template + PConstWithTypeLike ZeroWithTypeLike(const Pattern& pattern) { + return PConstWithTypeLike(pattern.derived(), 0); + } +}; + + +} // namespace arith +} // namespace tvm +#endif // TVM_ARITHMETIC_REWRITE_SIMPLIFY_H_ diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc new file mode 100644 index 000000000000..17d95bd9263f --- /dev/null +++ b/src/arithmetic/stmt_simplify.cc @@ -0,0 +1,167 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file stmt_simplify.cc + * \brief Statement simplifier based on analyzer + */ +#include +#include +#include +#include +#include +#include +#include "arithmetic/Simplify.h" + +namespace tvm { +namespace arith { +// statement simplifier +using namespace ir; + +class StmtSimplifier : public IRMutator { + public: + Stmt Mutate_(const For* op, const Stmt& s) final { + Var loop_var(op->loop_var.node_); + analyzer_.Bind(loop_var, Range::make_by_min_extent(op->min, op->extent)); + return IRMutator::Mutate_(op, s); + } + + // IfThenElse + Stmt Mutate_(const IfThenElse* op, const Stmt& s) { + Expr condition = this->Mutate(op->condition); + Stmt then_case, else_case; + { + ConstraintContext ctx(&analyzer_, condition); + then_case = this->Mutate(op->then_case); + } + if (op->else_case.defined()) { + ConstraintContext ctx(&analyzer_, Mutate(Not::make(condition))); + else_case = this->Mutate(op->else_case); + } + if (is_one(condition)) return then_case; + if (is_zero(condition)) { + if (else_case.defined()) { + return else_case; + } + return Evaluate::make(0); + } + + if (condition.same_as(op->condition) && + then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return s; + } else { + return IfThenElse::make(condition, then_case, else_case); + } + } + + // AttrStmt + Stmt Mutate_(const AttrStmt* op, const Stmt& s) { + if (op->attr_key == attr::thread_extent || + op->attr_key == attr::virtual_thread) { + IterVar iv(op->node.node_); + CHECK_NE(iv->thread_tag.length(), 0U); + if (!var_dom_.count(iv->var.get())) { + Range dom = Range::make_by_min_extent(0, op->value); + var_dom_[iv->var.get()] = dom; + analyzer_.Bind(iv->var, dom); + } + Stmt stmt = IRMutator::Mutate_(op, s); + return stmt; + } else { + return IRMutator::Mutate_(op, s); + } + } + + // AssertStmt + Stmt Mutate_(const AssertStmt* op, const Stmt& s) final { + Expr condition = this->Mutate(op->condition); + Expr message = this->Mutate(op->message); + ConstraintContext ctx(&analyzer_, condition); + Stmt body = this->Mutate(op->body); + + if (condition.same_as(op->condition) && + message.same_as(op->message) && + body.same_as(op->body)) { + return s; + } else { + return AssertStmt::make(condition, message, body); + } + } + + protected: + Analyzer analyzer_; + // variable domain + std::unordered_map var_dom_; +}; + + +class CanonicalStmtSimplifier : public StmtSimplifier { + public: + using StmtSimplifier::Mutate; + Expr Mutate(Expr expr) final { + return analyzer_.canonical_simplify(expr); + } + + Stmt CanonicalSimplify(Stmt stmt, Map vrange) { + for (auto kv : vrange) { + analyzer_.Bind(kv.first, kv.second); + } + return Mutate(stmt); + } +}; + +} // namespace arith + +namespace ir { + +Stmt CanonicalSimplify(Stmt stmt, Map vrange) { + return arith::CanonicalStmtSimplifier().CanonicalSimplify( + stmt, vrange); +} + +Expr CanonicalSimplify(Expr expr, Map vrange) { + arith::Analyzer analyzer; + for (auto kv : vrange) { + analyzer.Bind(kv.first, kv.second); + } + return analyzer.canonical_simplify(expr); +} + +template +T Simplify_(T a, Map vrange) { + using namespace HalideIR::Internal; + Scope rscope; + for (auto kv : vrange) { + Range r = kv.second; + rscope.push( + kv.first.get(), + Interval(r->min, + simplify(r->min + r->extent - make_const(r->min.type(), 1)))); + } + return HalideIR::Internal::simplify(a, true, rscope); +} + + +Expr Simplify(Expr a, Map vrange) { + // Simplify top level reduce. + if (const Reduce* r = a.as()) { + Array new_source; + for (auto& e : r->source) { + new_source.push_back(Simplify_(e, vrange)); + } + Expr new_condition = Simplify_(r->condition, vrange); + if (r->source.same_as(new_source) && + r->condition.same_as(new_condition)) { + return a; + } else { + return Reduce::make( + r->combiner, new_source, r->axis, new_condition, r->value_index); + } + } + return Simplify_(a, vrange); +} + +Stmt Simplify(Stmt a, Map vrange) { + return Simplify_(a, vrange); +} +} // namespace ir +} // namespace tvm diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index 2f700ed9112d..05f942f5428b 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -78,7 +78,7 @@ class ThreadAllreduceBuilder final : public IRMutator { Expr Mutate_(const Load* op, const Expr& e) final { auto it = load_remap_.find(op->buffer_var.get()); if (it != load_remap_.end()) { - CHECK(is_zero(op->index)); + CHECK(is_zero(op->index)) << e; return it->second; } else { return IRMutator::Mutate_(op, e); diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py new file mode 100644 index 000000000000..3631789ce307 --- /dev/null +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -0,0 +1,150 @@ +import tvm + +class CanonicalChecker: + def __init__(self): + self.analyzer = tvm.arith.Analyzer() + + def verify(self, data, expected): + res = self.analyzer.canonical_simplify(data) + assert tvm.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format(data, res, expected) + + +def test_mul_sum_simplify(): + ck = CanonicalChecker() + x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") + ck.verify(2 + (3 * x + z + y + 1) * 4 + x, + x * 13 + z * 4 + y * 4 +6) + ck.verify((x + y + x + y * 3) / 2, y * 2 + x) + ck.verify((x + y + x + y * 3) % 2, 0) + ck.verify(x * 3 - 4 * x + 1, 1 - x) + ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2) + + +def test_split_index_simplify(): + ck = CanonicalChecker() + x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") + ck.verify((x/3) *3 + x % 3, x) + ck.verify((x/6) * 6 + ((x/3) % 2) * 3 + x % 3, x) + + # split div const + ck.verify(((x % 16) / 2) * 2 / 4, (x % 16) / 4) + ck.verify((x % 2) / 8, 0) + ck.verify((x % 2) / 7, 0) + ck.verify(((x % 16) / 2) * 2 / 6, (x % 16) / 6) + + # split mod const + ck.verify((x * 8) % 16, (x % 2) * 8) + ck.verify((x * 8) % 2, 0) + + # simplify then fold + ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000)) + ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000)) + ck.verify((x * 4 + y) / 2 * 2 + (x * 4 + y) % 2, x * 4 + y) + # complex fold + ck.verify((z * 9 + y) / 2 * 2 + (z * 9 + y) % 2, z * 9 + y) + + + +def test_div_simplify(): + ck = CanonicalChecker() + x = tvm.var("x") + ck.verify((16+48*x)/16, x*3 + 1) + # (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0 + # (17+48*x)/16 != 1+3*x + ck.verify((17+48*x)/16, (x * 48 + 17) / 16) + # However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified + ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10)) + ck.verify((17+48*x)/16, x * 3 + 1) + # Trying expressions that are not simplifiable for any values of the variables + ck.verify((17+47*x)/16, (x * 47 + 17) / 16) + + +def test_canonical_mixed(): + ck = CanonicalChecker() + x = tvm.var("x") + z = tvm.const(3, "int32") + ck.verify(x / (z*z) - x / (z*z), 0) + ck.verify(x / (z+z) - x / (z+z), 0) + + +def test_reduce_combiner_simplify(): + ck = CanonicalChecker() + dummy = tvm.var('dummy') + comm_reducer = tvm.comm_reducer + prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0)) + + sum_or_prod = comm_reducer( + lambda x, y: tvm.expr.Select(dummy < 0, + x + y, x*y), + lambda t0: tvm.expr.Select(dummy < 0, + tvm.const(0, t0), tvm.const(1, t0))) + sum_and_prod = comm_reducer( + lambda x, y: (x[0] + y[0], + x[1]*y[1]), + lambda t0, t1: (tvm.const(0, t0), + tvm.const(5, t0) - tvm.const(4, t0))) + some_reducer1 = comm_reducer( + lambda x, y: (x[0] + y[0], + x[0] + y[0] + x[1] + y[1], + x[0]*y[2] + y[0]*x[2], + x[1] + y[2], + 4.0), + lambda t0, t1, t2, t3, t4: (tvm.const(0, t0), + tvm.const(1, t1), + tvm.const(2, t2), + tvm.const(3, t3), + tvm.const(4, t4))) + + k = tvm.reduce_axis((0, 10), name="k") + A = tvm.placeholder((10,), name='A') + # Test that SimplifyCombiner makes use of vranges + ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, -4)) + ck.verify(sum_or_prod(A[k], k), tvm.sum(A[k], k)) + ck.analyzer.update(dummy, tvm.arith.ConstIntBound(5, 9), True) + ck.verify(sum_or_prod(A[k], k), prod(A[k], k)) + ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, 100), True) + ck.verify(sum_and_prod((A[k], A[10-k]), k)[0], tvm.sum(A[k], k)) + ck.verify(sum_and_prod((A[k], A[10-k]), k)[1], prod(A[10-k], k)) + + reference_simplified_sources = [[A[0]], + [A[0], A[1]], + [A[0], A[2]], + [A[0], A[1], A[2], A[3]], + [A[4]]] + for j in range(5): + # Here we use the j-th component of the result, so only it and the components it + # depends on are left. + simplified = ck.analyzer.canonical_simplify( + some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j]) + + # Check that the remaining components are the expected ones. + for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]): + assert tvm.ir_pass.Equal(lhs, rhs) + + # Test that components with side effects are not removed + side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call.Intrinsic, None, 0) + ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0], + sum_and_prod((A[k], side_effect(A[10-k])), k)[0]) + ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0], + tvm.sum(side_effect(A[k]), k)) + + +def test_reduce_simplify(): + ck = CanonicalChecker() + k = tvm.reduce_axis((0, 10), name="k") + j = tvm.reduce_axis((-5, 3), name="j") + A = tvm.placeholder((10,), name='A') + ck.verify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j]), + tvm.sum(k + j, [k, j])) + ck.verify(tvm.sum(A[3], []), A[3]) + # The rule below is not typical, removed for now + ck.verify(tvm.sum(k / 10, k), tvm.sum(tvm.const(0, "int32"), k)) + + +if __name__ == "__main__": + test_div_simplify() + test_reduce_simplify() + test_reduce_combiner_simplify() + test_mul_sum_simplify() + test_split_index_simplify() + test_canonical_mixed() diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py index 71818708fbf6..a327650fd045 100644 --- a/tests/python/unittest/test_arith_simplify.py +++ b/tests/python/unittest/test_arith_simplify.py @@ -21,7 +21,6 @@ def test_simplify(): assert zz.a == x and zz.b.value == 4 n = tvm.var('n') - assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % (-1)), tvm.const(0, "int32")) assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % 1), tvm.const(0, "int32")) assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / 1), n) tvm.ir_pass.CanonicalSimplify(n / (-1)) @@ -29,36 +28,16 @@ def test_simplify(): # assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / (-1)), # tvm.ir_pass.CanonicalSimplify(-n)) -def test_simplify_div(): - x = tvm.var('x') - assert tvm.ir_pass.CanonicalSimplify((16+48*x)/16 - (1 + (x*3))).value == 0 - # (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0 - # (17+48*x)/16 != 1+3*x - r = tvm.ir_pass.CanonicalSimplify((17+48*x)/16) - assert r.b.value == 16 - assert tvm.ir_pass.CanonicalSimplify(r.a - (17 + 48*x)).value == 0 - # However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified - assert tvm.ir_pass.CanonicalSimplify((17+48*x)/16 - (1 + (x*3)), {x: tvm.Range(0,10)}).value == 0 - - # Trying expressions that are not simplifiable for any values of the variables - r = tvm.ir_pass.CanonicalSimplify((17+47*x)/16, {x: tvm.Range(0,10)}) - assert r.b.value == 16 - assert tvm.ir_pass.CanonicalSimplify(r.a - (17+47*x)).value == 0 - - r = tvm.ir_pass.CanonicalSimplify((8*x - 17)/8, {x : tvm.Range(4,10)}) - assert tvm.ir_pass.CanonicalSimplify(r - (x-3)).value == 0 - def test_simplify_mod(): - """Not yet working, mock design""" ib = tvm.ir_builder.create() n = tvm.var('n') - j = tvm.var('j') A = ib.pointer("float32", name="A") - with ib.for_range(0, 16, name="i") as i: - A[i] = A[((n * 4 + j * 2) * 8 + i+1) % 16] + with ib.for_range(0, 10, name="j") as j: + with ib.for_range(0, 16, name="i") as i: + A[i] = A[(j * 32 + i+1) % 16] body = ib.get() stmt = tvm.ir_pass.CanonicalSimplify(body) - diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16) + diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16) assert diff.value == 0 # if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16 index = tvm.ir_pass.CanonicalSimplify( @@ -95,8 +74,8 @@ def test_modular(): y: tvm.Range(i32_const(0), i32_const(2)), x: tvm.Range(i32_const(0), i32_const(14))} idx = ry * 16 + rx + y * 16 + x - z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap) z2 = tvm.ir_pass.CanonicalSimplify(idx % 16, vmap) + z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap) assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0 assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0 @@ -117,10 +96,9 @@ def test_const_propagation(): if __name__ == "__main__": - test_simplify_div() - test_simplify_mod() test_modular() test_simplify() test_mul() test_simplify_minmax() test_const_propagation() + test_simplify_mod() diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py index 939a08f5b8c2..e59528e875a0 100644 --- a/tests/python/unittest/test_pass_simplify.py +++ b/tests/python/unittest/test_pass_simplify.py @@ -35,109 +35,8 @@ def test_bound(): ret = tvm.ir_pass.Simplify(m % 10, vrange) assert ret == m -def test_canonical(): - x = tvm.var("x") - z = tvm.const(3, "int32") - ret = tvm.ir_pass.CanonicalSimplify(x / (z*z) - x / (z*z)) - assert(tvm.ir_pass.Equal(ret, 0)) - - ret = tvm.ir_pass.CanonicalSimplify(x / (z+z) - x / (z+z)) - assert(tvm.ir_pass.Equal(ret, 0)) - - #make sure terms are ordered based on their top operators (e.g., / always precedes %) - ret1 = tvm.ir_pass.CanonicalSimplify(x % 3 + x / 3) - ret2 = tvm.ir_pass.CanonicalSimplify(x / 3 + x % 3) - assert(tvm.ir_pass.Equal(ret1, ret2)) - - #when top operators match, compare string representation of terms - ret1 = tvm.ir_pass.CanonicalSimplify(x % 4 + x % 3) - ret2 = tvm.ir_pass.CanonicalSimplify(x % 3 + x % 4) - assert (tvm.ir_pass.Equal(ret1, ret2)) - - -def test_simplify_combiner(): - dummy = tvm.var('dummy') - - prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0)) - - sum_or_prod = comm_reducer(lambda x, y: tvm.expr.Select(dummy < 0, - x + y, x*y), - lambda t0: tvm.expr.Select(dummy < 0, - tvm.const(0, t0), tvm.const(1, t0))) - - sum_and_prod = comm_reducer(lambda x, y: (x[0] + y[0], - x[1]*y[1]), - lambda t0, t1: (tvm.const(0, t0), - tvm.const(5, t0) - tvm.const(4, t0))) - - sum_and_prod2 = comm_reducer(lambda x, y: (x[0] + y[0], - x[1]*y[1] + 0*x[0] + y[0] - y[0]), - lambda t0, t1: (tvm.const(5, t0) - tvm.const(5, t0), - tvm.const(1, t1))) - - some_reducer1 = comm_reducer(lambda x, y: (x[0] + y[0], - x[0] + y[0] + x[1] + y[1], - x[0]*y[2] + y[0]*x[2], - x[1] + y[2], - 4.0), - lambda t0, t1, t2, t3, t4: (tvm.const(0, t0), - tvm.const(1, t1), - tvm.const(2, t2), - tvm.const(3, t3), - tvm.const(4, t4))) - - k = tvm.reduce_axis((0, 10), name="k") - A = tvm.placeholder((10,), name='A') - - # Test that SimplifyCombiner makes use of vranges - vrange = {dummy: tvm.Range(-10, -5)} - assert Equal(Simplify(sum_or_prod(A[k], k), vrange), tvm.sum(A[k], k)) - vrange = {dummy: tvm.Range(5, 10)} - assert Equal(Simplify(sum_or_prod(A[k], k), vrange), prod(A[k], k)) - - assert Equal(Simplify(sum_and_prod((A[k], A[10-k]), k)[0]), tvm.sum(A[k], k)) - assert Equal(Simplify(sum_and_prod((A[k], A[10-k]), k)[1]), prod(A[10-k], k)) - - assert Equal(Simplify(sum_and_prod2((A[k], A[10-k]), k)[0]), tvm.sum(A[k], k)) - assert Equal(Simplify(sum_and_prod2((A[k], A[10-k]), k)[1]), prod(A[10-k], k)) - - reference_simplified_sources = [[A[0]], - [A[0], A[1]], - [A[0], A[2]], - [A[0], A[1], A[2], A[3]], - [A[4]]] - for j in range(5): - # Here we use the j-th component of the result, so only it and the components it - # depends on are left. - simplified = Simplify(some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j]) - - # Check that the remaining components are the expected ones. - for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]): - assert Equal(lhs, rhs) - - # Test that components with side effects are not removed - side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call.Intrinsic, None, 0) - assert Equal(Simplify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0]), - sum_and_prod((A[k], side_effect(A[10-k])), k)[0]) - assert Equal(Simplify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0]), - tvm.sum(side_effect(A[k]), k)) - - -def test_simplify_reduce(): - k = tvm.reduce_axis((0, 10), name="k") - j = tvm.reduce_axis((-5, 3), name="j") - A = tvm.placeholder((10,), name='A') - - assert Equal(Simplify(tvm.sum(k/10, k)), tvm.sum(tvm.const(0, "int32"), k)) - assert Equal(Simplify(tvm.sum(A[3], [])), A[3]) - assert Equal(Simplify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j])), - tvm.sum(k + j, [k, j])) - if __name__ == "__main__": test_bound() test_basic() test_simplify() - test_canonical() - test_simplify_combiner() - test_simplify_reduce()