diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index f8412dc3666b..b11486d9023a 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -62,6 +62,13 @@ bool HasSideEffect(const Expr& e); */ Stmt ConvertSSA(Stmt stmt); +/*! + * \brief Simplify by applying canonical form. + * \param stmt The statement to be canonically simplifed. + * \return Canonicalized statement. + */ +Stmt CanonicalSimplify(Stmt stmt); + /*! * \brief Substitute the var specified in key->var to be value. * \param stmt The source statement to be substituted diff --git a/python/tvm/build.py b/python/tvm/build.py index fbed0a33f849..bb03e8395687 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -17,7 +17,8 @@ def build(sch, target, name="default_function", binds=None, - record_codes=None): + record_codes=None, + max_auto_unroll_step=8): """Build a function with arguments as signiture. Parameters @@ -38,6 +39,9 @@ def build(sch, Dictionary that maps the binding of symbolic buffer to Tensor. By default, a new buffer is created for each tensor in the argument. + max_auto_unroll_step: int + Maximum step to perform automatic unrolling + Returns ------- f : Function, or pair of functions @@ -64,6 +68,8 @@ def build(sch, bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) stmt = ir_pass.StorageFlatten(stmt, binds) + stmt = ir_pass.CanonicalSimplify(stmt) + stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step) stmt = ir_pass.Simplify(stmt) fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list)) fsplits = ir_pass.SplitHostDevice(fapi) diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index df79996e4a6f..ff67ac7a867b 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -59,6 +59,7 @@ TVM_REGISTER_API(_pass_PostOrderVisit) REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(VerifySSA); +REGISTER_PASS1(CanonicalSimplify); REGISTER_PASS4(Inline); REGISTER_PASS2(StorageFlatten); REGISTER_PASS2(UnrollLoop); diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc new file mode 100644 index 000000000000..2c99094551c5 --- /dev/null +++ b/src/arithmetic/canonical.cc @@ -0,0 +1,486 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file canonical.cc + * \brief Canonicalize simplification. + */ +#include +#include "./int_set.h" +#include "./canonical.h" +#include "./compute_expr.h" + +namespace tvm { +namespace arith { +using namespace ir; + +// Canonical entry for communicative ops. +struct ComExprEntry { + // the value of the expression. + Expr value; + // the level of the expression. + int level{0}; + // The integer scale on value + int64_t scale{1}; + + ComExprEntry() {} + ComExprEntry(Expr value, int level) + : value(value), level(level) {} + inline bool operator<(const ComExprEntry& other) const { + if (level < other.level) return true; + if (level > other.level) return false; + return value.get() < other.value.get(); + } +}; + +// canonical expression for communicative expression. +struct ComExprNode { + // base constant value. + int64_t base{0}; + // The values to be sumed. + std::vector elem; +}; + +// canonical communicative expression +struct ComExpr { + public: + // constructor + ComExpr() {} + explicit ComExpr(std::shared_ptr ptr) : ptr_(ptr) {} + // get member + ComExprNode* operator->() const { + return ptr_.get(); + } + void reset() { + ptr_.reset(); + } + bool defined() const { + return ptr_.get() != nullptr; + } + // comparator + bool operator<(const ComExpr& b) const { + const ComExpr& a = *this; + if (a->base < b->base) return true; + if (a->base > b->base) return false; + if (a->elem.size() < b->elem.size()) return true; + if (a->elem.size() > b->elem.size()) return false; + for (size_t i = 0; i < a->elem.size(); ++i) { + const ComExprEntry& ea = a->elem[i]; + const ComExprEntry& eb = b->elem[i]; + if (ea.level < eb.level) return true; + if (ea.level > eb.level) return false; + if (ea.value.get() < eb.value.get()) return true; + if (ea.value.get() > eb.value.get()) return false; + if (ea.scale < eb.scale) return true; + if (ea.scale > eb.scale) return false; + } + return false; + } + // equality + bool operator==(const ComExpr& b) const { + const ComExpr& a = *this; + if (a->base != b->base) return false; + if (a->elem.size() != b->elem.size()) return false; + for (size_t i = 0; i < a->elem.size(); ++i) { + const ComExprEntry& ea = a->elem[i]; + const ComExprEntry& eb = b->elem[i]; + if (ea.level != eb.level) return false; + if (ea.value.get() != eb.value.get()) return false; + if (ea.scale != eb.scale) return false; + } + return true; + } + + private: + std::shared_ptr ptr_; +}; + +template +inline Expr Binary_(const T* op, + const Expr& e, + Expr a, Expr b) { + if (a.same_as(op->a) && b.same_as(op->b)) { + return e; + } else { + return T::make(a, b); + } +} + +template +inline Expr Binary( + const T* op, const Expr& e, IRMutator* m) { + return Binary_(op, e, m->Mutate(op->a), m->Mutate(op->b)); +} + +// internal of canonical engine. +class Canonical::Internal : public IRMutator { + public: + // stack entry. + struct StackEntry { + int max_level{0}; + bool has_side_effect{false}; + }; + // aggressively canonicalized expression + struct CacheEntry { + // The canonical value of the expression. + Expr value; + // The level of the expression. + int max_level{0}; + // whether the expression might have side effect. + bool has_side_effect{false}; + // if not null, corresponds to to sum + ComExpr sum; + // reset the return entry. + void reset() { + sum.reset(); + } + // as sum expr + ComExpr AsSum() const { + if (sum.defined()) return sum; + const int64_t *v1 = as_const_int(value); + const uint64_t *v2 = as_const_uint(value); + std::shared_ptr n = std::make_shared(); + if (v1) { + n->base = *v1; + } else if (v2) { + CHECK_LE(*v2, + static_cast(std::numeric_limits::max())); + n->base = static_cast(*v2); + } else { + n->elem.push_back(ComExprEntry(value, max_level)); + } + return ComExpr(n); + } + }; + // Set range and level of var. + void SetRange(Var v, Range r, int level) { + var_range_[v.get()] = IntSet::range(r); + var_level_[v.get()] = level; + var_rec_.push_back(v); + } + // functions + Stmt Mutate(Stmt stmt) final { + return IRMutator::Mutate(stmt); + } + Expr MutateExpr_(Expr expr) { + static const FMutateExpr& f = Internal::vtable_expr(); + stack_.push_back(StackEntry()); + expr = (f.can_dispatch(expr) ? + f(expr, expr, this) : IRMutator::Mutate(expr)); + // update result of parent automatically during pop + if (stack_.size() > 1) { + StackEntry& back = stack_[stack_.size() - 1]; + StackEntry& prev = stack_[stack_.size() - 2]; + prev.max_level = std::max(prev.max_level, back.max_level); + if (back.has_side_effect) prev.has_side_effect = true; + } + // copy result from stack + ret_entry_.has_side_effect = stack_.back().has_side_effect; + ret_entry_.max_level = stack_.back().max_level; + stack_.pop_back(); + return expr; + } + // call produce to get a cache entry. + CacheEntry Produce(Expr expr) { + ret_entry_.reset(); + ret_entry_.value = MutateExpr_(expr); + CacheEntry ret = ret_entry_; + ret_entry_.reset(); + return ret; + } + Expr Mutate(Expr expr) final { + ret_entry_.reset(); + expr = MutateExpr_(expr); + ret_entry_.reset(); + return expr; + } + + // Check whether do special canonicalization. + bool EnableOpt(Type t) const { + return (t.lanes() == 1 && (t.is_int() || t.is_uint())); + } + // Add + Expr Mutate_(const Add* op, const Expr& e) { + if (!EnableOpt(op->type)) { + return Binary(op, e, this); + } + CacheEntry a = Produce(op->a); + CacheEntry b = Produce(op->b); + if (a.has_side_effect || b.has_side_effect) { + return Binary_(op, e, a.value, b.value); + } + return SumAdd(a, b, +1); + } + // Sub + Expr Mutate_(const Sub* op, const Expr& e) { + if (!EnableOpt(op->type)) { + return Binary(op, e, this); + } + CacheEntry a = Produce(op->a); + CacheEntry b = Produce(op->b); + if (a.has_side_effect || b.has_side_effect) { + return Binary_(op, e, a.value, b.value); + } + return SumAdd(a, b, -1); + } + // Mul + Expr Mutate_(const Mul* op, const Expr& e) { + if (!EnableOpt(op->type)) { + return Binary(op, e, this); + } + CacheEntry a = Produce(op->a); + CacheEntry b = Produce(op->b); + if (a.has_side_effect || b.has_side_effect) { + return Binary_(op, e, a.value, b.value); + } + if (is_const(a.value) && is_const(b.value)) { + return ComputeExpr(a.value, b.value); + } else if (is_const(a.value)) { + return SumMulConst(b.AsSum(), a.value); + } else if (is_const(b.value)) { + return SumMulConst(a.AsSum(), b.value); + } else { + return Binary_(op, e, a.value, b.value); + } + } + // Variable + Expr Mutate_(const Variable* op, const Expr& e) final { + auto it = var_level_.find(op); + if (it != var_level_.end()) { + stack_.back().max_level = it->second; + } + return IRMutator::Mutate_(op, e); + } + // comparison + Expr Mutate_(const LT* op, const Expr& e) { + if (!EnableOpt(op->a.type())) { + return Binary(op, e, this); + } + CacheEntry a = Produce(op->a); + CacheEntry b = Produce(op->b); + if (a.has_side_effect || b.has_side_effect) { + return Binary_(op, e, a.value, b.value); + } + Expr b_sub_a = SumAdd(b, a, -1); + if (EvalSet(b_sub_a, var_range_).can_prove_positive()) { + return make_const(op->type, true); + } else { + return Binary_(op, e, a.value, b.value); + } + } + // Call + Expr Mutate_(const Call* op, const Expr& e) final { + if (!op->is_pure()) { + stack_.back().has_side_effect = true; + } + return IRMutator::Mutate_(op, e); + } + // For + Stmt Mutate_(const For* op, const Stmt& s) { + ++level_counter_; + Var loop_var(op->loop_var.node_); + this->SetRange(loop_var, + Range::make_with_min_extent(op->min, op->extent), + level_counter_); + Stmt stmt = IRMutator::Mutate_(op, s); + --level_counter_; + return stmt; + } + // AttrStmt + Stmt Mutate_(const AttrStmt* op, const Stmt& s) { + if (op->type_key == "thread_extent") { + ++level_counter_; + IterVar iv(op->node.node_); + CHECK_NE(iv->thread_tag.length(), 0U); + if (!var_level_.count(iv->var.get())) { + this->SetRange(iv->var, + Range::make_with_min_extent(0, op->value), + level_counter_); + } + Stmt stmt = IRMutator::Mutate_(op, s); + --level_counter_; + return stmt; + } else { + return IRMutator::Mutate_(op, s); + } + } + // The simplify statement. + static FMutateExpr& vtable_expr() { // NOLINT(*) + static FMutateExpr inst; return inst; + } + + private: + // return entry + CacheEntry ret_entry_; + // internal information stack + std::vector stack_; + // cache sum + std::map cache_sum_; + // range of each var + std::unordered_map var_range_; + // level of each var + std::unordered_map var_level_; + // record history vars, to avoid false positive. + std::vector var_rec_; + // level counter + int level_counter_{0}; + // subroutine to do produce + Expr SumMulConst(ComExpr a, Expr v) { + int64_t value = 0; + const int64_t *v1 = as_const_int(v); + const uint64_t *v2 = as_const_uint(v); + CHECK(v1 || v2); + if (v1) { + value = *v1; + } else if (v2) { + CHECK_LE(*v2, + static_cast(std::numeric_limits::max())); + value = static_cast(*v2); + } + + if (value == 0) { + return make_zero(v.type()); + } + std::shared_ptr vsum = + std::make_shared(*a.operator->()); + vsum->base *= value; + for (auto& e : vsum->elem) { + e.scale *= value; + } + ret_entry_.max_level = stack_.back().max_level; + ret_entry_.has_side_effect = stack_.back().has_side_effect; + ret_entry_.sum = ComExpr(vsum); + auto it = cache_sum_.find(ret_entry_.sum); + if (it != cache_sum_.end()) { + ret_entry_ = it->second; + } else { + ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type()); + cache_sum_[ret_entry_.sum] = ret_entry_; + } + return ret_entry_.value; + } + // add two ComExpr together + ComExpr SumAdd_(const ComExpr& suma, + const ComExpr& sumb, + int bscale) { + std::shared_ptr n = std::make_shared(); + n->base = suma->base + sumb->base; + // merge of suma and sumb; + size_t i = 0, j = 0; + while (i < suma->elem.size() && j < sumb->elem.size()) { + const auto& a = suma->elem[i]; + const auto& b = sumb->elem[j]; + if (a.value.same_as(b.value)) { + CHECK_EQ(a.level, b.level); + ComExprEntry e = a; + e.scale = a.scale + b.scale * bscale; + if (e.scale != 0) { + n->elem.push_back(e); + } + ++i; ++j; + } else if (a < b) { + n->elem.push_back(a); + ++i; + } else { + ComExprEntry e = b; + e.scale *= bscale; + n->elem.push_back(e); + ++j; + } + } + for (; i < suma->elem.size(); ++i) { + n->elem.push_back(suma->elem[i]); + } + for (; j < sumb->elem.size(); ++j) { + ComExprEntry e = sumb->elem[j]; + e.scale *= bscale; + n->elem.push_back(e); + } + return ComExpr(n); + } + // subroutine to do produce + Expr SumAdd(CacheEntry a, CacheEntry b, int bscale) { + ret_entry_.sum = SumAdd_(a.AsSum(), b.AsSum(), bscale); + ret_entry_.max_level = stack_.back().max_level; + ret_entry_.has_side_effect = stack_.back().has_side_effect; + auto it = cache_sum_.find(ret_entry_.sum); + if (it != cache_sum_.end()) { + ret_entry_ = it->second; + } else { + ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type()); + cache_sum_[ret_entry_.sum] = ret_entry_; + } + ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type()); + cache_sum_[ret_entry_.sum] = ret_entry_; + return ret_entry_.value; + } + // convert sum to expr + Expr Sum2Expr(const ComExpr& com, Type t) { + Expr vsum; + if (com->base != 0) { + vsum = make_const(t, com->base); + } + for (const ComExprEntry& e : com->elem) { + if (e.scale > 0) { + Expr v = e.value; + if (e.scale != 1) { + v = Mul::make(v, make_const(t, e.scale)); + } + if (vsum.defined()) { + vsum = Add::make(vsum, v); + } else { + vsum = v; + } + } + } + for (const ComExprEntry& e : com->elem) { + if (e.scale < 0) { + Expr v = e.value; + if (e.scale != -1) { + v = Mul::make(v, make_const(t, -e.scale)); + } + if (vsum.defined()) { + vsum = Sub::make(vsum, v); + } else { + vsum = Sub::make(make_zero(t), v); + } + } + } + return vsum; + } +}; + +using CInternal = Canonical::Internal; + +#define DISPATCH_EXPR(OP) \ + set_dispatch([](const OP *op, const Expr& e, IRMutator* p) { \ + return static_cast(p)->Mutate_(op, e); }) + +TVM_STATIC_IR_FUNCTOR(CInternal, vtable_expr) +.DISPATCH_EXPR(Add) +.DISPATCH_EXPR(Sub) +.DISPATCH_EXPR(Mul) +.DISPATCH_EXPR(LT); + + +Canonical::Canonical() + : ptr_(std::make_shared()) {} + +Expr Canonical::Simplify(Expr expr) { + return ptr_->Mutate(expr); +} + +Stmt Canonical::Simplify(Stmt stmt) { + return ptr_->Mutate(stmt); +} + +void Canonical::SetRange(Var v, Range r, int level) { + ptr_->SetRange(v, r, level); +} +} // namespace arith + +namespace ir { +Stmt CanonicalSimplify(Stmt stmt) { + return arith::Canonical().Simplify(stmt); +} + +} // namespace ir +} // namespace tvm diff --git a/src/arithmetic/canonical.h b/src/arithmetic/canonical.h new file mode 100644 index 000000000000..174acc20aebe --- /dev/null +++ b/src/arithmetic/canonical.h @@ -0,0 +1,55 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file canonical.h + * \brief Internal canonicalized expression simplification engine. + */ +#ifndef TVM_ARITHMETIC_CANONICAL_H_ +#define TVM_ARITHMETIC_CANONICAL_H_ + +#include +#include + +namespace tvm { +namespace arith { + +/*! + * \brief A stateful CanonicalEngine over SSA. + * + * Simplify and CSE with canonicalization expressions. + * Each call's result will get cached, so next call will + * simply return the cached result. + */ +class Canonical { + public: + /*! \brief constructor */ + Canonical(); + /*! + * \brief simplify expression e. + * \param expr The expression to be simplified. + */ + Expr Simplify(Expr expr); + /*! + * \brief simplify stmt. + * \param stmt The stmt to be simplified. + */ + Stmt Simplify(Stmt expr); + /*! + * \brief Set range and level variable + * \param v The variable + * \param r The range of the variable, can be undefined. + * \param level The scope level of the variable, + * affect the order of formula in communicative ops. + */ + void SetRange(Var v, Range r, int level); + + class Internal; + private: + // Internal pointer + std::shared_ptr ptr_; +}; + + +} // namespace arith +} // namespace tvm + +#endif // TVM_ARITHMETIC_CANONICAL_H_ diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 04b40191de11..d60504f2c51e 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -94,6 +94,11 @@ bool IntSet::is_single_point() const { return (s_int && s_int->i.is_single_point()); } +bool IntSet::can_prove_positive() const { + const IntervalSet* s_int = (*this).as(); + return (s_int && is_positive_const(ir::Simplify(s_int->i.min))); +} + Expr IntSet::point_value() const { const IntervalSet* s_int = (*this).as(); CHECK(s_int && s_int->i.is_single_point()); @@ -358,6 +363,9 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) { // Evaluator to evalute the epxression. class IntSetEvaluator { public: + explicit IntSetEvaluator(const std::unordered_map& dom_map) + : dom_map(dom_map) {} + inline IntSet Eval(Expr expr) { static const FType& f = vtable(); if (f.can_dispatch(expr)) { @@ -373,7 +381,7 @@ class IntSetEvaluator { static FType inst; return inst; } - std::unordered_map dom_map; + const std::unordered_map& dom_map; }; inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator*) { @@ -424,21 +432,29 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) .set_dispatch(Binary) .set_dispatch(Binary); + +IntSet EvalSet(Expr e, + const std::unordered_map& dom_map) { + return IntSetEvaluator(dom_map).Eval(e); +} + IntSet EvalSet(Expr e, const Map& dom_map) { - IntSetEvaluator m; + std::unordered_map dmap; for (auto kv : dom_map) { - m.dom_map[kv.first->var.as()] = kv.second; + dmap[kv.first->var.as()] = kv.second; } + IntSetEvaluator m(dmap); return m.Eval(e); } IntSet EvalSet(Range r, const Map& dom_map) { - IntSetEvaluator m; + std::unordered_map dmap; for (auto kv : dom_map) { - m.dom_map[kv.first->var.as()] = kv.second; + dmap[kv.first->var.as()] = kv.second; } + IntSetEvaluator m(dmap); IntSet min_set = m.Eval(r->min); IntSet ext_set = m.Eval(r->extent).cover_interval(); const Interval& ei = ext_set.as()->i; diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 80c2fae79146..979d138af9e2 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -44,6 +44,8 @@ class IntSet : public NodeRef { bool is_everything() const; /*! \return Whether the set is a single point */ bool is_single_point() const; + /*! \return Whether the set is proved to be bigger than 0 */ + bool can_prove_positive() const; /*! * \brief The single point value, call only if is_single_point is true * \return The point value. @@ -88,6 +90,8 @@ struct IntSetNode : public Node { */ IntSet EvalSet(Expr e, const Map& dom_map); +IntSet EvalSet(Expr e, + const std::unordered_map& dom_map); /*! * \brief Find an symbolic integer set that contains is union over diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index b4957a3d543e..9098200bdf27 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -45,7 +45,7 @@ MakeNVRTC(Array funcs) { std::ostringstream os; os << "typedef int int32_t;\n" << "typedef unsigned unt32_t;\n"; - bool output_ssa = true; + bool output_ssa = false; for (LoweredFunc f : funcs) { os << CodeGenCUDA().Compile(f, output_ssa); os << '\n'; diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index 3d54a66a8251..bafb56deb656 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -57,7 +57,7 @@ MakeOpenCL(Array funcs) { std::ostringstream os; os << "typedef int int32_t;\n" << "typedef unsigned unt32_t;\n"; - bool output_ssa = true; + bool output_ssa = false; for (LoweredFunc f : funcs) { os << CodeGenOpenCL().Compile(f, output_ssa); os << '\n'; diff --git a/tests/python/integration/test_gemm.py b/tests/python/integration/test_gemm.py index ac5c5c2c4b66..8b63d8c08e4c 100644 --- a/tests/python/integration/test_gemm.py +++ b/tests/python/integration/test_gemm.py @@ -3,9 +3,9 @@ def test_gemm(): # graph - nn = 1235 + nn = 1024 n = tvm.Var('n') - #n = tvm.convert(nn) + n = tvm.convert(nn) m = n l = n A = tvm.placeholder((n, l), name='A') @@ -52,12 +52,14 @@ def test_gemm(): _, xi = s[BB].split(s[BB].op.axis[0], outer=thread_y) _, xi = s[BB].split(xi, outer=thread_x) + max_auto_unroll_step = 0 # lowering test s.normalize() def check_device(target): codes = [] - f = tvm.build(s, [A, B, C], target, record_codes=codes) + f = tvm.build(s, [A, B, C], target, record_codes=codes, + max_auto_unroll_step=max_auto_unroll_step) for c in codes[1:]: print(c) if target == "cuda": diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py new file mode 100644 index 000000000000..9002b9686675 --- /dev/null +++ b/tests/python/unittest/test_pass_simplify.py @@ -0,0 +1,26 @@ +import tvm +import numpy + +def test_simplify(): + """Not yet working, mock design""" + dtype = 'int64' + n = tvm.Var('n') + Ab = tvm.Buffer((n, ), dtype) + i = tvm.Var('i') + j = tvm.Var('j') + # for i in 0 to n-1: + stmt = tvm.make.For( + i, 2, n, 0, 0, + tvm.make.For(j, 0, n, 0, 0, + tvm.make.IfThenElse( + tvm.make.LT(i + 2, n), + tvm.make.Store(Ab.data, + tvm.make.Load(dtype, Ab.data, i + 4) + 1, + (j + 1) * 4 - 4 * j + i), + None))) + print(stmt) + stmt = tvm.ir_pass.CanonicalSimplify(stmt) + print(stmt) + +if __name__ == "__main__": + test_simplify()