Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PASS]LoopPartition #56

Merged
merged 13 commits into from
Mar 4, 2017
4 changes: 3 additions & 1 deletion include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ using Halide::Internal::make_const;
using Halide::Internal::make_zero;
using Halide::Internal::as_const_int;
using Halide::Internal::as_const_uint;
using Halide::Internal::const_true;
using Halide::Internal::const_false;

inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
Expand All @@ -53,8 +55,8 @@ class Var : public Halide::VarExpr {
public:
explicit Var(const std::string& name_hint = "v",
Type t = Int(32)) : VarExpr(name_hint, t) {}

explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
explicit Var(VarExpr v) : VarExpr(v) {}

/*! \brief type indicate the container type */
using ContainerType = Variable;
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ class IRMutator {
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e);
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e);
virtual Stmt Mutate_(const Provide* op, const Stmt& e);
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& s);
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& s);
virtual Stmt Mutate_(const Provide* op, const Stmt& s);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& e);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& s);

virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& e);
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ Stmt InjectVirtualThread(Stmt stmt);
*/
Stmt LiftAllocate(Stmt stmt);

/*!
* \brief partition loops in the stmt
* \param stmt The stmt to do loop partition
* \return Transformed stmt.
*/
Stmt LoopPartition(Stmt stmt);

