-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ARITH] DeduceBound #40
Changes from 19 commits
2b9cb33
461d778
bccdfce
31478b2
2b4d091
2539265
0ea07ab
cf9f3ba
8f72e50
9106944
1f1ff8f
b409040
e4bee27
e3a5f9e
b1617a8
7abe378
f829694
96ded33
5a8fa91
71349f6
f3e3fa9
d5aedde
35683a8
d9794bb
696976a
434835a
2527b2e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# pylint: disable=protected-access, no-member | ||
"""Arithmetic data structure and utility""" | ||
from __future__ import absolute_import as _abs | ||
from ._ctypes._node import NodeBase, register_node | ||
from . import _api_internal | ||
|
||
@register_node | ||
class IntSet(NodeBase): | ||
pass | ||
|
||
@register_node | ||
class IntervalSet(IntSet): | ||
def min(self): | ||
return _api_internal._IntervalSetGetMin(self) | ||
|
||
def max(self): | ||
return _api_internal._IntervalSetGetMax(self) | ||
|
||
@register_node | ||
class StrideSet(IntSet): | ||
pass | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
/*! | ||
* Copyright (c) 2016 by Contributors | ||
* Implementation of API functions related to arith | ||
* \file api_arith.cc | ||
*/ | ||
#include <tvm/expr.h> | ||
#include <tvm/ir.h> | ||
#include <tvm/api_registry.h> | ||
#include "../arithmetic/int_set.h" | ||
#include "../arithmetic/int_set_internal.h" | ||
|
||
namespace tvm { | ||
namespace arith { | ||
|
||
TVM_REGISTER_API(_arith_intset_single_point) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = IntSet::single_point(args[0]); | ||
}); | ||
|
||
TVM_REGISTER_API(_arith_intset_range) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = IntSet::range(args[0], args[1]); | ||
}); | ||
|
||
TVM_REGISTER_API(_arith_DeduceBound) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = DeduceBound(args[0], args[1], args[2]); | ||
}); | ||
|
||
TVM_REGISTER_API(_IntervalSetGetMin) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
IntSet s = args[0].operator IntSet(); | ||
*ret = s.as<IntervalSet>()->i.min; | ||
}); | ||
|
||
TVM_REGISTER_API(_IntervalSetGetMax) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
IntSet s = args[0].operator IntSet(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider do it in one line
|
||
*ret = s.as<IntervalSet>()->i.max; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe consider add a function max to integer set, instead of exposing IntervalSet |
||
}); | ||
|
||
} // namespace arith | ||
} // namespace tvm |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file bound_deducer.cc | ||
* \brief Utility to deduce bound of expression | ||
*/ | ||
#include <tvm/expr.h> | ||
#include <tvm/ir_pass.h> | ||
#include <tvm/ir_visitor.h> | ||
#include <tvm/api_registry.h> | ||
#include <unordered_set> | ||
#include <unordered_map> | ||
#include "./int_set.h" | ||
|
||
namespace tvm { | ||
namespace arith { | ||
|
||
using namespace ir; | ||
using Halide::Internal::Interval; | ||
|
||
// a visitor to find the path to the target variable | ||
// from a expression. | ||
class VariablePathFinder: public IRVisitor { | ||
public: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. VariablePathFinder There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to look out for errors when a variable appears in multiple locations in the expression |
||
explicit VariablePathFinder(Var target) : target_(target) {} | ||
|
||
void Visit(const NodeRef& node) final { | ||
if (!success) return; | ||
if (visited_.count(node.get()) != 0 && | ||
!node.same_as(target_)) { | ||
return; | ||
} | ||
visited_.insert(node.get()); | ||
|
||
if (!found_) path_.push_back(node.get()); | ||
if (node.same_as(target_)) { | ||
if (!found_) { | ||
found_ = true; | ||
} else { | ||
// target variable appears at multiple location | ||
success = false; | ||
return; | ||
} | ||
} | ||
IRVisitor::Visit(node); | ||
if (!found_) path_.pop_back(); | ||
} | ||
|
||
std::vector<const Node*> path_; | ||
bool success{true}; | ||
|
||
private: | ||
bool found_{false}; | ||
Var target_; | ||
std::unordered_set<const Node*> visited_; | ||
}; | ||
|
||
// get the path to the variable, | ||
// return empty vector to represent failure | ||
std::vector<const Node*> GetPath(Var target, Expr expr) { | ||
VariablePathFinder v(target); | ||
v.Visit(expr); | ||
return v.success ? v.path_ : std::vector<const Node*>(); | ||
} | ||
|
||
// a visitor to deduce the bound of a variable from a expression | ||
class BoundDeducer: public IRVisitor { | ||
public: | ||
BoundDeducer(Var target, Expr expr, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doing everything in constructor have a problem of not being able to throw exception out, consider do it in another function, say Deduce |
||
const std::unordered_map<const Variable*, IntSet>& dom_map) | ||
: target_(target), expr_(expr), dom_map_(dom_map) { | ||
// get the path | ||
path_ = GetPath(target, expr); | ||
if (path_.empty()) { | ||
success = false; | ||
return; | ||
} | ||
iter_ = 0; | ||
result = make_zero(expr.type()); | ||
// get the sign of every subexpr | ||
expr_map_ = EvalSetForEachSubExpr(expr, dom_map); | ||
|
||
Visit(expr); | ||
} | ||
|
||
void Visit(const NodeRef& e) final { | ||
if (!success) return; | ||
if (e.get() == path_[iter_++]) { | ||
IRVisitor::Visit(e); | ||
} else { | ||
success = false; | ||
return; | ||
} | ||
} | ||
|
||
void Visit_(const LT* op) final { | ||
is_greater = false; | ||
is_equal = false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems to be bad practice to detect it inside visitor. What if we have something like
Since only simple case is handled, consider do the detection outside the visitor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I add a check before deduce |
||
result = op->b; | ||
Visit(op->a); | ||
} | ||
|
||
void Visit_(const LE* op) final { | ||
is_greater = false; | ||
is_equal = true; | ||
result = op->b; | ||
Visit(op->a); | ||
} | ||
|
||
void Visit_(const GT* op) final { | ||
is_greater = true; | ||
is_equal = false; | ||
result = op->b; | ||
Visit(op->a); | ||
} | ||
|
||
void Visit_(const GE* op) final { | ||
is_greater = true; | ||
is_equal = true; | ||
result = op->b; | ||
Visit(op->a); | ||
} | ||
|
||
void Visit_(const Add* op) final { | ||
bool left = op->a.get() == path_[iter_]; | ||
result -= left ? op->b : op->a; | ||
Visit(left ? op->a : op->b); | ||
} | ||
|
||
void Visit_(const Sub* op) final { | ||
bool left = op->a.get() == path_[iter_]; | ||
if (left) { | ||
result += op->b; | ||
} else { | ||
result -= op->a; | ||
result = -1 * result; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change to -result, or 0- result, negation should be overloaded already? |
||
is_greater = !is_greater; | ||
} | ||
Visit(left ? op->a : op->b); | ||
} | ||
|
||
void Visit_(const Mul* op) final { | ||
bool left = op->a.get() == path_[iter_]; | ||
Expr operand = left ? op->b : op->a; | ||
|
||
SignType sign; | ||
if (operand.type().is_uint()) { | ||
sign = kPositive; | ||
} else { | ||
sign = expr_map_[operand].sign_type(); | ||
} | ||
|
||
if (sign == SignType::kNegative) { | ||
is_greater = !is_greater; | ||
} else if (sign == SignType::kUnknown) { | ||
// unable to get the sign of operand | ||
success = false; | ||
return; | ||
} | ||
|
||
// always use relax bound | ||
if (is_greater) { | ||
result = result / operand + 1; | ||
} else { | ||
result = result / operand - 1; | ||
} | ||
Visit(left ? op->a : op->b); | ||
} | ||
|
||
Expr result; | ||
bool is_greater{true}; | ||
bool is_equal{true}; | ||
bool success{true}; | ||
|
||
private: | ||
Var target_; | ||
Expr expr_; | ||
const std::unordered_map<const Variable*, IntSet>& dom_map_; | ||
std::vector<const Node*> path_; | ||
size_t iter_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. directly initialize it to 0 here |
||
ExprIntSetMap expr_map_; | ||
}; | ||
|
||
// assuming e >= 0, deduce the bound of variable from it. | ||
// return empty set to represent deduce failure. | ||
IntSet DeduceBound(Var v, Expr e, | ||
const Map<Var, IntSet>& dom_map) { | ||
std::unordered_map<const Variable*, IntSet> dmap; | ||
for (auto kv : dom_map) { | ||
dmap[kv.first.get()] = kv.second; | ||
} | ||
BoundDeducer d(v, e, dmap); | ||
if (!d.success) return IntSet(); | ||
Expr min = Interval::neg_inf, max = Interval::pos_inf; | ||
if (d.is_greater) { | ||
min = d.is_equal ? d.result : d.result+1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. space between add |
||
} else { | ||
max = d.is_equal ? d.result : d.result-1; | ||
} | ||
return IntSet::range(min, max); | ||
} | ||
|
||
} // namespace arith | ||
} // namespace tvm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this include