Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Bound for Shape Variables #4486

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2a43573
add AssertLowerBound
yzhliu Dec 4, 2019
eb63b3a
add AssertLowerBound to ir_deep_compare.cc
yzhliu Dec 9, 2019
d1d3047
Make Range for AsserLowerBound; add support in codegen
yzhliu Dec 9, 2019
f3e3305
Expr bound
yzhliu Dec 10, 2019
d801459
use intrinsic assert_bound
yzhliu Dec 11, 2019
105a077
simplify nested assert_bound, deal with assert_bound in InternalVal
yzhliu Dec 17, 2019
81386e2
add test case for assert_bound rewrite simplify
yzhliu Dec 18, 2019
55bd45e
fix deduce bound accordingly
yzhliu Dec 22, 2019
c405d7b
fix floordiv IntervalSetEvaluator when b \in [0, n]
yzhliu Dec 23, 2019
92e319b
merge upstream
yzhliu Dec 23, 2019
e3639b9
fix lint
yzhliu Dec 23, 2019
9c94b5d
fix compile error
yzhliu Dec 23, 2019
38fb4d3
add assert_bound in hybrid script
yzhliu Dec 27, 2019
d078c83
fix lint
yzhliu Dec 27, 2019
ac26b27
fix auto buffer bind for assert_bound
yzhliu Dec 27, 2019
fe4975c
Merge remote-tracking branch 'upstream/master' into assert_bound_expr
yzhliu Dec 27, 2019
35d71b8
debug ci
yzhliu Dec 27, 2019
82adeab
revoke
yzhliu Dec 27, 2019
cfecf9d
retrigger
yzhliu Dec 27, 2019
10bc349
fix out of bound in path_ visit
yzhliu Dec 28, 2019
68ae224
fix test_any.py
yzhliu Dec 28, 2019
3fd47d4
Merge remote-tracking branch 'upstream/master' into assert_bound_expr
yzhliu Dec 29, 2019
8d16054
fix lint
yzhliu Dec 29, 2019
4f31cce
fix gpu unittest
yzhliu Dec 30, 2019
5277e81
merge from upstream
yzhliu Dec 31, 2019
3d15b40
polish bound deducer
yzhliu Jan 1, 2020
768b95b
fix build
yzhliu Jan 1, 2020
a905f1f
fix bound remover
yzhliu Jan 1, 2020
894681d
Merge remote-tracking branch 'upstream/master' into assert_bound_expr
yzhliu Jan 1, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/python/dev.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ tvm.ir_pass
tvm.ir_pass.SplitPipeline
tvm.ir_pass.LowerThreadAllreduce
tvm.ir_pass.LowerIntrin
tvm.ir_pass.RemoveIntrin
tvm.ir_pass.LowerTVMBuiltin
tvm.ir_pass.NarrowChannelAccess

Expand Down
2 changes: 2 additions & 0 deletions docs/api/python/tvm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ The user facing API for computation declaration.
tvm.min
tvm.max
tvm.tag_scope
tvm.assert_bound

.. autofunction:: tvm.load_json
.. autofunction:: tvm.save_json
Expand All @@ -70,3 +71,4 @@ The user facing API for computation declaration.
.. autofunction:: tvm.min
.. autofunction:: tvm.max
.. autofunction:: tvm.tag_scope
.. autofunction:: tvm.assert_bound
10 changes: 10 additions & 0 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,16 @@ TVM_DLL Expr nearbyint(Expr x);
*/
TVM_DLL Expr trunc(Expr x);

/*!
* \brief Pass bound information of value.
* \param value The input expression.
* \param lower The lower bound of value (inclusive).
* \param upper The upper bound of value (inclusive).
* \return The Call node indicates lower and upper bound of input expression.
* This intrinsic will be removed before codegen.
*/
TVM_DLL Expr assert_bound(Expr value, Expr lower, Expr upper);

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1613,6 +1613,16 @@ constexpr const char* tvm_fill_fragment = "tvm_fill_fragment";
*/
constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";

/*!
* \brief tvm intrinsic for passing bound information of the variables.
* It simply represents the value, while it helps BoundAnalyzer
* understand the upper and lower bound of the value.
* Expr tvm_assert_bound(Expr value, Expr lower_bound, Expr upper_bound) {
* return value;
* }
*/
constexpr const char* tvm_assert_bound = "tvm_assert_bound";

} // namespace intrinsic

/*!
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,13 @@ LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func);
*/
LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);

/*!
* \brief Remove intrinsic function calls if possible.
* \param f The function to be processed.
* \return Transformed function.
*/
LoweredFunc RemoveIntrin(LoweredFunc f);

