Skip to content

Commit

Permalink
[Relay][Op]Support symbolic TopK, Ones, Zeros and Full (apache#5459)
Browse files Browse the repository at this point in the history
* Support symbolic TopK, Ones, Zeros and Full

* Fix pylint

* Add docstring for topk shape func

* Fix grad

* Fix lazy_gradient_init

* Fix parser

* Fix print ir text

* Fix lint

* Improve pattern_util

* Fix topk

* Fix build

* Use Optional for attribute

* Fix clang-format

* Minot fix

* Fix pylint

* Fix build warning

* Fix parser

* Move ToScalar

* Fix lint

* Fix lint

* Make topk shape func as data independent when k is constant.

* Fix lint

* Minor fix
  • Loading branch information
kevinthesun authored and trevor-m committed Jun 18, 2020
1 parent 5deb83b commit a486905
Show file tree
Hide file tree
Showing 22 changed files with 435 additions and 186 deletions.
5 changes: 3 additions & 2 deletions include/tvm/relay/attrs/algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <tvm/relay/expr.h>

#include <string>

Expand All @@ -52,14 +53,14 @@ struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
};

struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
int k;
Optional<Integer> k;
int axis;
bool is_ascend;
std::string ret_type;
DataType dtype;

TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") {
TVM_ATTR_FIELD(k).set_default(1).describe("Number of top elements to select");
TVM_ATTR_FIELD(k).describe("Number of top elements to select");
TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis along which to sort the input tensor.");
TVM_ATTR_FIELD(ret_type).set_default("both").describe(
"The return type [both, values, indices]."
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {

/*! \brief Attributes that specify a tensor */
struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
Array<IndexExpr> shape;
Optional<Array<Integer>> shape;
DataType dtype;

TVM_DECLARE_ATTRS(InitOpAttrs, "relay.attrs.InitOpAttrs") {
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,11 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size)) << "Invalid DLTensor file format";
CHECK(data_byte_size == num_elems * elem_bytes) << "Invalid DLTensor file format";
CHECK(strm->Read(ret->data, data_byte_size)) << "Invalid DLTensor file format";
auto read_ret = strm->Read(ret->data, data_byte_size);
// Only check non-empty data
if (ndim > 0 && shape[0] != 0) {
CHECK(read_ret) << "Invalid DLTensor file format";
}
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def __call__(self, args, attrs, type_args):
attrs = {}
if self.operator is op.reshape:
x = self.operator(*args)
elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to):
x = self.operator(*args, dtype=attrs["dtype"])
else:
x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
if isinstance(x, expr.TupleWrapper):
Expand Down
68 changes: 68 additions & 0 deletions python/tvm/relay/op/_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
# pylint: disable=invalid-name,unused-argument
from __future__ import absolute_import

from tvm.te.hybrid import script
from tvm.runtime import convert

from . import strategy
from . import op as _reg
from .op import OpPattern, register_pattern
from .op import register_strategy

Expand All @@ -29,3 +33,67 @@
# topk
register_strategy("topk", strategy.topk_strategy)
register_pattern("topk", OpPattern.OPAQUE)

@script
def _topk_shape_func_input_data(data, k, axis):
ndim = len(data.shape)
val_out = output_tensor((ndim,), "int64")
indices_out = output_tensor((ndim,), "int64")

for i in const_range(ndim):
if i != axis:
val_out[i] = int64(data.shape[i])
indices_out[i] = int64(data.shape[i])
else:
if k[0] < 1:
val_out[i] = int64(data.shape[i])
indices_out[i] = int64(data.shape[i])
else:
val_out[i] = int64(k[0])
indices_out[i] = int64(k[0])
return val_out, indices_out

@script
def _topk_shape_func_input_shape(data_shape, k, axis):
ndim = data_shape.shape[0]
val_out = output_tensor((ndim,), "int64")
indices_out = output_tensor((ndim,), "int64")

