Skip to content

Commit

Permalink
Check added at mxnet frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN TRIPATHY authored and ANSHUMAN TRIPATHY committed May 18, 2020
1 parent fb87c0e commit e1609a6
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
6 changes: 5 additions & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,10 @@ def _mx_take(inputs, attrs):
axis = attrs.get_int("axis", 0)
return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)

def _mx_gather_nd(inputs, attrs):
assert len(inputs) == 2
assert len(_infer_shape(inputs[1])) > 1, "index tensor to have at least 2 dimensions"
return _op.gather_nd(inputs[0], inputs[1])

def _mx_reverse(inputs, attrs):
assert len(inputs) == 1
Expand Down Expand Up @@ -1770,7 +1774,6 @@ def impl(inputs, input_types):
"zeros_like",
"ones_like",
"where",
"gather_nd",
"cos",
"cosh",
"sin",
Expand Down Expand Up @@ -1918,6 +1921,7 @@ def impl(inputs, input_types):
"pad" : _mx_pad,
"Pad" : _mx_pad,
"take" : _mx_take,
"gather_nd" : _mx_gather_nd,
"reverse" : _mx_reverse,
"squeeze" : _mx_squeeze,
"broadcast_axis": _mx_broadcast_axis,
Expand Down
26 changes: 17 additions & 9 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,21 +608,29 @@ def verify(shape, indices_src, axis, mode="clip"):
verify((3,4), [-1, 5], 1, mode="wrap")

def test_forward_gather_nd():
def verify(xshape, yshape, y_data):
def verify(xshape, yshape, y_data, error=False):
x_data = np.random.uniform(size=xshape).astype("float32")
ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data))
mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data"))
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x_data": xshape, "y_data": yshape}, {"x_data": "float32", "y_data": "int32"})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
#ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data))
#mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data"))
try:
ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data))
mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data"))
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x_data": xshape, "y_data": yshape}, {"x_data": "float32", "y_data": "int32"})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
except Exception as e:
if not error:
raise e

verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])
verify((1, 4), (1, 1), [[0]])
verify((4,), (1,), [1], error=True)

def test_forward_bilinear_resize():
# add tests including scale_height and scale_width when mxnet is updated to version 1.5
Expand Down
8 changes: 8 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,14 @@ def test_forward_gather_nd():
np.reshape(np.arange(12), [2, 3, 2]).astype('int32'),
np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype('int32')
)
_test_gather_nd(
np.reshape(np.arange(4), [4]).astype('float32'),
np.asarray([1]).astype('int32')
)
_test_gather_nd(
np.reshape(np.arange(4), [1, 4]).astype('float32'),
np.asarray([3]).astype('int32')
)

#######################################################################
# StridedSlice
Expand Down

0 comments on commit e1609a6

Please sign in to comment.