/*!
* \brief Lower custom datatypes.
*
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def placeholder(shape, dtype=None, name="placeholder"):
The created tensor
"""
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
shape = tuple(assert_bound(size, 0, None) for size in shape)
dtype = float32 if dtype is None else dtype
return _api_internal._Placeholder(
shape, dtype, name)
Expand Down Expand Up @@ -296,6 +297,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
# for python3
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
shape = tuple(assert_bound(size, 0, None) for size in shape)
ndim = len(shape)
code = fcompute.__code__

Expand Down Expand Up @@ -1047,6 +1049,27 @@ def floormod(a, b):
return _make._OpFloorMod(a, b)


def assert_bound(value, lower=None, upper=None):
"""Pass bound information of value.

Parameters
----------
value : Expr
The input expression.
lower : Expr
The lower bound of value (inclusive). Default +inf
upper : Expr
The upper bound of value (inclusive). Default -inf

Returns
-------
res : Expr
Call node indicates lower and upper bound of input expression.
This intrinsic will be removed before codegen.
"""
return _make._OpAssertBound(value, lower, upper)


_init_api("tvm.api")

#pylint: disable=unnecessary-lambda
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,16 @@ def get_binds(args, compact=False, binds=None):
binds = {} if binds is None else binds.copy()
cfg = current_build_config()
arg_list = []

def is_var(idx):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename -> is_var_or_assert_bound

if isinstance(idx, expr.Var) or \
(isinstance(idx, expr.Call) and idx.name == "tvm_assert_bound"):
return True
return False

for x in args:
if isinstance(x, tensor.Tensor):
any_dim = any(isinstance(i, expr.Var) for i in x.shape)
any_dim = any(is_var(i) for i in x.shape)
buffer_type = "auto_broadcast" if any_dim and not compact else ""
if x not in binds:
buf = api.decl_buffer(x.shape,
Expand Down Expand Up @@ -499,7 +506,9 @@ def _build_for_device(flist, target, target_host):
fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
fdevice = [ir_pass.RemoveIntrin(x) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.RemoveIntrin(x) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mdev = codegen.build_module(fdevice, str(target)) if fdevice else None

Expand Down
11 changes: 11 additions & 0 deletions python/tvm/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,14 @@ def max_num_threads(func_id, args):
_internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint")
res = _tgt.current_target(args[0].value).max_num_threads
return _api.convert(res)

def tvm_assert_bound(func_id, args):
n = args.__len__()
_internal_assert(func_id == "tvm_assert_bound", "This function cannot be directly invoked!")
_internal_assert(n >= 1, "At least 1 argument should be provided.")
_internal_assert(n <= 3, "Accept at most 3 arguments.")
if n == 1:
return _make._OpAssertBound(args[0], None, None)
elif n == 2:
return _make._OpAssertBound(args[0], args[1], None)
return _make._OpAssertBound(*args)
2 changes: 1 addition & 1 deletion python/tvm/hybrid/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def visit_Call(self, node):
_internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \
['range', 'max', 'min', 'len'] + \
list(self.symbols.keys()), \
"Function call id not in intrinsics' list")
"Function call id " + func_id + " not in intrinsics' list")
for elem in node.args:
self.visit(elem)

Expand Down
81 changes: 52 additions & 29 deletions python/tvm/hybrid/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,36 +110,59 @@ def max_num_threads(allow_none=True):
return target.current_target(allow_none).max_num_threads


def tvm_assert_bound(value, lower=None, upper=None): #pylint: disable=unused-argument
"""
Provide lower bound and upper bound for the value.
For now we simply return the value

Parameters
----------
value: Expr
The bounded value
lower: Expr
lower bound (inclusive)
upper: Expr
upper bound (inclusive)

Returns
-------
res: Expr
same as value
"""
return value


HYBRID_GLOBALS = {
'unroll' : range,
'vectorize' : range,
'parallel' : range,
'const_range' : range,
'bind' : bind,
'allocate' : allocate,
'output_tensor' : allocate,
'sqrt' : numpy.sqrt,
'rsqrt' : rsqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount,
'likely' : lambda cond: cond,
'uint8' : numpy.uint8,
'uint16' : numpy.uint16,
'uint32' : numpy.uint32,
'uint64' : numpy.uint64,
'int8' : numpy.int8,
'int16' : numpy.int16,
'int32' : numpy.int32,
'int64' : numpy.int64,
'float16' : numpy.float16,
'float32' : numpy.float32,
'float64' : numpy.float64,
'ceil_div' : lambda a, b: (a + b - 1) // b,
'max_num_threads': max_num_threads
'unroll' : range,
'vectorize' : range,
'parallel' : range,
'const_range' : range,
'bind' : bind,
'allocate' : allocate,
'output_tensor' : allocate,
'sqrt' : numpy.sqrt,
'rsqrt' : rsqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount,
'likely' : lambda cond: cond,
'uint8' : numpy.uint8,
'uint16' : numpy.uint16,
'uint32' : numpy.uint32,
'uint64' : numpy.uint64,
'int8' : numpy.int8,
'int16' : numpy.int16,
'int32' : numpy.int32,
'int64' : numpy.int64,
'float16' : numpy.float16,
'float32' : numpy.float32,
'float64' : numpy.float64,
'ceil_div' : lambda a, b: (a + b - 1) // b,
'max_num_threads' : max_num_threads,
'tvm_assert_bound' : tvm_assert_bound
}


Expand Down
5 changes: 4 additions & 1 deletion src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ TVM_REGISTER_API("make.Allocate")
} \
})


REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
Expand Down Expand Up @@ -225,6 +224,10 @@ TVM_REGISTER_API("make._OpIfThenElse")
.set_body_typed<Expr(Expr, Expr, Expr)>([] (Expr cond, Expr true_value, Expr false_value) {
return if_then_else(cond, true_value, false_value);
});
TVM_REGISTER_API("make._OpAssertBound")
.set_body_typed<Expr(Expr, Expr, Expr)>([] (Expr value, Expr lower, Expr upper) {
return assert_bound(value, lower, upper);
});

} // namespace ir
} // namespace tvm
1 change: 1 addition & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ REGISTER_PASS(LowerThreadAllreduce);
REGISTER_PASS(LowerWarpMemory);
REGISTER_PASS(RemapThreadAxis);
REGISTER_PASS(LowerIntrin);
REGISTER_PASS(RemoveIntrin);
REGISTER_PASS(LowerCustomDatatypes);
REGISTER_PASS(LowerTVMBuiltin);
REGISTER_PASS(CombineContextCall);
Expand Down
52 changes: 51 additions & 1 deletion src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,41 @@ std::vector<const Object*> GetPath(Expr target, Expr expr) {
return v.path_;
}

class BoundRemover : public IRMutator {
public:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE, we are in progress of updating IRMutators to new style, would be great if we can directly change it here related #4607

Expr Remove(Expr e) {
remove_bounded_ = true;
return IRMutator::Mutate(ir::Simplify(e));
}

Expr Reset(Expr e) {
remove_bounded_ = false;
return IRMutator::Mutate(e);
}

Expr Mutate_(const Call* op, const Expr& e) final {
if (op->is_intrinsic(intrinsic::tvm_assert_bound) && remove_bounded_) {
Expr value = op->args[0];
const Variable* var = value.as<Variable>();
CHECK(var) << "Invalid value in " << e << ". It should have been simplified.";
bounded_var_map_[var] = GetRef<Expr>(op);
return value;
}
return IRMutator::Mutate_(op, e);
}

Expr Mutate_(const Variable* op, const Expr& e) final {
if (!remove_bounded_ && bounded_var_map_.count(op)) {
return bounded_var_map_[op];
}
return e;
}

private:
bool remove_bounded_ = false;
std::unordered_map<const Variable*, Expr> bounded_var_map_;
};

enum CompareOp {kGreater, kLess, kEqual};

// a visitor to deduce the bound of a variable from a expression
Expand All @@ -84,7 +119,7 @@ class BoundDeducer: public IRVisitor {

void Visit(const ObjectRef& e) final {
if (!success_) return;
if (e.get() == path_[iter_++]) {
if (iter_ < path_.size() && e.get() == path_[iter_++]) {
IRVisitor::Visit(e);
} else {
success_ = false;
Expand Down Expand Up @@ -295,6 +330,18 @@ void BoundDeducer::Transform() {
void BoundDeducer::Deduce() {
Init();
if (!success_) return;

// Any variable appears in both expr and result,
// they should not be eagerly simplified according to its bound
// e.g., i + n/4 >= n
// => i >= n - n/4
// If we eagerly simplified the left side given assert_bound(n, 0, +inf)
// we would get i + 0 >= n => i >= n, which is obviously incorrect.
// Thus we remove assert_bound here and reset later.
BoundRemover bound_remover;
expr_ = bound_remover.Remove(expr_);
result_ = bound_remover.Remove(result_);

Relax();
if (!success_) return;
// get the path
Expand All @@ -306,6 +353,9 @@ void BoundDeducer::Deduce() {
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);

Visit(expr_);

expr_ = bound_remover.Reset(expr_);
result_ = bound_remover.Reset(result_);
}

void BoundDeducer::Relax() {
Expand Down
5 changes: 5 additions & 0 deletions src/arithmetic/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ class ConstIntBoundAnalyzer::Impl :
return VisitRightShift(op);
} else if (op->is_intrinsic(Call::bitwise_and)) {
return VisitBitwiseAnd(op);
} else if (op->is_intrinsic(intrinsic::tvm_assert_bound)) {
Expr value = op->args[0];
Entry lower = VisitExpr(op->args[1]);
Entry upper = VisitExpr(op->args[2]);
return MakeBound(lower.min_value, upper.max_value);
} else {
return Everything(op->dtype);
}
Expand Down
Loading