for i in const_range(ndim):
if i != axis:
val_out[i] = int64(data_shape[i])
indices_out[i] = int64(data_shape[i])
else:
if k < 1:
val_out[i] = int64(data_shape[i])
indices_out[i] = int64(data_shape[i])
else:
val_out[i] = int64(k)
indices_out[i] = int64(k)
return val_out, indices_out

@_reg.register_shape_func("topk", True)
def topk_shape_func(attrs, inputs, _):
"""
Shape func for topk.
"""
axis = attrs.axis
if attrs.k is not None:
if axis < 0:
axis += inputs[0].shape[0]
val_out, indices_out = \
_topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis))
else:
if axis < 0:
axis += len(inputs[0].shape)
val_out, indices_out = \
_topk_shape_func_input_data(inputs[0], inputs[1], convert(axis))
ret_type = attrs.ret_type
if ret_type == "both":
ret = [val_out, indices_out]
elif ret_type == "values":
ret = [val_out]
else:
ret = [indices_out]

return ret
41 changes: 18 additions & 23 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
#pylint: disable=invalid-name, unused-argument, len-as-condition
"""Backend compiler related feature registration"""

from tvm.runtime import convert
from tvm.te.hybrid import script
import topi
from topi.util import get_const_tuple

from .op import register_compute, register_shape_func
from .op import register_broadcast_schedule, register_injective_schedule
from .op import register_pattern, OpPattern
Expand Down Expand Up @@ -93,7 +92,7 @@
# zeros
@register_compute("zeros")
def zeros_compute(attrs, inputs, output_type):
assert not inputs
assert len(inputs) == 1
return [topi.full(output_type.shape, output_type.dtype, 0.0)]

register_broadcast_schedule("zeros")
Expand All @@ -110,7 +109,7 @@ def zeros_like_compute(attrs, inputs, output_type):
# ones
@register_compute("ones")
def ones_compute(attrs, inputs, output_type):
assert not inputs
assert len(inputs) == 1
return [topi.full(output_type.shape, output_type.dtype, 1.0)]

register_broadcast_schedule("ones")
Expand All @@ -132,31 +131,26 @@ def clip_compute(attrs, inputs, output_type):

register_injective_schedule("clip")

@script
def _cast_shape_function(x):
out_ndim = len(x)
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = x[i]
return out

def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]

# full
@script
def _full_shape_func(shape):
out_ndim = len(shape)
out_ndim = shape.shape[0]
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = int64(shape[i])
return out

def full_shape_func(attrs, inputs, out_ndims):
"""
Shape func for zeros, zeros_like, ones, ones_like.
Shape func for full.
"""
return [_full_shape_func(inputs[1])]

def no_data_full_shape_func(attrs, inputs, out_ndims):
"""
Shape func for zeros and ones.
"""
shape = get_const_tuple(attrs.shape)
return [_full_shape_func(convert(shape))]
return [_full_shape_func(inputs[0])]

@script
def _broadcast_shape_func(x, y, ndim):
Expand Down Expand Up @@ -198,13 +192,14 @@ def elemwise_shape_func(attrs, inputs, _):
"""
return [topi.math.identity(inputs[0])]

register_shape_func("cast", False, cast_shape_func)
register_shape_func("zeros", False, full_shape_func)
register_shape_func("cast", False, elemwise_shape_func)
register_shape_func("zeros", True, no_data_full_shape_func)
register_shape_func("zeros_like", False, elemwise_shape_func)
register_shape_func("ones", False, full_shape_func)
register_shape_func("ones", True, no_data_full_shape_func)
register_shape_func("ones_like", False, elemwise_shape_func)
register_shape_func("full", False, full_shape_func)
register_shape_func("full", True, full_shape_func)
register_shape_func("full_like", False, elemwise_shape_func)
register_shape_func("broadcast_to", True, full_shape_func)

register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,14 @@ def divide_grad(orig, grad):

@register_gradient("zeros")
def zeros_grad(orig, grad):
"""Returns []"""
return []
"""Returns [shape]"""
return [orig.args[0]]


@register_gradient("ones")
def ones_grad(orig, grad):
"""Returns []"""
return []
"""Returns [shape]"""
return [orig.args[0]]


@register_gradient("zeros_like")
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def _concatenate_shape_func(inputs, axis):
@_reg.register_shape_func("concatenate", False)
def concatenate_shape_func(attrs, inputs, _):
axis = get_const_int(attrs.axis)
if axis < 0:
axis += inputs[0].shape[0]
return [_concatenate_shape_func(inputs, convert(axis))]

@script
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/relay/op/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Classic algorithm operation"""
from __future__ import absolute_import as _abs
from . import _make
from ..expr import TupleWrapper
from ..expr import TupleWrapper, const

