Skip to content

Commit

Permalink
[Relay, Topi] [TF, MXNet] Unravel Index operator (apache#5082)
Browse files Browse the repository at this point in the history
* first cut unravel_index

* merge fixes

* change rates to dilations

* unravel_index op relay, topi, mxnet, tf

* doc changes

* small changes

* remove empty unravel and argwhere attrs

* remove empty unravel and argwhere attrs
  • Loading branch information
maheshambule authored and zhiics committed Apr 17, 2020
1 parent 4730703 commit 2e55c27
Show file tree
Hide file tree
Showing 16 changed files with 353 additions and 15 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ List of operators
topi.strided_slice
topi.expand_dims
topi.reshape
topi.unravel_index
topi.squeeze
topi.concatenate
topi.split
Expand Down Expand Up @@ -147,6 +148,7 @@ topi
.. autofunction:: topi.strided_slice
.. autofunction:: topi.expand_dims
.. autofunction:: topi.reshape
.. autofunction:: topi.unravel_index
.. autofunction:: topi.squeeze
.. autofunction:: topi.concatenate
.. autofunction:: topi.split
Expand Down
1 change: 1 addition & 0 deletions docs/frontend/tensorflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -242,5 +242,6 @@ Supported Ops
- Transpose
- TruncateMod
- Unpack
- UnravelIndex
- Where
- ZerosLike
3 changes: 2 additions & 1 deletion docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ This level enables additional math and transform operators.
tvm.relay.repeat
tvm.relay.tile
tvm.relay.reverse
tvm.relay.unravel_index


**Level 4: Broadcast and Reductions**
Expand Down Expand Up @@ -217,4 +218,4 @@ This level supports dialect operators.
:nosignatures:

tvm.relay.qnn.op.requantize
tvm.relay.qnn.op.conv2d
tvm.relay.qnn.op.conv2d
6 changes: 0 additions & 6 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,6 @@ struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
}
}; // struct OneHotAttrs

/*! \brief Attributes for ArgWhere operator */
struct ArgWhereAttrs : public tvm::AttrsNode<ArgWhereAttrs> {
TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") {
}
}; // struct ArgWhereAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
8 changes: 8 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ def impl(inputs, attrs):
return impl


def _mx_unravel_index(inputs, attrs):
assert len(inputs) == 1
shape = attrs.get_int_tuple("shape")
shape_expr = _expr.const(list(shape))
return _op.unravel_index(inputs[0], shape_expr)


def _mx_zeros(inputs, attrs):
assert len(inputs) == 0
shape = attrs.get_int_tuple("shape")
Expand Down Expand Up @@ -1826,6 +1833,7 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
"Embedding" : _mx_embedding,
"argsort" : _mx_argsort,
"topk" : _mx_topk,
"_unravel_index": _mx_unravel_index,
"SequenceMask" : _mx_sequence_mask,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,11 @@ def _impl(inputs, attr, params):
return inputs[0]
return _impl

def _unravel_index():
def _impl(inputs, attr, params):
return _op.unravel_index(inputs[0], inputs[1])
return _impl

def _crop_and_resize():
def _impl(inputs, attr, params):
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
Expand Down Expand Up @@ -1744,6 +1749,7 @@ def _impl(inputs, attr, params):
'Transpose' : _transpose(),
'TruncateMod' : _elemwise('mod'),
'Unpack' : _unpack(),
'UnravelIndex' : _unravel_index(),
'Where' : _where(),
'ZerosLike' : AttrCvt('zeros_like'),

Expand Down Expand Up @@ -2517,9 +2523,7 @@ def _parse_param(self, key, value, name, shape):

