Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, Topi] [TF, MXNet] Unravel Index operator #5082

Merged
merged 11 commits into from
Mar 23, 2020
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ List of operators
topi.strided_slice
topi.expand_dims
topi.reshape
topi.unravel_index
topi.squeeze
topi.concatenate
topi.split
Expand Down Expand Up @@ -147,6 +148,7 @@ topi
.. autofunction:: topi.strided_slice
.. autofunction:: topi.expand_dims
.. autofunction:: topi.reshape
.. autofunction:: topi.unravel_index
.. autofunction:: topi.squeeze
.. autofunction:: topi.concatenate
.. autofunction:: topi.split
Expand Down
1 change: 1 addition & 0 deletions docs/frontend/tensorflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -242,5 +242,6 @@ Supported Ops
- Transpose
- TruncateMod
- Unpack
- UnravelIndex
- Where
- ZerosLike
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ This level enables additional math and transform operators.
tvm.relay.repeat
tvm.relay.tile
tvm.relay.reverse
tvm.relay.unravel_index


**Level 4: Broadcast and Reductions**
Expand Down Expand Up @@ -299,6 +300,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.repeat
.. autofunction:: tvm.relay.tile
.. autofunction:: tvm.relay.reverse
.. autofunction:: tvm.relay.unravel_index


Level 4 Definitions
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,12 @@ struct ArgWhereAttrs : public tvm::AttrsNode<ArgWhereAttrs> {
}
}; // struct ArgWhereAttrs

/*! \brief Attributes used in unravel_index operators */
struct UnRavelIndexAttrs : public tvm::AttrsNode<UnRavelIndexAttrs> {
TVM_DECLARE_ATTRS(UnRavelIndexAttrs, "relay.attrs.UnRavelIndexAttrs") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's any need to define an attribute type for an operator without attributes. Although argwhere seems to do the same thing you have, other operators without attributes just don't use one (see nn.batch_flatten as one example). I'd argue we should try to avoid defining unnecessary attrs to prevent bloat.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Thanks. This is good to know. I have removed the attrs for both unravel_index and argwhere.

}
}; // struct UnRavelIndexAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
8 changes: 8 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def impl(inputs, attrs):
return impl


def _mx_unravel_index(inputs, attrs):
assert len(inputs) == 1
shape = attrs.get_int_tuple("shape")
shape_expr = _expr.const(list(shape))
return _op.unravel_index(inputs[0], shape_expr)


def _mx_zeros(inputs, attrs):
assert len(inputs) == 0
shape = attrs.get_int_tuple("shape")
Expand Down Expand Up @@ -1825,6 +1832,7 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
"Embedding" : _mx_embedding,
"argsort" : _mx_argsort,
"topk" : _mx_topk,
"_unravel_index": _mx_unravel_index,
"SequenceMask" : _mx_sequence_mask,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,11 @@ def _impl(inputs, attr, params):
return inputs[0]
return _impl

def _unravel_index():
def _impl(inputs, attr, params):
return _op.unravel_index(inputs[0], inputs[1])
return _impl

def _crop_and_resize():
def _impl(inputs, attr, params):
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
Expand Down Expand Up @@ -1736,6 +1741,7 @@ def _impl(inputs, attr, params):
'Transpose' : _transpose(),
'TruncateMod' : _elemwise('mod'),
'Unpack' : _unpack(),
'UnravelIndex' : _unravel_index(),
'Where' : _where(),
'ZerosLike' : AttrCvt('zeros_like'),

Expand Down Expand Up @@ -2509,9 +2515,7 @@ def _parse_param(self, key, value, name, shape):

