diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 4edf0b80de4c..59e3903af327 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1737,6 +1737,12 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename), "broadcast_lesser" : _mx_compare(_op.less, _rename), "broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename), + "_equal" : _mx_compare(_op.equal, _rename), + "_not_equal" : _mx_compare(_op.not_equal, _rename), + "_greater" : _mx_compare(_op.greater, _rename), + "_greater_equal" : _mx_compare(_op.greater_equal, _rename), + "_lesser" : _mx_compare(_op.less, _rename), + "_lesser_equal" : _mx_compare(_op.less_equal, _rename), "elemwise_add" : _rename(_op.add), "elemwise_sub" : _rename(_op.subtract), "elemwise_mul" : _rename(_op.multiply), diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 4a9848e03b5e..7a82426a33e2 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -328,13 +328,19 @@ def test_forward_broadcast_ops(): def test_forward_elemwise_ops(): for op in ["elemwise_add", "elemwise_sub", "elemwise_mul", - "elemwise_div", "maximum", "minimum"]: + "elemwise_div", "maximum", "minimum", + operator.lt, operator.le, operator.eq, + operator.ne, operator.gt, operator.ge]: shape = (3, 4, 5) dtype = 'float32' a_np = np.random.uniform(size=shape).astype(dtype) b_np = np.random.uniform(size=shape).astype(dtype) - mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')]) - ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)]) + if type(op) == str: + mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')]) + ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)]) + else: + mx_sym = op(mx.sym.var('a'), mx.sym.var('b')) + ref_res = op(mx.nd.array(a_np), mx.nd.array(b_np)) shapes = {'a': shape, 'b': shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, ctx in ctx_list():