Skip to content

Commit

Permalink
Doc error handled
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN TRIPATHY authored and ANSHUMAN TRIPATHY committed May 18, 2020
1 parent e1609a6 commit f19b35c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 16 deletions.
23 changes: 8 additions & 15 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,27 +610,20 @@ def verify(shape, indices_src, axis, mode="clip"):
def test_forward_gather_nd():
def verify(xshape, yshape, y_data, error=False):
x_data = np.random.uniform(size=xshape).astype("float32")
#ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data))
#mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data"))
try:
ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data))
mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data"))
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x_data": xshape, "y_data": yshape}, {"x_data": "float32", "y_data": "int32"})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
except Exception as e:
if not error:
raise e
ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data))
mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data"))
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x_data": xshape, "y_data": yshape}, {"x_data": "float32", "y_data": "int32"})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])
verify((1, 4), (1, 1), [[0]])
verify((4,), (1,), [1], error=True)

def test_forward_bilinear_resize():
# add tests including scale_height and scale_width when mxnet is updated to version 1.5
Expand Down
1 change: 0 additions & 1 deletion topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,6 @@ inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_t
*
* \param data The source array.
* \param indices The indices of the values to extract.
* \param one_dim_support To allow user to input 1 dim tensor.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
Expand Down

0 comments on commit f19b35c

Please sign in to comment.