array_ndim = len(np_array.shape)
if array_ndim == 0:
new_array = np.empty([1], dtype=np_array.dtype)
new_array[0] = np_array
self._nodes[name] = [tvm.relay.const(new_array)]
self._nodes[name] = [tvm.relay.const(np_array)]
else:
self._params[name] = tvm.nd.array(np_array)
self._nodes[name] = [_expr.var(name,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
_reg.register_injective_schedule("sequence_mask")
_reg.register_injective_schedule("one_hot")
_reg.register_reduce_schedule("collapse_sum_like")
_reg.register_injective_schedule("unravel_index")

# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,3 +861,26 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
[0, 0, 1]]
"""
return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)


def unravel_index(indices, shape):
"""Convert a flat index or array of flat indices into a tuple of coordinate arrays.
Example::
- unravel_index([22, 41, 37], [7, 6]) = [[3, 6, 6],[4, 5, 1]]
Parameters
----------
indices : relay.Expr
An integer array containing indices.
shape : relay.Expr
The shape of the array.
Returns
-------
result : relay.Expr
The tuple of coordinate arrays.
"""

return _make.unravel_index(indices, shape)
72 changes: 69 additions & 3 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -806,15 +806,13 @@ bool ArgWhereRel(const Array<Type>& types,
TVM_REGISTER_GLOBAL("relay.op._make.argwhere")
.set_body_typed([](Expr data) {
static const Op& op = Op::Get("argwhere");
auto attrs = make_object<ArgWhereAttrs>();
return CallNode::make(op, {data}, Attrs(attrs), {});
return CallNode::make(op, {data}, Attrs(), {});
});

RELAY_REGISTER_OP("argwhere")
.describe(R"doc(Find the indices of elements of a tensor that are
non-zero)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type<ArgWhereAttrs>()
.add_argument("condition", "Tensor", "The input condition tensor.")
.add_type_rel("ArgWhere", ArgWhereRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
Expand Down Expand Up @@ -2662,5 +2660,73 @@ RELAY_REGISTER_OP("one_hot")
.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);

/* relay.unravel_index */
bool UnRavelIndexRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);

const auto* indices = types[0].as<TensorTypeNode>();
if (indices == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "unravel_index: expect input type to be TensorType but get "
<< types[0];
return false;
}
CHECK(indices->dtype.is_int())
<< "indices of unravel_index must be tensor of integer";

const auto* shape = types[1].as<TensorTypeNode>();
if (shape == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "unravel_index: expect input type to be TensorType but get "
<< types[1];
return false;
}
CHECK(indices->dtype.is_int())
<< "shape of unravel_index must be tensor of integer";

Array<IndexExpr> indices_shape;
Array<IndexExpr> shape_shape;
indices_shape = indices->shape;
shape_shape = shape->shape;

Array<IndexExpr> oshape;
oshape.push_back(shape_shape[0]);
if (indices_shape.size() != 0) {
oshape.push_back(indices_shape[0]);
}
reporter->Assign(types[2], TensorType(oshape, indices->dtype));
return true;
}

Array<te::Tensor> UnRavelIndexCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type) {
return Array<te::Tensor>{topi::unravel_index(inputs[0], inputs[1])};
}

Expr MakeUnRavelIndex(Expr data,
Expr shape) {
static const Op& op = Op::Get("unravel_index");
return CallNode::make(op, {data, shape}, Attrs(), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.unravel_index")
.set_body_typed(MakeUnRavelIndex);

RELAY_REGISTER_OP("unravel_index")
.describe(R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays.
Example::
- unravel_index([22, 41, 37], (7, 6)) = [[3, 6, 6], [4, 5, 1]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.set_support_level(3)
.add_type_rel("UnRavelIndexRel", UnRavelIndexRel)
.set_attr<FTVMCompute>("FTVMCompute", UnRavelIndexCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

} // namespace relay
} // namespace tvm
29 changes: 28 additions & 1 deletion tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,32 @@ def verify(a_np, b_np):
verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))


def test_forward_unravel_index():
def verify(x, shape, dtype):
a_np = np.array(x).astype(dtype)
mx_sym = _mx_symbol(mx.sym, 'unravel_index', [mx.sym.var('a'), shape])
ref_res = _mx_symbol(mx.nd, 'unravel_index', [mx.nd.array(a_np), shape])
shapes = {'a': a_np.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)

for target, ctx in ctx_list():
for kind in ["graph", "vm", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(a_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

for dtype in ["int32", "int64"]:
verify([0, 1, 2, 3], [2, 2], dtype)
verify([144, 13, 45], [6, 7, 10, 2], dtype)
verify([456], [6, 7, 10, 2], dtype)

# In below example, 5 is out of bound for array of size 4.
# MXNet implementation provides different result than TVM
# TVM implementation is inline with Tensorflow
# Ideally error should be thrown just like Numpy
# verify([0, 1, 2, 5], [2, 2], dtype)


if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand Down Expand Up @@ -1003,4 +1029,5 @@ def verify(a_np, b_np):
test_forward_convolution()
test_forward_deconvolution()
test_forward_cond()
test_forward_make_loss()
test_forward_make_loss()
test_forward_unravel_index()
52 changes: 52 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3057,6 +3057,57 @@ def test_forward_add_n():
_test_forward_add_n(in5)


#######################################################################
# Unravel Index
# ----------------------
def _test_forward_unravel_index(inputs):
tf.reset_default_graph()
with tf.Graph().as_default():
temp = []
for each in inputs:
temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype))
output = tf.unravel_index(temp[0], temp[1])
compare_tf_with_tvm([each for each in inputs], [
each.name for each in temp], output.name)


def _test_forward_unravel_index_scalar(x, y, dtype="int32"):
tf.reset_default_graph()
with tf.Graph().as_default():
indices_1 = constant_op.constant(x, dtype=dtype)
dims_1 = constant_op.constant(y, dtype=dtype)
out_1 = array_ops.unravel_index(indices_1, dims_1)
compare_tf_with_tvm([], [], out_1.name)


def test_forward_unravel_index():
x = np.array([0, 1, 2, 3])
y = np.array([2, 2])
_test_forward_unravel_index([x, y])

x = np.array([0, 1, 2, 5])
y = np.array([2, 2])
_test_forward_unravel_index([x, y])

x = np.array([0, 1, 2, 5])
y = np.array([2])
_test_forward_unravel_index([x, y])

x = np.array([102, 300, 16])
y = np.array([10, 10, 9, 6])
_test_forward_unravel_index([x, y])

x = np.array([100])
y = np.array([10, 10, 9, 6])
_test_forward_unravel_index([x, y])

# Test scalar input
_test_forward_unravel_index_scalar(13, [1, 4, 5, 2])


#######################################################################
# Dilation2d
# ----------------------
def _test_dilation2d(tensor_in_sizes, filter_in_sizes,
strides, dilations, padding):
""" One iteration of dilation2d with given shapes and attributes """
Expand Down Expand Up @@ -3173,6 +3224,7 @@ def test_forward_dilation():
test_forward_squared_difference()
test_forward_add_n()
test_forward_floormod()
test_forward_unravel_index()

# Reductions
test_forward_argminmax()
Expand Down
39 changes: 39 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,44 @@ def verify_gather_nd(xshape, yshape, y_data):
verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])


def test_unravel_index():
def verify_unravel_index(indices, shape, dtype):
x_data = np.array(indices).astype(dtype)
y_data = np.array(shape).astype(dtype)
x = relay.var("x", relay.TensorType(x_data.shape, dtype))
y = relay.var("y", relay.TensorType(y_data.shape, dtype))

z = relay.unravel_index(x, y)
zz = run_infer_type(z)

if len(x_data.shape) == 1:
out_shape = [y_data.shape[0], x_data.shape[0]]
else:
out_shape = [y_data.shape[0]]
assert zz.checked_type == relay.ty.TensorType(out_shape, dtype)

func = relay.Function([x, y], z)
ref_res = np.unravel_index(x_data, y_data)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

for dtype in ["int64", "int32"]:
verify_unravel_index([0, 1, 2, 3], [2, 2], dtype)
verify_unravel_index([144], [5, 5, 5, 2], dtype)
verify_unravel_index(144, [5, 5, 5, 2], dtype)
verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype)

# In below example, 5 is out of bound for array of size 4.
# Numpy implementation throws error for it
# TVM implementation does not throw error instead it produces
# output which is inline with Tensorflow
# verify_unravel_index([0, 1, 2, 5], [2, 2], dtype)


if __name__ == "__main__":
test_arange()
test_cast()
Expand Down Expand Up @@ -713,3 +751,4 @@ def verify_gather_nd(xshape, yshape, y_data):
test_tile()
test_repeat()
test_gather_nd()
test_unravel_index()
Loading

0 comments on commit 2e55c27

Please sign in to comment.