def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
"""Performs sorting along the given axis and returns an array of indicies
Expand Down Expand Up @@ -48,7 +48,8 @@ def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
return _make.argsort(data, axis, is_ascend, dtype)


def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
def topk(data, k=1, axis=-1, ret_type="both",
is_ascend=False, dtype="int32"):
"""Get the top k elements in an input tensor along the given axis.
ret_type specifies the return type, can be one of ("both", "values", "indices").
Expand All @@ -58,7 +59,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
data : relay.Expr
The input data tensor.
k : int, optional
k : int or relay.Expr, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Expand All @@ -81,6 +82,8 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
out : relay.Expr or List[relay.Expr]
The computed result.
"""
if isinstance(k, int):
k = const(k, "int64")
out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
if ret_type == "both":
return TupleWrapper(out, 2)
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,9 @@ def argsort_strategy(attrs, inputs, out_type, target):
def wrap_compute_topk(topi_compute):
"""Wrap topk compute"""
def _compute_topk(attrs, inputs, out_type):
k = get_const_int(attrs.k)
k = inputs[1]
if attrs.k is not None:
k = attrs.k
axis = get_const_int(attrs.axis)
ret_type = attrs.ret_type
is_ascend = bool(get_const_int(attrs.is_ascend))
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm.runtime import TVMContext as _TVMContext

from . import _make
from ..expr import Tuple
from ..expr import Tuple, const


# We create a wrapper function for each operator in the
Expand Down Expand Up @@ -928,7 +928,7 @@ def zeros(shape, dtype):
Parameters
----------
shape : tuple of int
shape : tuple of int or relay.Expr
The shape of the target.
dtype : data type
Expand All @@ -939,6 +939,8 @@ def zeros(shape, dtype):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, (list, tuple)):
shape = const(list(shape), "int32")
return _make.zeros(shape, dtype)


Expand All @@ -963,7 +965,7 @@ def ones(shape, dtype):
Parameters
----------
shape : tuple of int
shape : tuple of int or relay.Expr
The shape of the target.
dtype : data type
Expand All @@ -974,6 +976,8 @@ def ones(shape, dtype):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, (list, tuple)):
shape = const(list(shape), "int32")
return _make.ones(shape, dtype)


Expand Down
8 changes: 6 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def full(fill_value, shape=(), dtype=""):
fill_value : relay.Expr
The value to fill. Must be a scalar.
shape : tuple of int
shape : tuple of int or relay.Expr
The shape of the target.
dtype : data type, optional (defaults to data type of the fill value)
Expand All @@ -310,6 +310,8 @@ def full(fill_value, shape=(), dtype=""):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, (list, tuple)):
shape = const(list(shape), "int32")
return _make.full(fill_value, shape, dtype)


Expand Down Expand Up @@ -527,14 +529,16 @@ def broadcast_to(data, shape):
data : relay.Expr
The input tensor.
shape : shape
shape : tuple of int or relay.Expr
Provide the shape to broadcast to.
Returns
-------
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, (list, tuple)):
shape = const(list(shape), "int32")
return _make.broadcast_to(data, shape)

def broadcast_to_like(data, broadcast_type):
Expand Down
Loading

0 comments on commit a486905

Please sign in to comment.