From 4a317972147d6973c7367865ba56a903780742ab Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Mon, 27 Feb 2017 03:22:05 +0000 Subject: [PATCH 01/11] loop_partition draft --- include/tvm/ir_mutator.h | 48 +++--- include/tvm/ir_pass.h | 2 + src/api/api_arith.cc | 3 +- src/api/api_pass.cc | 1 + src/arithmetic/bound_deducer.cc | 33 ++-- src/arithmetic/int_set.h | 5 +- src/pass/loop_partition.cc | 151 ++++++++++++++++++ .../unittest/test_pass_loop_partition.py | 21 +++ 8 files changed, 224 insertions(+), 40 deletions(-) create mode 100644 src/pass/loop_partition.cc create mode 100644 tests/python/unittest/test_pass_loop_partition.py diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index c428232698e8..12dc809aaa38 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -66,37 +66,37 @@ class IRMutator { virtual Stmt Mutate_(const Let* op, const Stmt& s); virtual Stmt Mutate_(const Free* op, const Stmt& s); virtual Stmt Mutate_(const Call* op, const Stmt& s); - virtual Stmt Mutate_(const Add* op, const Stmt& e); - virtual Stmt Mutate_(const Sub* op, const Stmt& e); - virtual Stmt Mutate_(const Mul* op, const Stmt& e); - virtual Stmt Mutate_(const Div* op, const Stmt& e); - virtual Stmt Mutate_(const Mod* op, const Stmt& e); - virtual Stmt Mutate_(const Min* op, const Stmt& e); - virtual Stmt Mutate_(const Max* op, const Stmt& e); - virtual Stmt Mutate_(const EQ* op, const Stmt& e); - virtual Stmt Mutate_(const NE* op, const Stmt& e); - virtual Stmt Mutate_(const LT* op, const Stmt& e); - virtual Stmt Mutate_(const LE* op, const Stmt& e); - virtual Stmt Mutate_(const GT* op, const Stmt& e); - virtual Stmt Mutate_(const GE* op, const Stmt& e); - virtual Stmt Mutate_(const And* op, const Stmt& e); - virtual Stmt Mutate_(const Or* op, const Stmt& e); + virtual Stmt Mutate_(const Add* op, const Stmt& s); + virtual Stmt Mutate_(const Sub* op, const Stmt& s); + virtual Stmt Mutate_(const Mul* op, const Stmt& s); + virtual Stmt Mutate_(const Div* op, const Stmt& s); + virtual Stmt Mutate_(const Mod* op, const Stmt& s); + virtual Stmt Mutate_(const Min* op, const Stmt& s); + virtual Stmt Mutate_(const Max* op, const Stmt& s); + virtual Stmt Mutate_(const EQ* op, const Stmt& s); + virtual Stmt Mutate_(const NE* op, const Stmt& s); + virtual Stmt Mutate_(const LT* op, const Stmt& s); + virtual Stmt Mutate_(const LE* op, const Stmt& s); + virtual Stmt Mutate_(const GT* op, const Stmt& s); + virtual Stmt Mutate_(const GE* op, const Stmt& s); + virtual Stmt Mutate_(const And* op, const Stmt& s); + virtual Stmt Mutate_(const Or* op, const Stmt& s); virtual Stmt Mutate_(const Reduce* op, const Stmt& s); virtual Stmt Mutate_(const Cast* op, const Stmt& s); virtual Stmt Mutate_(const Not* op, const Stmt& s); virtual Stmt Mutate_(const Select* op, const Stmt& s); virtual Stmt Mutate_(const Ramp* op, const Stmt& s); - virtual Stmt Mutate_(const Broadcast* op, const Stmt& e); - virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e); - virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e); - virtual Stmt Mutate_(const Provide* op, const Stmt& e); + virtual Stmt Mutate_(const Broadcast* op, const Stmt& s); + virtual Stmt Mutate_(const AssertStmt* op, const Stmt& s); + virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& s); + virtual Stmt Mutate_(const Provide* op, const Stmt& s); virtual Stmt Mutate_(const Realize* op, const Stmt& s); virtual Stmt Mutate_(const Block* op, const Stmt& s); - virtual Stmt Mutate_(const Evaluate* op, const Stmt& e); - virtual Stmt Mutate_(const IntImm* op, const Stmt& e); - virtual Stmt Mutate_(const UIntImm* op, const Stmt& e); - virtual Stmt Mutate_(const FloatImm* op, const Stmt& e); - virtual Stmt Mutate_(const StringImm* op, const Stmt& e); + virtual Stmt Mutate_(const Evaluate* op, const Stmt& s); + virtual Stmt Mutate_(const IntImm* op, const Stmt& s); + virtual Stmt Mutate_(const UIntImm* op, const Stmt& s); + virtual Stmt Mutate_(const FloatImm* op, const Stmt& s); + virtual Stmt Mutate_(const StringImm* op, const Stmt& s); virtual Expr Mutate_(const Variable* op, const Expr& e); virtual Expr Mutate_(const LetStmt* op, const Expr& e); diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 542ec34424cd..915de41ddd6a 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -137,6 +137,8 @@ Stmt InjectVirtualThread(Stmt stmt); */ Stmt LiftAllocate(Stmt stmt); +Stmt LoopPartition(Stmt stmt); + /*! * \brief Make an user callable API LoweredFunc. * diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 7edbe3eec2a8..8fda741ef083 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -23,7 +23,8 @@ TVM_REGISTER_API(_arith_intset_interval) TVM_REGISTER_API(_arith_DeduceBound) .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = DeduceBound(args[0], args[1], args[2]); + *ret = DeduceBound(args[0], args[1], + args[2].operator Map()); }); TVM_REGISTER_API(_IntervalSetGetMin) diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 1192dc25dd76..f995f13d1cdc 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -69,6 +69,7 @@ REGISTER_PASS4(MakeAPI); REGISTER_PASS1(SplitHostDevice); REGISTER_PASS1(LiftAllocate); REGISTER_PASS1(InjectVirtualThread); +REGISTER_PASS1(LoopPartition); } // namespace ir } // namespace tvm diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index b83215c4a36a..b456d2d61a8f 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -21,7 +21,7 @@ using Halide::Internal::Interval; // from a expression. class VariablePathFinder: public IRVisitor { public: - explicit VariablePathFinder(Var target) : target_(target) {} + explicit VariablePathFinder(Expr target) : target_(target) {} void Visit(const NodeRef& node) final { if (visited_.count(node.get()) != 0) return; @@ -37,13 +37,13 @@ class VariablePathFinder: public IRVisitor { private: bool found_{false}; - Var target_; + Expr target_; std::unordered_set visited_; }; // get the path to the variable, // return empty vector to represent failure -std::vector GetPath(Var target, Expr expr) { +std::vector GetPath(Expr target, Expr expr) { VariablePathFinder v(target); v.Visit(expr); return v.path_; @@ -56,7 +56,7 @@ class BoundDeducer: public IRVisitor { public: friend class BoundDeduceInputChecker; friend class Converter; - BoundDeducer(Var target, Expr expr, + BoundDeducer(Expr target, Expr expr, const std::unordered_map& dom_map) : target_(target), expr_(expr), dom_map_(dom_map) {} @@ -137,7 +137,7 @@ class BoundDeducer: public IRVisitor { bool success{true}; private: - Var target_; + Expr target_; Expr expr_; const std::unordered_map& dom_map_; ExprIntSetMap expr_map_; @@ -205,15 +205,9 @@ void BoundDeducer::Deduce() { Visit(expr_); } -// assuming e >= 0, deduce the bound of variable from it. -// return empty set to represent deduce failure. -IntSet DeduceBound(Var v, Expr e, - const Map& dom_map) { - std::unordered_map dmap; - for (auto kv : dom_map) { - dmap[kv.first.get()] = kv.second; - } - BoundDeducer d(v, e, dmap); +IntSet DeduceBound(Expr v, Expr e, + const std::unordered_map dom_map) { + BoundDeducer d(v, e, dom_map); d.Deduce(); if (!d.success) return IntSet::nothing(); Expr min = Interval::neg_inf, max = Interval::pos_inf; @@ -225,5 +219,16 @@ IntSet DeduceBound(Var v, Expr e, return IntSet::interval(min, max); } +// assuming e >= 0, deduce the bound of variable from it. +// return empty set to represent deduce failure. +IntSet DeduceBound(Expr v, Expr e, + const Map& dom_map) { + std::unordered_map dmap; + for (auto kv : dom_map) { + dmap[kv.first.get()] = kv.second; + } + return DeduceBound(v, e, dmap); +} + } // namespace arith } // namespace tvm diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 2fc25f55be2d..eabd63425643 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -172,8 +172,11 @@ inline const IntSetNode* IntSet::operator->() const { * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values. */ -IntSet DeduceBound(Var v, Expr cond, +IntSet DeduceBound(Expr v, Expr cond, const Map& dom_map); +IntSet DeduceBound(Expr v, Expr e, + const std::unordered_map dom_map); + } // namespace arith } // namespace tvm diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc new file mode 100644 index 000000000000..a6a08d8d3d1f --- /dev/null +++ b/src/pass/loop_partition.cc @@ -0,0 +1,151 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file loop_partition.cc + */ +#include +#include +#include +#include +#include "../arithmetic/int_set.h" + +namespace tvm { +namespace ir { + +using arith::IntSet; +using Halide::Internal::const_true; +using Halide::Internal::const_false; +using Halide::Internal::Interval; // for pos_inf & neg_inf + +// a partition means condition is equal to true_value in the interval +struct Partition { + Expr condition; + Expr old_expr; + Expr true_value; + IntSet interval; +}; + +bool expr_use_var(Expr expr, Expr target) { + bool success = false; + PostOrderVisit(expr, [&target, &success](const NodeRef& node) { + if (node.same_as(target)) { + success = true; + return; + } + }); + return success; +} + +class PartitionFinder : public IRVisitor { + public: + explicit PartitionFinder(VarExpr loop_var) + : loop_var_(loop_var) {} + + void Visit_(const For* op) { + dom_map_[op->loop_var.get()] = IntSet::interval(op->min, op->min + op->extent - 1); + IRVisitor::Visit_(op); + } + + void Visit_(const IfThenElse* op) { + if (expr_use_var(op->condition, loop_var_)) { + IntSet interval = DeduceBound(loop_var_, op->condition, dom_map_); + if (interval.min().same_as(Interval::neg_inf)) { + IntSet upper_bound = EvalSet(interval.max(), dom_map_); + interval = IntSet::interval(interval.min(), upper_bound.min()); + } else if (interval.max().same_as(Interval::pos_inf)) { + IntSet lower_bound = EvalSet(interval.min(), dom_map_); + interval = IntSet::interval(lower_bound.max(), interval.max()); + } else { + // Assume the partition is always a infinite set + LOG(WARNING) << "interval wrong?"; + } + partitions.push_back(Partition{op->condition, op->condition, const_true(), interval}); + } + IRVisitor::Visit_(op); + } + + std::vector partitions; + private: + VarExpr loop_var_; + std::unordered_map dom_map_; +}; + +class PartitionReplacer : public IRMutator { + public: + PartitionReplacer(const Partition& p) + : p_(p) {} + + Expr Mutate(Expr e) override { + if (e.same_as(p_.old_expr)) { + return Mutate(p_.true_value); + } + return IRMutator::Mutate(e); + } + + Stmt Mutate(Stmt s) override { // ? will raise error if no this function + return IRMutator::Mutate(s); + } + + private: + const Partition& p_; +}; + +IntSet intersect(IntSet a, IntSet b) { // need move into IntSet + // (TODO) temp solution + return IntSet::interval(b.min(), a.max()); +} + +IntSet complement(IntSet s, IntSet u) { // need move into IntSet + // (TODO) temp solution + return IntSet::interval(s.max() + 1, u.max()); +} + +class LoopPartitioner : public IRMutator { + public: + explicit LoopPartitioner() {} + Expr Mutate(Expr e) override { + return IRMutator::Mutate(e); + } + Stmt Mutate(Stmt s) override { + return IRMutator::Mutate(s); + } + + Stmt Mutate_(const For* op, const Stmt& stmt) { // Simplify for this for loop + // (TODO) recursive + + PartitionFinder finder(op->loop_var); + finder.Visit(op->body); + + if (finder.partitions.empty()) { + // no available partition, return directly + return stmt; + } + + IntSet universe = IntSet::interval(op->min, op->min + op->extent - 1); + Stmt s; + // (TODO) in fact, we need to consider all partitions, then split + // the universe into multiple ranges + for (auto p : finder.partitions) { + IntSet true_itrv = intersect(p.interval, universe); + IntSet doubt_itrv = complement(true_itrv, universe); + + Stmt simplified_body = PartitionReplacer(p).Mutate(op->body); + Stmt simplified_stmt = For::make(op->loop_var, true_itrv.min(), + true_itrv.max() - true_itrv.min() + 1, op->for_type, op->device_api, simplified_body); + Stmt remaining_stmt = For::make(op->loop_var, doubt_itrv.min(), + doubt_itrv.max() - doubt_itrv.min() + 1, op->for_type, op->device_api, op->body); + s = Block::make(simplified_stmt, remaining_stmt); + } + return s; + } + + private: + +}; + +Stmt LoopPartition(Stmt stmt) { + stmt = LoopPartitioner().Mutate(stmt); + return stmt; +} + +} // namespace ir +} // namespace tvm diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py new file mode 100644 index 000000000000..6f52bbeecafb --- /dev/null +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -0,0 +1,21 @@ +import tvm +n = tvm.Var('n') +A = tvm.placeholder((n, ), name='A') +B = tvm.placeholder((n, ), name='B') + +T = tvm.compute((n, ), lambda i: A[i]+B[i]) +s = tvm.Schedule(T.op) +xo, xi = s[T].split(T.op.axis[0], factor=4) + +bounds = tvm.schedule.InferBound(s) +stmt = tvm.schedule.ScheduleOps(s, bounds) +stmt = tvm.ir_pass.LoopPartition(stmt) +print(stmt) + +# for (i.outer, 0, n) { +# for (i.inner, 0, 4) { +# if (i.inner + (i.outer*4) < n) { +# compute(i.inner + (i.outer*4)) = (A(i.inner + i.outer*4) + B(i.inner + i.outer*4)) +# } +# } +# } From 5c95c6f7133986df88e7fdfcf30af82277e66e70 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Wed, 1 Mar 2017 00:38:02 +0000 Subject: [PATCH 02/11] divide loop variable into constant domain and variable domain & consider multiple partitions --- src/arithmetic/int_set.cc | 10 ++++ src/arithmetic/int_set.h | 4 ++ src/pass/loop_partition.cc | 113 +++++++++++++++++++++++++------------ 3 files changed, 92 insertions(+), 35 deletions(-) diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 709da26a648f..b229bec4c0b6 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -179,6 +179,16 @@ IntSet Union(const Array& set) { return IntervalSet::make(x); } +IntSet Intersect(const std::vector& sets) { + // (TODO) temp solution + return IntSet::interval(sets[0].min(), sets[1].max()); +} + +IntSet Complement(IntSet s, IntSet u) { + // (TODO) temp solution + return IntSet::interval(s.max() + 1, u.max()); +} + // type traits template struct is_logical_op { diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index eabd63425643..faee56d0cc0d 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -157,6 +157,10 @@ ExprIntSetMap EvalSetForEachSubExpr(Expr r, */ IntSet Union(const Array& sets); +IntSet Intersect(const std::vector& sets); + +IntSet Complement(const IntSet s, const IntSet u); + // implementation inline const IntSetNode* IntSet::operator->() const { return static_cast(node_.get()); diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index a6a08d8d3d1f..42c445d63851 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include "../arithmetic/int_set.h" namespace tvm { @@ -24,10 +25,10 @@ struct Partition { IntSet interval; }; -bool expr_use_var(Expr expr, Expr target) { +bool ExprUseVar(Expr expr, const Variable* var) { bool success = false; - PostOrderVisit(expr, [&target, &success](const NodeRef& node) { - if (node.same_as(target)) { + PostOrderVisit(expr, [&var, &success](const NodeRef& node) { + if (node.get() == var) { success = true; return; } @@ -35,18 +36,36 @@ bool expr_use_var(Expr expr, Expr target) { return success; } +inline bool IsConstDomain(Expr min, Expr extent) { + return is_const(min) && is_const(extent); +} + class PartitionFinder : public IRVisitor { public: - explicit PartitionFinder(VarExpr loop_var) - : loop_var_(loop_var) {} + explicit PartitionFinder(VarExpr loop_var, + const std::unordered_map& dom_map, + const std::unordered_set& variables) + : loop_var_(loop_var), dom_map_(dom_map), variables_(variables) {} void Visit_(const For* op) { - dom_map_[op->loop_var.get()] = IntSet::interval(op->min, op->min + op->extent - 1); - IRVisitor::Visit_(op); + if (IsConstDomain(op->min, op->extent)) { + dom_map_.insert({op->loop_var.get(), + IntSet::interval(op->min, op->min + op->extent - 1)}); + IRVisitor::Visit_(op); + dom_map_.erase(op->loop_var.get()); + } else { + variables_.insert(op->loop_var.get()); + IRVisitor::Visit_(op); + variables_.erase(op->loop_var.get()); + } } void Visit_(const IfThenElse* op) { - if (expr_use_var(op->condition, loop_var_)) { + if (ExprUseVar(op->condition, loop_var_.get())) { + for (auto var : variables_) { + if (ExprUseVar(op->condition, var)) IRVisitor::Visit_(op); + } + IntSet interval = DeduceBound(loop_var_, op->condition, dom_map_); if (interval.min().same_as(Interval::neg_inf)) { IntSet upper_bound = EvalSet(interval.max(), dom_map_); @@ -67,6 +86,7 @@ class PartitionFinder : public IRVisitor { private: VarExpr loop_var_; std::unordered_map dom_map_; + std::unordered_set variables_; }; class PartitionReplacer : public IRMutator { @@ -89,15 +109,22 @@ class PartitionReplacer : public IRMutator { const Partition& p_; }; -IntSet intersect(IntSet a, IntSet b) { // need move into IntSet - // (TODO) temp solution - return IntSet::interval(b.min(), a.max()); -} - -IntSet complement(IntSet s, IntSet u) { // need move into IntSet - // (TODO) temp solution - return IntSet::interval(s.max() + 1, u.max()); -} +// LoopPartitioner will try to partition the loop variable in the IR. +// The loop variable can be divided into two categories: +// +// - whose range is fixed, the min and the extent both are constant. +// +// For now, we will not do partition on this kind loop variable, we +// add them into dom_map in order to do deduce for follow-up +// partitions. +// +// - whose range is variable +// +// We will try to do partition on this kind loop variable. If success, +// we will mutate the stmt then return. (only consider the partition +// on the outmost loop yet). If failed, we will mark them as variable +// (add them into variables_), then in the follow-up procedure, we know +// a condition is not able to be deduced if it use this variable. class LoopPartitioner : public IRMutator { public: @@ -109,37 +136,53 @@ class LoopPartitioner : public IRMutator { return IRMutator::Mutate(s); } - Stmt Mutate_(const For* op, const Stmt& stmt) { // Simplify for this for loop - // (TODO) recursive + Stmt Mutate_(const For* op, const Stmt& stmt) { + if (IsConstDomain(op->min, op->extent)) { + // if the range of loop_var is constant, we will not partition it, + // instead, we will use the fixed domain to deduce. + dom_map_.insert({op->loop_var.get(), + IntSet::interval(op->min, op->min + op->extent - 1)}); + Stmt res = IRMutator::Mutate_(op, stmt); + dom_map_.erase(op->loop_var.get()); + return res; + } - PartitionFinder finder(op->loop_var); + PartitionFinder finder(op->loop_var, dom_map_, variables_); finder.Visit(op->body); if (finder.partitions.empty()) { - // no available partition, return directly + variables_.insert(op->loop_var.get()); + IRMutator::Mutate_(op, stmt); + variables_.erase(op->loop_var.get()); return stmt; } IntSet universe = IntSet::interval(op->min, op->min + op->extent - 1); - Stmt s; - // (TODO) in fact, we need to consider all partitions, then split - // the universe into multiple ranges + std::vector sets{universe}; + // merge partitions (take their intersect) + for (auto p : finder.partitions) { + sets.push_back(p.interval); + } + + IntSet true_itrv = Intersect(sets); + IntSet doubt_itrv = Complement(true_itrv, universe); + + Stmt simplified_body = op->body; for (auto p : finder.partitions) { - IntSet true_itrv = intersect(p.interval, universe); - IntSet doubt_itrv = complement(true_itrv, universe); - - Stmt simplified_body = PartitionReplacer(p).Mutate(op->body); - Stmt simplified_stmt = For::make(op->loop_var, true_itrv.min(), - true_itrv.max() - true_itrv.min() + 1, op->for_type, op->device_api, simplified_body); - Stmt remaining_stmt = For::make(op->loop_var, doubt_itrv.min(), - doubt_itrv.max() - doubt_itrv.min() + 1, op->for_type, op->device_api, op->body); - s = Block::make(simplified_stmt, remaining_stmt); + p.interval = true_itrv; + simplified_body = PartitionReplacer(p).Mutate(simplified_body); } - return s; + + Stmt simplified_stmt = For::make(op->loop_var, true_itrv.min(), + true_itrv.max() - true_itrv.min() + 1, op->for_type, op->device_api, simplified_body); + Stmt remaining_stmt = For::make(op->loop_var, doubt_itrv.min(), + doubt_itrv.max() - doubt_itrv.min() + 1, op->for_type, op->device_api, op->body); + return Block::make(simplified_stmt, remaining_stmt); } private: - + std::unordered_set variables_; + std::unordered_map dom_map_; }; Stmt LoopPartition(Stmt stmt) { From 335a8ad62b7cbddd107ed2a9afb61d677beb561b Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Wed, 1 Mar 2017 21:05:58 +0000 Subject: [PATCH 03/11] process doubt interval --- src/arithmetic/int_set.cc | 37 +++++++++++++------ src/arithmetic/int_set.h | 4 +- src/pass/loop_partition.cc | 23 ++++++++---- .../unittest/test_pass_loop_partition.py | 8 ---- 4 files changed, 44 insertions(+), 28 deletions(-) diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index b229bec4c0b6..f99b7291ed27 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -162,11 +162,19 @@ inline bool MatchPoint(const IntSet& a, return i.is_single_point() && i.min.same_as(b); } -IntSet Union(const Array& set) { - if (set.size() == 1) return set[0]; - Interval x = set[0].cover_interval().as()->i; - for (size_t i = 1; i < set.size(); ++i) { - IntSet s = set[i].cover_interval(); +IntSet Union(const Array& sets) { + std::vector v_sets; + for (auto s : sets) { + v_sets.push_back(s); + } + return Union(v_sets); +} + +IntSet Union(const std::vector& sets) { + if (sets.size() == 1) return sets[0]; + Interval x = sets[0].cover_interval().as()->i; + for (size_t i = 1; i < sets.size(); ++i) { + IntSet s = sets[i].cover_interval(); const Interval& y = s.as()->i; if (can_prove(x.max + 1 >= y.min)) { x.max = y.max; @@ -179,14 +187,21 @@ IntSet Union(const Array& set) { return IntervalSet::make(x); } -IntSet Intersect(const std::vector& sets) { - // (TODO) temp solution - return IntSet::interval(sets[0].min(), sets[1].max()); +IntSet Intersect(const Array& sets) { + std::vector v_sets; + for (auto s : sets) { + v_sets.push_back(s); + } + return Intersect(v_sets); } -IntSet Complement(IntSet s, IntSet u) { - // (TODO) temp solution - return IntSet::interval(s.max() + 1, u.max()); +IntSet Intersect(const std::vector& sets) { + Interval x = sets[0].cover_interval().as()->i; + for (size_t i = 1; i < sets.size(); ++i) { + Interval y = sets[i].cover_interval().as()->i; + x = Interval::make_intersection(x, y); + } + return IntervalSet::make(x); } // type traits diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index faee56d0cc0d..fb96ca7cf2e4 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -156,11 +156,11 @@ ExprIntSetMap EvalSetForEachSubExpr(Expr r, * \return the set after union */ IntSet Union(const Array& sets); +IntSet Union(const std::vector& sets); +IntSet Intersect(const Array& sets); IntSet Intersect(const std::vector& sets); -IntSet Complement(const IntSet s, const IntSet u); - // implementation inline const IntSetNode* IntSet::operator->() const { return static_cast(node_.get()); diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 42c445d63851..7b733222a262 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -75,7 +75,7 @@ class PartitionFinder : public IRVisitor { interval = IntSet::interval(lower_bound.max(), interval.max()); } else { // Assume the partition is always a infinite set - LOG(WARNING) << "interval wrong?"; + LOG(WARNING) << "interval wrong"; } partitions.push_back(Partition{op->condition, op->condition, const_true(), interval}); } @@ -163,9 +163,7 @@ class LoopPartitioner : public IRMutator { for (auto p : finder.partitions) { sets.push_back(p.interval); } - IntSet true_itrv = Intersect(sets); - IntSet doubt_itrv = Complement(true_itrv, universe); Stmt simplified_body = op->body; for (auto p : finder.partitions) { @@ -175,14 +173,25 @@ class LoopPartitioner : public IRMutator { Stmt simplified_stmt = For::make(op->loop_var, true_itrv.min(), true_itrv.max() - true_itrv.min() + 1, op->for_type, op->device_api, simplified_body); - Stmt remaining_stmt = For::make(op->loop_var, doubt_itrv.min(), - doubt_itrv.max() - doubt_itrv.min() + 1, op->for_type, op->device_api, op->body); - return Block::make(simplified_stmt, remaining_stmt); + Stmt s = simplified_stmt; + + Expr pre_doubt_cond = (true_itrv.min() != universe.min()); + IntSet pre_doubt_itrv = IntSet::interval(universe.min(), true_itrv.min()); + Stmt pre_stmt = For::make(op->loop_var, pre_doubt_itrv.min(), + pre_doubt_itrv.max() - pre_doubt_itrv.min() + 1, op->for_type, op->device_api, op->body); + s = Block::make(IfThenElse::make(pre_doubt_cond, pre_stmt), s); + + Expr post_doubt_cond = (true_itrv.max() != universe.max()); + IntSet post_doubt_itrv = IntSet::interval(true_itrv.max(), universe.max()); + Stmt post_stmt = For::make(op->loop_var, post_doubt_itrv.min(), + post_doubt_itrv.max() - post_doubt_itrv.min() + 1, op->for_type, op->device_api, op->body); + s = Block::make(s, IfThenElse::make(post_doubt_cond, post_stmt)); + return s; } private: - std::unordered_set variables_; std::unordered_map dom_map_; + std::unordered_set variables_; }; Stmt LoopPartition(Stmt stmt) { diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index 6f52bbeecafb..c3fb164e9349 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -11,11 +11,3 @@ stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.LoopPartition(stmt) print(stmt) - -# for (i.outer, 0, n) { -# for (i.inner, 0, 4) { -# if (i.inner + (i.outer*4) < n) { -# compute(i.inner + (i.outer*4)) = (A(i.inner + i.outer*4) + B(i.inner + i.outer*4)) -# } -# } -# } From 266931ce8080f64f4bd726be8afe042e358222f9 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 2 Mar 2017 22:27:28 +0000 Subject: [PATCH 04/11] fix and refactor, add relax_map arg in BoundDeduce --- include/tvm/expr.h | 2 + src/api/api_arith.cc | 3 +- src/arithmetic/bound_deducer.cc | 52 ++++++++---- src/arithmetic/int_set.h | 6 +- src/pass/loop_partition.cc | 144 ++++++++++++++------------------ 5 files changed, 107 insertions(+), 100 deletions(-) diff --git a/include/tvm/expr.h b/include/tvm/expr.h index b7a6a458876f..510ac9d86835 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -35,6 +35,8 @@ using Halide::Internal::make_const; using Halide::Internal::make_zero; using Halide::Internal::as_const_int; using Halide::Internal::as_const_uint; +using Halide::Internal::const_true; +using Halide::Internal::const_false; inline Type TVMType2Type(TVMType t) { return Type(static_cast(t.code), t.bits, t.lanes); diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 99217cc56eb9..59ddb47536e8 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -30,7 +30,8 @@ TVM_REGISTER_API(_arith_EvalModular) TVM_REGISTER_API(_arith_DeduceBound) .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = DeduceBound(args[0], args[1], - args[2].operator Map()); + args[2].operator Map(), + args[3].operator Map()); }); TVM_REGISTER_API(_IntervalSetGetMin) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index b456d2d61a8f..f264a3294b9a 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -57,10 +57,10 @@ class BoundDeducer: public IRVisitor { friend class BoundDeduceInputChecker; friend class Converter; BoundDeducer(Expr target, Expr expr, - const std::unordered_map& dom_map) - : target_(target), expr_(expr), dom_map_(dom_map) {} + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) + : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {} - bool Init(); void Deduce(); void Visit(const NodeRef& e) final { @@ -137,9 +137,14 @@ class BoundDeducer: public IRVisitor { bool success{true}; private: + void Init(); + void Transform(); + void Relax(); + Expr target_; Expr expr_; - const std::unordered_map& dom_map_; + const std::unordered_map& hint_map_; + const std::unordered_map& relax_map_; ExprIntSetMap expr_map_; std::vector path_; size_t iter_{0}; @@ -163,10 +168,13 @@ class BoundDeduceInputChecker: public IRVisitor { size_t target_count{0}; }; -bool BoundDeducer::Init() { +void BoundDeducer::Init() { BoundDeduceInputChecker checker; if (!checker.Check(this)) success = false; + Transform(); +} +void BoundDeducer::Transform() { if (const LT* op = expr_.as()) { is_greater = false; is_equal = false; @@ -190,24 +198,35 @@ bool BoundDeducer::Init() { } else { success = false; } - return success; } void BoundDeducer::Deduce() { Init(); if (!success) return; + Relax(); // get the path path_ = GetPath(target_, expr_); // get the sign of every subexpr - expr_map_ = EvalSetForEachSubExpr(expr_, dom_map_); + expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); Visit(expr_); } +void BoundDeducer::Relax() { + if (is_greater) { + expr_ = EvalSet(expr_ , relax_map_).min(); + result = EvalSet(result, relax_map_).max(); + } else { + expr_ = EvalSet(expr_ , relax_map_).max(); + result = EvalSet(result, relax_map_).min(); + } +} + IntSet DeduceBound(Expr v, Expr e, - const std::unordered_map dom_map) { - BoundDeducer d(v, e, dom_map); + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) { + BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); if (!d.success) return IntSet::nothing(); Expr min = Interval::neg_inf, max = Interval::pos_inf; @@ -222,12 +241,17 @@ IntSet DeduceBound(Expr v, Expr e, // assuming e >= 0, deduce the bound of variable from it. // return empty set to represent deduce failure. IntSet DeduceBound(Expr v, Expr e, - const Map& dom_map) { - std::unordered_map dmap; - for (auto kv : dom_map) { - dmap[kv.first.get()] = kv.second; + const Map& hint_map, + const Map& relax_map) { + std::unordered_map hmap; + for (auto kv : hint_map) { + hmap[kv.first.get()] = kv.second; + } + std::unordered_map rmap; + for (auto kv : relax_map) { + rmap[kv.first.get()] = kv.second; } - return DeduceBound(v, e, dmap); + return DeduceBound(v, e, hmap, rmap); } } // namespace arith diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index fb96ca7cf2e4..88d2729dee49 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -177,9 +177,11 @@ inline const IntSetNode* IntSet::operator->() const { * \return An integer set that can cover all the possible values. */ IntSet DeduceBound(Expr v, Expr cond, - const Map& dom_map); + const Map& hint_map, + const Map& relax_map); IntSet DeduceBound(Expr v, Expr e, - const std::unordered_map dom_map); + const std::unordered_map& hint_map, + const std::unordered_map& relax_map); } // namespace arith diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 7b733222a262..555b63dcd0be 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -13,15 +13,10 @@ namespace tvm { namespace ir { using arith::IntSet; -using Halide::Internal::const_true; -using Halide::Internal::const_false; -using Halide::Internal::Interval; // for pos_inf & neg_inf -// a partition means condition is equal to true_value in the interval +// a partition means the expr is equal to true in the interval struct Partition { - Expr condition; - Expr old_expr; - Expr true_value; + Expr expr; IntSet interval; }; @@ -43,70 +38,60 @@ inline bool IsConstDomain(Expr min, Expr extent) { class PartitionFinder : public IRVisitor { public: explicit PartitionFinder(VarExpr loop_var, - const std::unordered_map& dom_map, - const std::unordered_set& variables) - : loop_var_(loop_var), dom_map_(dom_map), variables_(variables) {} + const std::unordered_map& vars) + : target_var_(loop_var), out_vars_(vars), hint_map_(vars), relax_map_() {} void Visit_(const For* op) { - if (IsConstDomain(op->min, op->extent)) { - dom_map_.insert({op->loop_var.get(), - IntSet::interval(op->min, op->min + op->extent - 1)}); - IRVisitor::Visit_(op); - dom_map_.erase(op->loop_var.get()); - } else { - variables_.insert(op->loop_var.get()); - IRVisitor::Visit_(op); - variables_.erase(op->loop_var.get()); + for (auto kv : out_vars_) { + if (ExprUseVar(op->min, kv.first) || + ExprUseVar(op->extent, kv.first)) { + return; + } } + + hint_map_.insert({op->loop_var.get(), + IntSet::interval(op->min, op->min + op->extent - 1)}); + relax_map_.insert({op->loop_var.get(), + IntSet::interval(op->min, op->min + op->extent - 1)}); + IRVisitor::Visit_(op); + relax_map_.erase(op->loop_var.get()); + hint_map_.erase(op->loop_var.get()); } void Visit_(const IfThenElse* op) { - if (ExprUseVar(op->condition, loop_var_.get())) { - for (auto var : variables_) { - if (ExprUseVar(op->condition, var)) IRVisitor::Visit_(op); - } - - IntSet interval = DeduceBound(loop_var_, op->condition, dom_map_); - if (interval.min().same_as(Interval::neg_inf)) { - IntSet upper_bound = EvalSet(interval.max(), dom_map_); - interval = IntSet::interval(interval.min(), upper_bound.min()); - } else if (interval.max().same_as(Interval::pos_inf)) { - IntSet lower_bound = EvalSet(interval.min(), dom_map_); - interval = IntSet::interval(lower_bound.max(), interval.max()); - } else { - // Assume the partition is always a infinite set - LOG(WARNING) << "interval wrong"; - } - partitions.push_back(Partition{op->condition, op->condition, const_true(), interval}); + if (ExprUseVar(op->condition, target_var_.get())) { + IntSet interval = DeduceBound(target_var_, op->condition, hint_map_, relax_map_); + partitions.push_back(Partition{op->condition, interval}); + } else { + IRVisitor::Visit_(op); } - IRVisitor::Visit_(op); } std::vector partitions; private: - VarExpr loop_var_; - std::unordered_map dom_map_; - std::unordered_set variables_; + VarExpr target_var_; + const std::unordered_map& out_vars_; + std::unordered_map hint_map_; + std::unordered_map relax_map_; }; class PartitionReplacer : public IRMutator { public: - PartitionReplacer(const Partition& p) - : p_(p) {} + PartitionReplacer(const std::vector& ps) + : ps_(ps) {} - Expr Mutate(Expr e) override { - if (e.same_as(p_.old_expr)) { - return Mutate(p_.true_value); + Expr Mutate(Expr e) final { + for (auto p : ps_) { + if (e.same_as(p.expr)) { + return Mutate(const_true()); + } } return IRMutator::Mutate(e); } - - Stmt Mutate(Stmt s) override { // ? will raise error if no this function - return IRMutator::Mutate(s); - } + using IRMutator::Mutate; private: - const Partition& p_; + const std::vector& ps_; }; // LoopPartitioner will try to partition the loop variable in the IR. @@ -129,32 +114,27 @@ class PartitionReplacer : public IRMutator { class LoopPartitioner : public IRMutator { public: explicit LoopPartitioner() {} - Expr Mutate(Expr e) override { - return IRMutator::Mutate(e); - } - Stmt Mutate(Stmt s) override { - return IRMutator::Mutate(s); - } Stmt Mutate_(const For* op, const Stmt& stmt) { if (IsConstDomain(op->min, op->extent)) { // if the range of loop_var is constant, we will not partition it, // instead, we will use the fixed domain to deduce. - dom_map_.insert({op->loop_var.get(), - IntSet::interval(op->min, op->min + op->extent - 1)}); + vars_.insert({op->loop_var.get(), + IntSet::interval(op->min, op->min + op->extent - 1)}); Stmt res = IRMutator::Mutate_(op, stmt); - dom_map_.erase(op->loop_var.get()); + vars_.erase(op->loop_var.get()); return res; } - PartitionFinder finder(op->loop_var, dom_map_, variables_); + PartitionFinder finder(op->loop_var, vars_); finder.Visit(op->body); if (finder.partitions.empty()) { - variables_.insert(op->loop_var.get()); - IRMutator::Mutate_(op, stmt); - variables_.erase(op->loop_var.get()); - return stmt; + vars_.insert({op->loop_var.get(), + IntSet::interval(op->min, op->min + op->extent - 1)}); + Stmt res = IRMutator::Mutate_(op, stmt); + vars_.erase(op->loop_var.get()); + return res; } IntSet universe = IntSet::interval(op->min, op->min + op->extent - 1); @@ -165,33 +145,31 @@ class LoopPartitioner : public IRMutator { } IntSet true_itrv = Intersect(sets); - Stmt simplified_body = op->body; - for (auto p : finder.partitions) { - p.interval = true_itrv; - simplified_body = PartitionReplacer(p).Mutate(simplified_body); - } - + Stmt simplified_body = PartitionReplacer(finder.partitions).Mutate(op->body); Stmt simplified_stmt = For::make(op->loop_var, true_itrv.min(), true_itrv.max() - true_itrv.min() + 1, op->for_type, op->device_api, simplified_body); Stmt s = simplified_stmt; - Expr pre_doubt_cond = (true_itrv.min() != universe.min()); - IntSet pre_doubt_itrv = IntSet::interval(universe.min(), true_itrv.min()); - Stmt pre_stmt = For::make(op->loop_var, pre_doubt_itrv.min(), - pre_doubt_itrv.max() - pre_doubt_itrv.min() + 1, op->for_type, op->device_api, op->body); - s = Block::make(IfThenElse::make(pre_doubt_cond, pre_stmt), s); - - Expr post_doubt_cond = (true_itrv.max() != universe.max()); - IntSet post_doubt_itrv = IntSet::interval(true_itrv.max(), universe.max()); - Stmt post_stmt = For::make(op->loop_var, post_doubt_itrv.min(), - post_doubt_itrv.max() - post_doubt_itrv.min() + 1, op->for_type, op->device_api, op->body); - s = Block::make(s, IfThenElse::make(post_doubt_cond, post_stmt)); + if (!can_prove(true_itrv.min() == universe.min())) { + Expr pre_doubt_cond = (true_itrv.min() != universe.min()); + IntSet pre_doubt_itrv = IntSet::interval(universe.min(), true_itrv.min()); + Stmt pre_stmt = For::make(op->loop_var, pre_doubt_itrv.min(), + pre_doubt_itrv.max() - pre_doubt_itrv.min() + 1, op->for_type, op->device_api, op->body); + s = Block::make(IfThenElse::make(pre_doubt_cond, pre_stmt), s); + } + + if (!can_prove(true_itrv.max() == universe.max())) { + Expr post_doubt_cond = (true_itrv.max() != universe.max()); + IntSet post_doubt_itrv = IntSet::interval(true_itrv.max(), universe.max()); + Stmt post_stmt = For::make(op->loop_var, post_doubt_itrv.min(), + post_doubt_itrv.max() - post_doubt_itrv.min() + 1, op->for_type, op->device_api, op->body); + s = Block::make(s, IfThenElse::make(post_doubt_cond, post_stmt)); + } return s; } private: - std::unordered_map dom_map_; - std::unordered_set variables_; + std::unordered_map vars_; }; Stmt LoopPartition(Stmt stmt) { From 40c0b5f5d863d6c1e62fde44d6b85025dda7b939 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 2 Mar 2017 23:19:52 +0000 Subject: [PATCH 05/11] fix testcase and comment --- include/tvm/ir_pass.h | 5 ++ src/arithmetic/int_set.h | 9 +++- src/pass/loop_partition.cc | 5 +- tests/python/unittest/test_arith_intset.py | 12 ++--- .../unittest/test_pass_loop_partition.py | 51 ++++++++++++++----- 5 files changed, 61 insertions(+), 21 deletions(-) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 915de41ddd6a..f1ee06188b06 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -137,6 +137,11 @@ Stmt InjectVirtualThread(Stmt stmt); */ Stmt LiftAllocate(Stmt stmt); +/*! + * \brief partition loops in the stmt + * \param stmt The stmt to do loop partition + * \return Transformed stmt. + */ Stmt LoopPartition(Stmt stmt); /*! diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 88d2729dee49..9e6c84f0b4eb 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -8,6 +8,7 @@ #include #include +#include namespace tvm { namespace arith { @@ -158,6 +159,11 @@ ExprIntSetMap EvalSetForEachSubExpr(Expr r, IntSet Union(const Array& sets); IntSet Union(const std::vector& sets); +/*! + * \brief Create an union set of all sets + * \param sets The sets to be intersected + * \return the set after intersected + */ IntSet Intersect(const Array& sets); IntSet Intersect(const std::vector& sets); @@ -173,7 +179,8 @@ inline const IntSetNode* IntSet::operator->() const { * * \param v The target variable to be deduced. * \param cond The conditional expression. - * \param dom_map The domain of each variable. + * \param hint_map The domain of variable, used to help deduce. + * \param relax The domain of each variable, used to relax the domain. * \return An integer set that can cover all the possible values. */ IntSet DeduceBound(Expr v, Expr cond, diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 555b63dcd0be..ce422a181b59 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -68,6 +68,7 @@ class PartitionFinder : public IRVisitor { } std::vector partitions; + private: VarExpr target_var_; const std::unordered_map& out_vars_; @@ -77,7 +78,7 @@ class PartitionFinder : public IRVisitor { class PartitionReplacer : public IRMutator { public: - PartitionReplacer(const std::vector& ps) + explicit PartitionReplacer(const std::vector& ps) : ps_(ps) {} Expr Mutate(Expr e) final { @@ -113,7 +114,7 @@ class PartitionReplacer : public IRMutator { class LoopPartitioner : public IRMutator { public: - explicit LoopPartitioner() {} + LoopPartitioner() {} Stmt Mutate_(const For* op, const Stmt& stmt) { if (IsConstDomain(op->min, op->extent)) { diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index b677ea6ec6fa..b3ba1c7e681e 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -16,17 +16,17 @@ def test_deduce(): d_s = tvm.arith.intset_interval(-3, -1) e0 = (-b)*a+c-d - res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}) + res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = (d-c)/(-b)+(-1) assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) e1 = (a*4+b < c) - res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}) + res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) ans1 = (c-b)/4+(-2) assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) e2 = (tvm.max(5, a * 4) < 0) - res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}) + res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) assert str(res2.max()) == "neg_inf" assert str(res2.min()) == "pos_inf" @@ -41,15 +41,15 @@ def test_check(): d_s = tvm.arith.intset_interval(-3, -1) # no compare operator - res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}) + res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {}) assert res1.is_nothing() # multiple compare operators - res2 = tvm.arith.DeduceBound(a, a+b>3>c , {b: b_s, c: c_s}) + res2 = tvm.arith.DeduceBound(a, a+b>3>c , {b: b_s, c: c_s}, {}) assert res1.is_nothing() # multiple target variable - res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}) + res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {}) assert res1.is_nothing() if __name__ == "__main__": diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index c3fb164e9349..0795ea47814d 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -1,13 +1,40 @@ import tvm -n = tvm.Var('n') -A = tvm.placeholder((n, ), name='A') -B = tvm.placeholder((n, ), name='B') - -T = tvm.compute((n, ), lambda i: A[i]+B[i]) -s = tvm.Schedule(T.op) -xo, xi = s[T].split(T.op.axis[0], factor=4) - -bounds = tvm.schedule.InferBound(s) -stmt = tvm.schedule.ScheduleOps(s, bounds) -stmt = tvm.ir_pass.LoopPartition(stmt) -print(stmt) + +def test_basic(): + n = tvm.Var('n') + A = tvm.placeholder((n, ), name='A') + B = tvm.placeholder((n, ), name='B') + + T = tvm.compute((n, ), lambda i: A[i]+B[i]) + s = tvm.Schedule(T.op) + xo, xi = s[T].split(T.op.axis[0], factor=4) + + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + stmt = tvm.ir_pass.LoopPartition(stmt) + assert(stmt.body.body.body.first.body.body.condition.value == 1) + print(stmt) + +def test_multi_loop(): + i = tvm.Var('i') + j = tvm.Var('j') + k = tvm.Var('k') + m = tvm.Var('m') + n = tvm.Var('n') + stmt = tvm.make.For( + i, 0, 4, 0, 0, + tvm.make.For( + j, 0, n, 0, 0, + tvm.make.For( + k, 0, m, 0, 0, + tvm.make.IfThenElse( + (i*m+j+k < n), tvm.make.Evaluate(1), tvm.make.Evaluate(0))))) + stmt = tvm.ir_pass.LoopPartition(stmt) + assert(stmt.body.first.body.body.condition.value == 1) + print(stmt) + return stmt + + +if __name__ == "__main__": + s1 = test_basic() + s2 = test_multi_loop() From 63d3a7d1efe2487f50d9f74c44f2d22c2fee7c50 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 3 Mar 2017 06:46:39 +0000 Subject: [PATCH 06/11] rebase to zero, convert to SSA --- include/tvm/expr.h | 2 +- src/pass/ir_mutator.cc | 2 +- src/pass/loop_partition.cc | 44 +++++++------------ .../unittest/test_pass_loop_partition.py | 6 +-- 4 files changed, 20 insertions(+), 34 deletions(-) diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 510ac9d86835..f446fb4ee591 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -55,8 +55,8 @@ class Var : public Halide::VarExpr { public: explicit Var(const std::string& name_hint = "v", Type t = Int(32)) : VarExpr(name_hint, t) {} - explicit Var(std::shared_ptr n) : VarExpr(n) {} + explicit Var(VarExpr v) : VarExpr(v) {} /*! \brief type indicate the container type */ using ContainerType = Variable; diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index bab7471c0561..5fd141ed2f55 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -128,7 +128,7 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { Expr condition = this->Mutate(op->condition); Stmt then_case = this->Mutate(op->then_case); Stmt else_case; - if (else_case.defined()) { + if (op->else_case.defined()) { else_case = this->Mutate(op->else_case); } if (condition.same_as(op->condition) && diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index ce422a181b59..d4cb5883eff9 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -13,6 +13,7 @@ namespace tvm { namespace ir { using arith::IntSet; +using Halide::Internal::Interval; // for pos_inf & neg_inf // a partition means the expr is equal to true in the interval struct Partition { @@ -81,7 +82,7 @@ class PartitionReplacer : public IRMutator { explicit PartitionReplacer(const std::vector& ps) : ps_(ps) {} - Expr Mutate(Expr e) final { + Expr Mutate(Expr e) override { for (auto p : ps_) { if (e.same_as(p.expr)) { return Mutate(const_true()); @@ -95,31 +96,12 @@ class PartitionReplacer : public IRMutator { const std::vector& ps_; }; -// LoopPartitioner will try to partition the loop variable in the IR. -// The loop variable can be divided into two categories: -// -// - whose range is fixed, the min and the extent both are constant. -// -// For now, we will not do partition on this kind loop variable, we -// add them into dom_map in order to do deduce for follow-up -// partitions. -// -// - whose range is variable -// -// We will try to do partition on this kind loop variable. If success, -// we will mutate the stmt then return. (only consider the partition -// on the outmost loop yet). If failed, we will mark them as variable -// (add them into variables_), then in the follow-up procedure, we know -// a condition is not able to be deduced if it use this variable. - class LoopPartitioner : public IRMutator { public: LoopPartitioner() {} Stmt Mutate_(const For* op, const Stmt& stmt) { if (IsConstDomain(op->min, op->extent)) { - // if the range of loop_var is constant, we will not partition it, - // instead, we will use the fixed domain to deduce. vars_.insert({op->loop_var.get(), IntSet::interval(op->min, op->min + op->extent - 1)}); Stmt res = IRMutator::Mutate_(op, stmt); @@ -147,26 +129,30 @@ class LoopPartitioner : public IRMutator { IntSet true_itrv = Intersect(sets); Stmt simplified_body = PartitionReplacer(finder.partitions).Mutate(op->body); - Stmt simplified_stmt = For::make(op->loop_var, true_itrv.min(), - true_itrv.max() - true_itrv.min() + 1, op->for_type, op->device_api, simplified_body); + // rebase to zero + Stmt body = Substitute(simplified_body, {{Var{op->loop_var}, op->loop_var + true_itrv.min()}}); + Stmt simplified_stmt = For::make(op->loop_var, 0, + true_itrv.max() - true_itrv.min() + 1, op->for_type, op->device_api, body); Stmt s = simplified_stmt; if (!can_prove(true_itrv.min() == universe.min())) { Expr pre_doubt_cond = (true_itrv.min() != universe.min()); - IntSet pre_doubt_itrv = IntSet::interval(universe.min(), true_itrv.min()); - Stmt pre_stmt = For::make(op->loop_var, pre_doubt_itrv.min(), - pre_doubt_itrv.max() - pre_doubt_itrv.min() + 1, op->for_type, op->device_api, op->body); + IntSet pre_doubt_itrv = IntSet::interval(universe.min(), true_itrv.min() - 1); + Stmt body = Substitute(op->body, {{Var{op->loop_var}, op->loop_var + pre_doubt_itrv.min()}}); + Stmt pre_stmt = For::make(op->loop_var, 0, + pre_doubt_itrv.max() - pre_doubt_itrv.min() + 1, op->for_type, op->device_api, body); s = Block::make(IfThenElse::make(pre_doubt_cond, pre_stmt), s); } if (!can_prove(true_itrv.max() == universe.max())) { Expr post_doubt_cond = (true_itrv.max() != universe.max()); - IntSet post_doubt_itrv = IntSet::interval(true_itrv.max(), universe.max()); - Stmt post_stmt = For::make(op->loop_var, post_doubt_itrv.min(), - post_doubt_itrv.max() - post_doubt_itrv.min() + 1, op->for_type, op->device_api, op->body); + IntSet post_doubt_itrv = IntSet::interval(true_itrv.max() + 1, universe.max()); + Stmt body = Substitute(op->body, {{Var{op->loop_var}, op->loop_var + post_doubt_itrv.min()}}); + Stmt post_stmt = For::make(op->loop_var, 0, + post_doubt_itrv.max() - post_doubt_itrv.min() + 1, op->for_type, op->device_api, body); s = Block::make(s, IfThenElse::make(post_doubt_cond, post_stmt)); } - return s; + return Simplify(ConvertSSA(s)); } private: diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index 0795ea47814d..c75a56c4a9a7 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -12,7 +12,7 @@ def test_basic(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.LoopPartition(stmt) - assert(stmt.body.body.body.first.body.body.condition.value == 1) + # assert(stmt.body.body.body.first.body.body.condition.value == 1) print(stmt) def test_multi_loop(): @@ -28,9 +28,9 @@ def test_multi_loop(): tvm.make.For( k, 0, m, 0, 0, tvm.make.IfThenElse( - (i*m+j+k < n), tvm.make.Evaluate(1), tvm.make.Evaluate(0))))) + (i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n))))) stmt = tvm.ir_pass.LoopPartition(stmt) - assert(stmt.body.first.body.body.condition.value == 1) + # assert(stmt.body.first.body.body.condition.value == 1) print(stmt) return stmt From d7aba90d8f42459b9d8b10e5822b5657a1afb028 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 4 Mar 2017 01:44:39 +0000 Subject: [PATCH 07/11] change the logic of generating loop code & fix issues --- src/arithmetic/int_set.cc | 16 --- src/arithmetic/int_set.h | 2 - src/pass/loop_partition.cc | 134 +++++++++++------- .../unittest/test_pass_loop_partition.py | 9 +- 4 files changed, 87 insertions(+), 74 deletions(-) diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index f99b7291ed27..1a66d060acc5 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -163,14 +163,6 @@ inline bool MatchPoint(const IntSet& a, } IntSet Union(const Array& sets) { - std::vector v_sets; - for (auto s : sets) { - v_sets.push_back(s); - } - return Union(v_sets); -} - -IntSet Union(const std::vector& sets) { if (sets.size() == 1) return sets[0]; Interval x = sets[0].cover_interval().as()->i; for (size_t i = 1; i < sets.size(); ++i) { @@ -188,14 +180,6 @@ IntSet Union(const std::vector& sets) { } IntSet Intersect(const Array& sets) { - std::vector v_sets; - for (auto s : sets) { - v_sets.push_back(s); - } - return Intersect(v_sets); -} - -IntSet Intersect(const std::vector& sets) { Interval x = sets[0].cover_interval().as()->i; for (size_t i = 1; i < sets.size(); ++i) { Interval y = sets[i].cover_interval().as()->i; diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 9e6c84f0b4eb..113c9bd013af 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -157,7 +157,6 @@ ExprIntSetMap EvalSetForEachSubExpr(Expr r, * \return the set after union */ IntSet Union(const Array& sets); -IntSet Union(const std::vector& sets); /*! * \brief Create an union set of all sets @@ -165,7 +164,6 @@ IntSet Union(const std::vector& sets); * \return the set after intersected */ IntSet Intersect(const Array& sets); -IntSet Intersect(const std::vector& sets); // implementation inline const IntSetNode* IntSet::operator->() const { diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index d4cb5883eff9..be93985bb7c4 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -6,14 +6,14 @@ #include #include #include -#include +#include #include "../arithmetic/int_set.h" +#include "../arithmetic/int_set_internal.h" namespace tvm { namespace ir { using arith::IntSet; -using Halide::Internal::Interval; // for pos_inf & neg_inf // a partition means the expr is equal to true in the interval struct Partition { @@ -21,12 +21,14 @@ struct Partition { IntSet interval; }; -bool ExprUseVar(Expr expr, const Variable* var) { +bool ExprUseVars(Expr expr, const std::vector& vars) { bool success = false; - PostOrderVisit(expr, [&var, &success](const NodeRef& node) { - if (node.get() == var) { - success = true; - return; + PostOrderVisit(expr, [&vars, &success](const NodeRef& node) { + for (const Variable* v : vars) { + if (node.get() == v) { + success = true; + return; + } } }); return success; @@ -40,15 +42,12 @@ class PartitionFinder : public IRVisitor { public: explicit PartitionFinder(VarExpr loop_var, const std::unordered_map& vars) - : target_var_(loop_var), out_vars_(vars), hint_map_(vars), relax_map_() {} + : target_var_(loop_var), out_vars_(vars.size()), hint_map_(vars), relax_map_() { + for (auto kv : vars) out_vars_.push_back(kv.first); + } void Visit_(const For* op) { - for (auto kv : out_vars_) { - if (ExprUseVar(op->min, kv.first) || - ExprUseVar(op->extent, kv.first)) { - return; - } - } + if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; hint_map_.insert({op->loop_var.get(), IntSet::interval(op->min, op->min + op->extent - 1)}); @@ -60,40 +59,45 @@ class PartitionFinder : public IRVisitor { } void Visit_(const IfThenElse* op) { - if (ExprUseVar(op->condition, target_var_.get())) { + if (ExprUseVars(op->condition, std::vector({target_var_.get()}))) { IntSet interval = DeduceBound(target_var_, op->condition, hint_map_, relax_map_); - partitions.push_back(Partition{op->condition, interval}); + partitions[op->condition.get()] = Partition{op->condition, interval}; } else { IRVisitor::Visit_(op); } } - std::vector partitions; + std::unordered_map partitions; private: VarExpr target_var_; - const std::unordered_map& out_vars_; + std::vector out_vars_; std::unordered_map hint_map_; std::unordered_map relax_map_; }; +std::unordered_map +FindPartitions(VarExpr target, Stmt body, std::unordered_map vars) { + PartitionFinder finder(target, vars); + finder.Visit(body); + return finder.partitions; +} + class PartitionReplacer : public IRMutator { public: - explicit PartitionReplacer(const std::vector& ps) + explicit PartitionReplacer(const std::unordered_map& ps) : ps_(ps) {} Expr Mutate(Expr e) override { - for (auto p : ps_) { - if (e.same_as(p.expr)) { - return Mutate(const_true()); - } + if (ps_.count(e.get())) { + return Mutate(const_true()); } return IRMutator::Mutate(e); } using IRMutator::Mutate; private: - const std::vector& ps_; + const std::unordered_map& ps_; }; class LoopPartitioner : public IRMutator { @@ -109,10 +113,10 @@ class LoopPartitioner : public IRMutator { return res; } - PartitionFinder finder(op->loop_var, vars_); - finder.Visit(op->body); + std::unordered_map partitions = + FindPartitions(op->loop_var, op->body, vars_); - if (finder.partitions.empty()) { + if (partitions.empty()) { vars_.insert({op->loop_var.get(), IntSet::interval(op->min, op->min + op->extent - 1)}); Stmt res = IRMutator::Mutate_(op, stmt); @@ -120,38 +124,66 @@ class LoopPartitioner : public IRMutator { return res; } - IntSet universe = IntSet::interval(op->min, op->min + op->extent - 1); - std::vector sets{universe}; + Expr min = op->min; + Expr max = op->min + op->extent - 1; + Array sets; // merge partitions (take their intersect) - for (auto p : finder.partitions) { - sets.push_back(p.interval); + for (auto kv : partitions) { + sets.push_back(kv.second.interval); } IntSet true_itrv = Intersect(sets); - Stmt simplified_body = PartitionReplacer(finder.partitions).Mutate(op->body); - // rebase to zero - Stmt body = Substitute(simplified_body, {{Var{op->loop_var}, op->loop_var + true_itrv.min()}}); - Stmt simplified_stmt = For::make(op->loop_var, 0, - true_itrv.max() - true_itrv.min() + 1, op->for_type, op->device_api, body); - Stmt s = simplified_stmt; + Stmt pre_stmt; + Expr body_begin; + if (true_itrv.as()->i.has_lower_bound()) { + body_begin = true_itrv.min(); + if (!can_prove(body_begin == min)) { + if (!can_prove(body_begin - min >= 0)) { + LOG(WARNING) << "cannot prove: " << (body_begin - min >= 0) + << ", when generating the pre doubt loop"; + body_begin = Max::make(body_begin, min); + } + // [min, body_begin) + Stmt body = Substitute(op->body, {{Var{op->loop_var}, op->loop_var + min}}); + pre_stmt = For::make(op->loop_var, 0, + body_begin - min, op->for_type, op->device_api, body); + } + } else { + body_begin = min; + } - if (!can_prove(true_itrv.min() == universe.min())) { - Expr pre_doubt_cond = (true_itrv.min() != universe.min()); - IntSet pre_doubt_itrv = IntSet::interval(universe.min(), true_itrv.min() - 1); - Stmt body = Substitute(op->body, {{Var{op->loop_var}, op->loop_var + pre_doubt_itrv.min()}}); - Stmt pre_stmt = For::make(op->loop_var, 0, - pre_doubt_itrv.max() - pre_doubt_itrv.min() + 1, op->for_type, op->device_api, body); - s = Block::make(IfThenElse::make(pre_doubt_cond, pre_stmt), s); + Stmt post_stmt; + Expr post_doubt_begin; + if (true_itrv.as()->i.has_upper_bound()) { + post_doubt_begin = true_itrv.max() + 1; + if (!can_prove(true_itrv.max() == max)) { + if (!can_prove(max - post_doubt_begin >= 0)) { + LOG(WARNING) << "Cannot prove: " << (max - post_doubt_begin >= 0) + << ", when generating the post doubt loop"; + post_doubt_begin = Min::make(post_doubt_begin, max); + } + // [post_doubt_begin, max] + Stmt body = Substitute(op->body, {{Var{op->loop_var}, op->loop_var + post_doubt_begin}}); + post_stmt = For::make(op->loop_var, 0, + max - post_doubt_begin + 1, op->for_type, op->device_api, body); + } + } else { + post_doubt_begin = max + 1; } - if (!can_prove(true_itrv.max() == universe.max())) { - Expr post_doubt_cond = (true_itrv.max() != universe.max()); - IntSet post_doubt_itrv = IntSet::interval(true_itrv.max() + 1, universe.max()); - Stmt body = Substitute(op->body, {{Var{op->loop_var}, op->loop_var + post_doubt_itrv.min()}}); - Stmt post_stmt = For::make(op->loop_var, 0, - post_doubt_itrv.max() - post_doubt_itrv.min() + 1, op->for_type, op->device_api, body); - s = Block::make(s, IfThenElse::make(post_doubt_cond, post_stmt)); + // [body_begin, post_doubt_begin) + Stmt simplified_body = PartitionReplacer(partitions).Mutate(op->body); + Stmt body = Substitute(simplified_body, {{Var{op->loop_var}, op->loop_var + body_begin}}); + Stmt simplified_stmt = For::make(op->loop_var, 0, + post_doubt_begin - body_begin, op->for_type, op->device_api, body); + Stmt s = simplified_stmt; + if (pre_stmt.defined()) { + s = Block::make(pre_stmt, s); } + if (post_stmt.defined()) { + s = Block::make(s, post_stmt); + } + return Simplify(ConvertSSA(s)); } diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index c75a56c4a9a7..f4df2e490050 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -12,7 +12,7 @@ def test_basic(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.LoopPartition(stmt) - # assert(stmt.body.body.body.first.body.body.condition.value == 1) + assert('if' not in str(stmt.body.body.body.first)) print(stmt) def test_multi_loop(): @@ -30,11 +30,10 @@ def test_multi_loop(): tvm.make.IfThenElse( (i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n))))) stmt = tvm.ir_pass.LoopPartition(stmt) - # assert(stmt.body.first.body.body.condition.value == 1) + assert('if' not in str(stmt.body.first)) print(stmt) - return stmt if __name__ == "__main__": - s1 = test_basic() - s2 = test_multi_loop() + test_basic() + test_multi_loop() From e2b39b13c09f4f28827c6147b2db0dc977e99295 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 4 Mar 2017 02:43:38 +0000 Subject: [PATCH 08/11] add a testcase for relax map in deducebound && fix issues --- src/pass/loop_partition.cc | 161 +++++++++++---------- tests/python/unittest/test_arith_intset.py | 5 + 2 files changed, 87 insertions(+), 79 deletions(-) diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index be93985bb7c4..d899c9aa83aa 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include "../arithmetic/int_set.h" #include "../arithmetic/int_set_internal.h" @@ -21,11 +22,11 @@ struct Partition { IntSet interval; }; -bool ExprUseVars(Expr expr, const std::vector& vars) { +bool ExprUseVars(Expr expr, const std::unordered_set& vars) { bool success = false; PostOrderVisit(expr, [&vars, &success](const NodeRef& node) { - for (const Variable* v : vars) { - if (node.get() == v) { + if (const Variable* v = node.as()) { + if (vars.count(v)) { success = true; return; } @@ -34,16 +35,12 @@ bool ExprUseVars(Expr expr, const std::vector& vars) { return success; } -inline bool IsConstDomain(Expr min, Expr extent) { - return is_const(min) && is_const(extent); -} - class PartitionFinder : public IRVisitor { public: explicit PartitionFinder(VarExpr loop_var, const std::unordered_map& vars) : target_var_(loop_var), out_vars_(vars.size()), hint_map_(vars), relax_map_() { - for (auto kv : vars) out_vars_.push_back(kv.first); + for (auto kv : vars) out_vars_.insert(kv.first); } void Visit_(const For* op) { @@ -59,7 +56,7 @@ class PartitionFinder : public IRVisitor { } void Visit_(const IfThenElse* op) { - if (ExprUseVars(op->condition, std::vector({target_var_.get()}))) { + if (ExprUseVars(op->condition, std::unordered_set({target_var_.get()}))) { IntSet interval = DeduceBound(target_var_, op->condition, hint_map_, relax_map_); partitions[op->condition.get()] = Partition{op->condition, interval}; } else { @@ -71,18 +68,11 @@ class PartitionFinder : public IRVisitor { private: VarExpr target_var_; - std::vector out_vars_; + std::unordered_set out_vars_; std::unordered_map hint_map_; std::unordered_map relax_map_; }; -std::unordered_map -FindPartitions(VarExpr target, Stmt body, std::unordered_map vars) { - PartitionFinder finder(target, vars); - finder.Visit(body); - return finder.partitions; -} - class PartitionReplacer : public IRMutator { public: explicit PartitionReplacer(const std::unordered_map& ps) @@ -105,7 +95,7 @@ class LoopPartitioner : public IRMutator { LoopPartitioner() {} Stmt Mutate_(const For* op, const Stmt& stmt) { - if (IsConstDomain(op->min, op->extent)) { + if (is_const(op->min) && is_const(op->extent)) { vars_.insert({op->loop_var.get(), IntSet::interval(op->min, op->min + op->extent - 1)}); Stmt res = IRMutator::Mutate_(op, stmt); @@ -113,83 +103,96 @@ class LoopPartitioner : public IRMutator { return res; } - std::unordered_map partitions = - FindPartitions(op->loop_var, op->body, vars_); + Stmt s = DoPartition(op, stmt); - if (partitions.empty()) { + if (s.defined()) { + return s; + } else { vars_.insert({op->loop_var.get(), IntSet::interval(op->min, op->min + op->extent - 1)}); Stmt res = IRMutator::Mutate_(op, stmt); vars_.erase(op->loop_var.get()); return res; } + } - Expr min = op->min; - Expr max = op->min + op->extent - 1; - Array sets; - // merge partitions (take their intersect) - for (auto kv : partitions) { - sets.push_back(kv.second.interval); - } - IntSet true_itrv = Intersect(sets); - - Stmt pre_stmt; - Expr body_begin; - if (true_itrv.as()->i.has_lower_bound()) { - body_begin = true_itrv.min(); - if (!can_prove(body_begin == min)) { - if (!can_prove(body_begin - min >= 0)) { - LOG(WARNING) << "cannot prove: " << (body_begin - min >= 0) - << ", when generating the pre doubt loop"; - body_begin = Max::make(body_begin, min); - } - // [min, body_begin) - Stmt body = Substitute(op->body, {{Var{op->loop_var}, op->loop_var + min}}); - pre_stmt = For::make(op->loop_var, 0, - body_begin - min, op->for_type, op->device_api, body); - } - } else { - body_begin = min; - } + private: + Stmt DoPartition(const For* op, const Stmt& stmt); + + std::unordered_map vars_; +}; - Stmt post_stmt; - Expr post_doubt_begin; - if (true_itrv.as()->i.has_upper_bound()) { - post_doubt_begin = true_itrv.max() + 1; - if (!can_prove(true_itrv.max() == max)) { - if (!can_prove(max - post_doubt_begin >= 0)) { - LOG(WARNING) << "Cannot prove: " << (max - post_doubt_begin >= 0) - << ", when generating the post doubt loop"; - post_doubt_begin = Min::make(post_doubt_begin, max); - } - // [post_doubt_begin, max] - Stmt body = Substitute(op->body, {{Var{op->loop_var}, op->loop_var + post_doubt_begin}}); - post_stmt = For::make(op->loop_var, 0, - max - post_doubt_begin + 1, op->for_type, op->device_api, body); +Stmt LoopPartitioner::DoPartition(const For* op, const Stmt& stmt) { + PartitionFinder finder(op->loop_var, vars_); + finder.Visit(op->body); + const auto& partitions = finder.partitions; + + if (partitions.empty()) return Stmt(); + + Expr min = op->min; + Expr max = op->min + op->extent - 1; + Array sets; + // merge partitions (take their intersect) + for (auto kv : partitions) { + sets.push_back(kv.second.interval); + } + IntSet true_itrv = Intersect(sets); + + Stmt pre_stmt; + Expr body_begin; + if (true_itrv.as()->i.has_lower_bound()) { + body_begin = true_itrv.min(); + if (!can_prove(body_begin == min)) { + if (!can_prove(body_begin - min >= 0)) { + LOG(WARNING) << "cannot prove: " << (body_begin - min >= 0) + << ", when generating the pre doubt loop"; + body_begin = Max::make(body_begin, min); } - } else { - post_doubt_begin = max + 1; + // [min, body_begin) + Stmt body = Substitute(op->body, + {{Var{op->loop_var}, op->loop_var + min}}); + pre_stmt = For::make(op->loop_var, 0, + body_begin - min, op->for_type, op->device_api, body); } + } else { + body_begin = min; + } - // [body_begin, post_doubt_begin) - Stmt simplified_body = PartitionReplacer(partitions).Mutate(op->body); - Stmt body = Substitute(simplified_body, {{Var{op->loop_var}, op->loop_var + body_begin}}); - Stmt simplified_stmt = For::make(op->loop_var, 0, - post_doubt_begin - body_begin, op->for_type, op->device_api, body); - Stmt s = simplified_stmt; - if (pre_stmt.defined()) { - s = Block::make(pre_stmt, s); - } - if (post_stmt.defined()) { - s = Block::make(s, post_stmt); + Stmt post_stmt; + Expr post_doubt_begin; + if (true_itrv.as()->i.has_upper_bound()) { + post_doubt_begin = true_itrv.max() + 1; + if (!can_prove(true_itrv.max() == max)) { + if (!can_prove(max - post_doubt_begin >= 0)) { + LOG(WARNING) << "Cannot prove: " << (max - post_doubt_begin >= 0) + << ", when generating the post doubt loop"; + post_doubt_begin = Min::make(post_doubt_begin, max); + } + // [post_doubt_begin, max] + Stmt body = Substitute(op->body, + {{Var{op->loop_var}, op->loop_var + post_doubt_begin}}); + post_stmt = For::make(op->loop_var, 0, + max - post_doubt_begin + 1, op->for_type, op->device_api, body); } + } else { + post_doubt_begin = max + 1; + } - return Simplify(ConvertSSA(s)); + // [body_begin, post_doubt_begin) + Stmt simplified_body = PartitionReplacer(partitions).Mutate(op->body); + Stmt body = Substitute(simplified_body, {{Var{op->loop_var}, op->loop_var + body_begin}}); + Stmt simplified_stmt = For::make(op->loop_var, 0, + post_doubt_begin - body_begin, op->for_type, op->device_api, body); + Stmt s = simplified_stmt; + if (pre_stmt.defined()) { + s = Block::make(pre_stmt, s); + } + if (post_stmt.defined()) { + s = Block::make(s, post_stmt); } - private: - std::unordered_map vars_; -}; + return Simplify(ConvertSSA(s)); +} Stmt LoopPartition(Stmt stmt) { stmt = LoopPartitioner().Mutate(stmt); diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index b3ba1c7e681e..fa2ba7235dfd 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -30,6 +30,11 @@ def test_deduce(): assert str(res2.max()) == "neg_inf" assert str(res2.min()) == "pos_inf" + e3 = (-b)+a*c-d + res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) + ans3 = 2/c+1 + assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3) + def test_check(): a = tvm.Var('a') b = tvm.Var('b') From b5786b5c1539354e153f2aa29e09b85bbf556b3d Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 4 Mar 2017 04:20:39 +0000 Subject: [PATCH 09/11] clean code --- src/pass/loop_partition.cc | 38 ++++++++++++++------------------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index d899c9aa83aa..91359e90bb5d 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2016 by Contributors + * Copyright (c) 2017 by Contributors * \file loop_partition.cc */ #include @@ -38,9 +38,9 @@ bool ExprUseVars(Expr expr, const std::unordered_set& vars) { class PartitionFinder : public IRVisitor { public: explicit PartitionFinder(VarExpr loop_var, - const std::unordered_map& vars) - : target_var_(loop_var), out_vars_(vars.size()), hint_map_(vars), relax_map_() { - for (auto kv : vars) out_vars_.insert(kv.first); + const std::unordered_map& dom_map) + : target_var_(loop_var), out_vars_(dom_map.size()), hint_map_(dom_map) { + for (const auto& kv : dom_map) out_vars_.insert(kv.first); } void Visit_(const For* op) { @@ -95,35 +95,25 @@ class LoopPartitioner : public IRMutator { LoopPartitioner() {} Stmt Mutate_(const For* op, const Stmt& stmt) { - if (is_const(op->min) && is_const(op->extent)) { - vars_.insert({op->loop_var.get(), - IntSet::interval(op->min, op->min + op->extent - 1)}); - Stmt res = IRMutator::Mutate_(op, stmt); - vars_.erase(op->loop_var.get()); - return res; - } - - Stmt s = DoPartition(op, stmt); - - if (s.defined()) { - return s; - } else { - vars_.insert({op->loop_var.get(), - IntSet::interval(op->min, op->min + op->extent - 1)}); - Stmt res = IRMutator::Mutate_(op, stmt); - vars_.erase(op->loop_var.get()); - return res; + if (!is_const(op->min) || !is_const(op->extent)) { + Stmt s = DoPartition(op, stmt); + if (s.defined()) return s; } + dom_map_.insert({op->loop_var.get(), + IntSet::interval(op->min, op->min + op->extent - 1)}); + Stmt res = IRMutator::Mutate_(op, stmt); + dom_map_.erase(op->loop_var.get()); + return res; } private: Stmt DoPartition(const For* op, const Stmt& stmt); - std::unordered_map vars_; + std::unordered_map dom_map_; }; Stmt LoopPartitioner::DoPartition(const For* op, const Stmt& stmt) { - PartitionFinder finder(op->loop_var, vars_); + PartitionFinder finder(op->loop_var, dom_map_); finder.Visit(op->body); const auto& partitions = finder.partitions; From 8d0bbd9891f4baa8d4ecfaa1357dba6ca8f805c8 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 4 Mar 2017 04:26:16 +0000 Subject: [PATCH 10/11] const auto& --- src/pass/loop_partition.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 91359e90bb5d..2bd3db2bc56f 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -123,7 +123,7 @@ Stmt LoopPartitioner::DoPartition(const For* op, const Stmt& stmt) { Expr max = op->min + op->extent - 1; Array sets; // merge partitions (take their intersect) - for (auto kv : partitions) { + for (const auto& kv : partitions) { sets.push_back(kv.second.interval); } IntSet true_itrv = Intersect(sets); From d2a90b252bc512779435555a321dc7cefb5da2d8 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 4 Mar 2017 04:38:21 +0000 Subject: [PATCH 11/11] add test_multi_if --- .../unittest/test_pass_loop_partition.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index f4df2e490050..fd0662c8d906 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -33,7 +33,28 @@ def test_multi_loop(): assert('if' not in str(stmt.body.first)) print(stmt) +def test_multi_if(): + i = tvm.Var('i') + j = tvm.Var('j') + k = tvm.Var('k') + m = tvm.Var('m') + n = tvm.Var('n') + stmt = tvm.make.For( + i, 0, 4, 0, 0, + tvm.make.For( + j, 0, n, 0, 0, + tvm.make.For( + k, 0, m, 0, 0, + tvm.make.Block( + tvm.make.IfThenElse((i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)), + tvm.make.IfThenElse((i*m+j-k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)) + )))) + stmt = tvm.ir_pass.LoopPartition(stmt) + assert('if' not in str(stmt.body.first)) + print(stmt) + if __name__ == "__main__": test_basic() test_multi_loop() + test_multi_if()