/*!
* \brief Make an user callable API LoweredFunc.
*
Expand Down
4 changes: 3 additions & 1 deletion src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ TVM_REGISTER_API(_arith_EvalModular)

TVM_REGISTER_API(_arith_DeduceBound)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], args[2]);
*ret = DeduceBound(args[0], args[1],
args[2].operator Map<Var, IntSet>(),
args[3].operator Map<Var, IntSet>());
});

TVM_REGISTER_API(_IntervalSetGetMin)
Expand Down
1 change: 1 addition & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(LiftAllocate);
REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(LoopPartition);

} // namespace ir
} // namespace tvm
69 changes: 49 additions & 20 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using Halide::Internal::Interval;
// from a expression.
class VariablePathFinder: public IRVisitor {
public:
explicit VariablePathFinder(Var target) : target_(target) {}
explicit VariablePathFinder(Expr target) : target_(target) {}

void Visit(const NodeRef& node) final {
if (visited_.count(node.get()) != 0) return;
Expand All @@ -37,13 +37,13 @@ class VariablePathFinder: public IRVisitor {

private:
bool found_{false};
Var target_;
Expr target_;
std::unordered_set<const Node*> visited_;
};

// get the path to the variable,
// return empty vector to represent failure
std::vector<const Node*> GetPath(Var target, Expr expr) {
std::vector<const Node*> GetPath(Expr target, Expr expr) {
VariablePathFinder v(target);
v.Visit(expr);
return v.path_;
Expand All @@ -56,11 +56,11 @@ class BoundDeducer: public IRVisitor {
public:
friend class BoundDeduceInputChecker;
friend class Converter;
BoundDeducer(Var target, Expr expr,
const std::unordered_map<const Variable*, IntSet>& dom_map)
: target_(target), expr_(expr), dom_map_(dom_map) {}
BoundDeducer(Expr target, Expr expr,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map)
: target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}

bool Init();
void Deduce();

void Visit(const NodeRef& e) final {
Expand Down Expand Up @@ -137,9 +137,14 @@ class BoundDeducer: public IRVisitor {
bool success{true};

private:
Var target_;
void Init();
void Transform();
void Relax();

Expr target_;
Expr expr_;
const std::unordered_map<const Variable*, IntSet>& dom_map_;
const std::unordered_map<const Variable*, IntSet>& hint_map_;
const std::unordered_map<const Variable*, IntSet>& relax_map_;
ExprIntSetMap expr_map_;
std::vector<const Node*> path_;
size_t iter_{0};
Expand All @@ -163,10 +168,13 @@ class BoundDeduceInputChecker: public IRVisitor {
size_t target_count{0};
};

bool BoundDeducer::Init() {
void BoundDeducer::Init() {
BoundDeduceInputChecker checker;
if (!checker.Check(this)) success = false;
Transform();
}

void BoundDeducer::Transform() {
if (const LT* op = expr_.as<LT>()) {
is_greater = false;
is_equal = false;
Expand All @@ -190,30 +198,35 @@ bool BoundDeducer::Init() {
} else {
success = false;
}
return success;
}

void BoundDeducer::Deduce() {
Init();
if (!success) return;

Relax();
// get the path
path_ = GetPath(target_, expr_);
// get the sign of every subexpr
expr_map_ = EvalSetForEachSubExpr(expr_, dom_map_);
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);

Visit(expr_);
}

// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
IntSet DeduceBound(Var v, Expr e,
const Map<Var, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
dmap[kv.first.get()] = kv.second;
void BoundDeducer::Relax() {
if (is_greater) {
expr_ = EvalSet(expr_ , relax_map_).min();
result = EvalSet(result, relax_map_).max();
} else {
expr_ = EvalSet(expr_ , relax_map_).max();
result = EvalSet(result, relax_map_).min();
}
BoundDeducer d(v, e, dmap);
}

IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success) return IntSet::nothing();
Expr min = Interval::neg_inf, max = Interval::pos_inf;
Expand All @@ -225,5 +238,21 @@ IntSet DeduceBound(Var v, Expr e,
return IntSet::interval(min, max);
}

// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
IntSet DeduceBound(Expr v, Expr e,
const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map) {
std::unordered_map<const Variable*, IntSet> hmap;
for (auto kv : hint_map) {
hmap[kv.first.get()] = kv.second;
}
std::unordered_map<const Variable*, IntSet> rmap;
for (auto kv : relax_map) {
rmap[kv.first.get()] = kv.second;
}
return DeduceBound(v, e, hmap, rmap);
}

} // namespace arith
} // namespace tvm
19 changes: 14 additions & 5 deletions src/arithmetic/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ inline bool MatchPoint(const IntSet& a,
return i.is_single_point() && i.min.same_as(b);
}

IntSet Union(const Array<IntSet>& set) {
if (set.size() == 1) return set[0];
Interval x = set[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < set.size(); ++i) {
IntSet s = set[i].cover_interval();
IntSet Union(const Array<IntSet>& sets) {
if (sets.size() == 1) return sets[0];
Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < sets.size(); ++i) {
IntSet s = sets[i].cover_interval();
const Interval& y = s.as<IntervalSet>()->i;
if (can_prove(x.max + 1 >= y.min)) {
x.max = y.max;
Expand All @@ -179,6 +179,15 @@ IntSet Union(const Array<IntSet>& set) {
return IntervalSet::make(x);
}

IntSet Intersect(const Array<IntSet>& sets) {
Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < sets.size(); ++i) {
Interval y = sets[i].cover_interval().as<IntervalSet>()->i;
x = Interval::make_intersection(x, y);
}
return IntervalSet::make(x);
}

// type traits
template<typename OP>
struct is_logical_op {
Expand Down
20 changes: 17 additions & 3 deletions src/arithmetic/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <vector>

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -157,6 +158,13 @@ ExprIntSetMap EvalSetForEachSubExpr(Expr r,
*/
IntSet Union(const Array<IntSet>& sets);

/*!
* \brief Create an union set of all sets
* \param sets The sets to be intersected
* \return the set after intersected
*/
IntSet Intersect(const Array<IntSet>& sets);

// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
Expand All @@ -169,11 +177,17 @@ inline const IntSetNode* IntSet::operator->() const {
*
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param dom_map The domain of each variable.
* \param hint_map The domain of variable, used to help deduce.
* \param relax The domain of each variable, used to relax the domain.
* \return An integer set that can cover all the possible values.
*/
IntSet DeduceBound(Var v, Expr cond,
const Map<Var, IntSet>& dom_map);
IntSet DeduceBound(Expr v, Expr cond,
const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map);
IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map);


} // namespace arith
} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion src/pass/ir_mutator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Stmt then_case = this->Mutate(op->then_case);
Stmt else_case;
if (else_case.defined()) {
if (op->else_case.defined()) {
else_case = this->Mutate(op->else_case);
}
if (condition.same_as(op->condition) &&
Expand Down
Loading