From 9d883b88ce348ea79b423c2852a5dcfc5fea3d10 Mon Sep 17 00:00:00 2001 From: MORITA Kazutaka Date: Fri, 24 Apr 2020 11:54:26 +0900 Subject: [PATCH] [FRONTEND][MXNET] support elemwise logic ops (#5361) --- python/tvm/relay/frontend/mxnet.py | 6 ++++++ tests/python/frontend/mxnet/test_forward.py | 12 +++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index be2f1105d960..775eb53d2592 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1751,6 +1751,12 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename), "broadcast_lesser" : _mx_compare(_op.less, _rename), "broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename), + "_equal" : _mx_compare(_op.equal, _rename), + "_not_equal" : _mx_compare(_op.not_equal, _rename), + "_greater" : _mx_compare(_op.greater, _rename), + "_greater_equal" : _mx_compare(_op.greater_equal, _rename), + "_lesser" : _mx_compare(_op.less, _rename), + "_lesser_equal" : _mx_compare(_op.less_equal, _rename), "elemwise_add" : _rename(_op.add), "elemwise_sub" : _rename(_op.subtract), "elemwise_mul" : _rename(_op.multiply), diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 10edff9031fe..5e4c137fe392 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -328,13 +328,19 @@ def test_forward_broadcast_ops(): def test_forward_elemwise_ops(): for op in ["elemwise_add", "elemwise_sub", "elemwise_mul", - "elemwise_div", "maximum", "minimum"]: + "elemwise_div", "maximum", "minimum", + operator.lt, operator.le, operator.eq, + operator.ne, operator.gt, operator.ge]: shape = (3, 4, 5) dtype = 'float32' a_np = np.random.uniform(size=shape).astype(dtype) b_np = np.random.uniform(size=shape).astype(dtype) - mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')]) - ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)]) + if type(op) == str: + mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')]) + ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)]) + else: + mx_sym = op(mx.sym.var('a'), mx.sym.var('b')) + ref_res = op(mx.nd.array(a_np), mx.nd.array(b_np)) shapes = {'a': shape, 'b': shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, ctx in ctx_list():