array_ndim = len(np_array.shape)
if array_ndim == 0:
new_array = np.empty([1], dtype=np_array.dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this because we want to pass the scalar as scalar only and not as a tensor of rank 1.

new_array[0] = np_array
self._nodes[name] = [tvm.relay.const(new_array)]
self._nodes[name] = [tvm.relay.const(np_array)]
else:
self._params[name] = tvm.nd.array(np_array)
self._nodes[name] = [_expr.var(name,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
_reg.register_injective_schedule("sequence_mask")
_reg.register_injective_schedule("one_hot")
_reg.register_reduce_schedule("collapse_sum_like")
_reg.register_injective_schedule("unravel_index")

# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,3 +861,26 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
[0, 0, 1]]
"""
return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)


def unravel_index(indices, shape):
"""Convert a flat index or array of flat indices into a tuple of coordinate arrays.

Example::
- unravel_index([22, 41, 37], [7, 6]) = [[3, 6, 6],[4, 5, 1]]

Parameters
----------
indices : relay.Expr
An integer array containing indices.

shape : relay.Expr
The shape of the array.

Returns
-------
result : relay.Expr
The tuple of coordinate arrays.
"""

return _make.unravel_index(indices, shape)
66 changes: 66 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2703,5 +2703,71 @@ RELAY_REGISTER_OP("one_hot")
.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);

/* relay.unravel_index */
TVM_REGISTER_NODE_TYPE(UnRavelIndexAttrs);

bool UnRavelIndexRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);

const auto* indices = types[0].as<TensorTypeNode>();
if (indices == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "unravel_index: expect input type to be TensorType but get " << types[0];
return false;
}
CHECK(indices->dtype.is_int()) << "indices of unravel_index must be tensor of integer";

const auto* shape = types[1].as<TensorTypeNode>();
if (shape == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "unravel_index: expect input type to be TensorType but get " << types[1];
return false;
}
CHECK(indices->dtype.is_int()) << "shape of unravel_index must be tensor of integer";

Array<IndexExpr> indices_shape;
Array<IndexExpr> shape_shape;
indices_shape = indices->shape;
shape_shape = shape->shape;

Array<IndexExpr> oshape;
oshape.push_back(shape_shape[0]);
if (indices_shape.size() != 0) {
oshape.push_back(indices_shape[0]);
}
reporter->Assign(types[2], TensorType(oshape, indices->dtype));
return true;
}

Array<te::Tensor> UnRavelIndexCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type) {
return Array<te::Tensor>{topi::unravel_index(inputs[0], inputs[1])};
}

Expr MakeUnRavelIndex(Expr data, Expr shape) {
auto attrs = make_object<UnRavelIndexAttrs>();
static const Op& op = Op::Get("unravel_index");
return CallNode::make(op, {data, shape}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.unravel_index").set_body_typed(MakeUnRavelIndex);

RELAY_REGISTER_OP("unravel_index")
.describe(
R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays.
Example::
- unravel_index([22, 41, 37], (7, 6)) = [[3, 6, 6], [4, 5, 1]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.set_attrs_type<UnRavelIndexAttrs>()
.set_support_level(3)
.add_type_rel("UnRavelIndexRel", UnRavelIndexRel)
.set_attr<FTVMCompute>("FTVMCompute", UnRavelIndexCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

masahi marked this conversation as resolved.
Show resolved Hide resolved
} // namespace relay
} // namespace tvm
29 changes: 28 additions & 1 deletion tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,32 @@ def verify(a_np, b_np):
verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))


def test_forward_unravel_index():
def verify(x, shape, dtype):
a_np = np.array(x).astype(dtype)
mx_sym = _mx_symbol(mx.sym, 'unravel_index', [mx.sym.var('a'), shape])
ref_res = _mx_symbol(mx.nd, 'unravel_index', [mx.nd.array(a_np), shape])
shapes = {'a': a_np.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)

for target, ctx in ctx_list():
for kind in ["graph", "vm", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(a_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

for dtype in ["int32", "int64"]:
verify([0, 1, 2, 3], [2, 2], dtype)
verify([144, 13, 45], [6, 7, 10, 2], dtype)
verify([456], [6, 7, 10, 2], dtype)

# In below example, 5 is out of bound for array of size 4.
# MXNet implementation provides different result than TVM
# TVM implementation is inline with Tensorflow
# Ideally error should be thrown just like Numpy
# verify([0, 1, 2, 5], [2, 2], dtype)


if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand Down Expand Up @@ -1003,4 +1029,5 @@ def verify(a_np, b_np):
test_forward_convolution()
test_forward_deconvolution()
test_forward_cond()
test_forward_make_loss()
test_forward_make_loss()
test_forward_unravel_index()
52 changes: 52 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3037,6 +3037,57 @@ def test_forward_add_n():
_test_forward_add_n(in5)


#######################################################################
# Unravel Index
# ----------------------
def _test_forward_unravel_index(inputs):
tf.reset_default_graph()
with tf.Graph().as_default():
temp = []
for each in inputs:
temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype))
output = tf.unravel_index(temp[0], temp[1])
compare_tf_with_tvm([each for each in inputs], [
each.name for each in temp], output.name)


def _test_forward_unravel_index_scalar(x, y, dtype="int32"):
tf.reset_default_graph()
with tf.Graph().as_default():
indices_1 = constant_op.constant(x, dtype=dtype)
dims_1 = constant_op.constant(y, dtype=dtype)
out_1 = array_ops.unravel_index(indices_1, dims_1)
compare_tf_with_tvm([], [], out_1.name)


def test_forward_unravel_index():
x = np.array([0, 1, 2, 3])
y = np.array([2, 2])
_test_forward_unravel_index([x, y])

x = np.array([0, 1, 2, 5])
y = np.array([2, 2])
_test_forward_unravel_index([x, y])

x = np.array([0, 1, 2, 5])
y = np.array([2])
_test_forward_unravel_index([x, y])

x = np.array([102, 300, 16])
y = np.array([10, 10, 9, 6])
_test_forward_unravel_index([x, y])

x = np.array([100])
y = np.array([10, 10, 9, 6])
_test_forward_unravel_index([x, y])

# Test scalar input
_test_forward_unravel_index_scalar(13, [1, 4, 5, 2])


#######################################################################
# Dilation2d
# ----------------------
def _test_dilation2d(tensor_in_sizes, filter_in_sizes,
strides, dilations, padding):
""" One iteration of dilation2d with given shapes and attributes """
Expand Down Expand Up @@ -3151,6 +3202,7 @@ def test_forward_dilation():
test_forward_squared_difference()
test_forward_add_n()
test_forward_floormod()
test_forward_unravel_index()

# Reductions
test_forward_argminmax()
Expand Down
39 changes: 39 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,44 @@ def verify_gather_nd(xshape, yshape, y_data):
verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])


def test_unravel_index():
def verify_unravel_index(indices, shape, dtype):
x_data = np.array(indices).astype(dtype)
y_data = np.array(shape).astype(dtype)
x = relay.var("x", relay.TensorType(x_data.shape, dtype))
y = relay.var("y", relay.TensorType(y_data.shape, dtype))

z = relay.unravel_index(x, y)
zz = run_infer_type(z)

if len(x_data.shape) == 1:
out_shape = [y_data.shape[0], x_data.shape[0]]
else:
out_shape = [y_data.shape[0]]
assert zz.checked_type == relay.ty.TensorType(out_shape, dtype)

func = relay.Function([x, y], z)
ref_res = np.unravel_index(x_data, y_data)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

for dtype in ["int64", "int32"]:
verify_unravel_index([0, 1, 2, 3], [2, 2], dtype)
verify_unravel_index([144], [5, 5, 5, 2], dtype)
verify_unravel_index(144, [5, 5, 5, 2], dtype)
verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype)

# In below example, 5 is out of bound for array of size 4.
# Numpy implementation throws error for it
# TVM implementation does not throw error instead it produces
# output which is inline with Tensorflow
# verify_unravel_index([0, 1, 2, 5], [2, 2], dtype)


if __name__ == "__main__":
test_arange()
test_cast()
Expand Down Expand Up @@ -713,3 +751,4 @@ def verify_gather_nd(xshape, yshape, y_data):
test_tile()
test_repeat()
test_gather_nd()
test_unravel_index()
Loading