From c9be99097077eeeb7ed7c1a0b826f92a7149bfa3 Mon Sep 17 00:00:00 2001 From: Samuel Date: Wed, 3 Jun 2020 13:59:38 +0530 Subject: [PATCH] [MXNET]Softmin, trunc op support added (#5715) --- python/tvm/relay/frontend/mxnet.py | 7 +++++++ tests/python/frontend/mxnet/test_forward.py | 12 +++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index c75612dd4916..7f9950b808c3 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -846,6 +846,11 @@ def _mx_softsign(inputs, attrs): return inputs[0] / (_expr.const(1.0) + _op.abs(inputs[0])) +def _mx_softmin(inputs, attrs): + axis = attrs.get_int("axis", -1) + return _op.nn.softmax(_op.negative(inputs[0]), axis) + + def _mx_hard_sigmoid(inputs, attrs): x = (_expr.const(0.2) * inputs[0]) + _expr.const(0.5) return _op.clip(x, a_min=0.0, a_max=1.0) @@ -1829,6 +1834,7 @@ def impl(inputs, input_types): "floor", "ceil", "round", + "trunc", "sign", "sigmoid", "negative", @@ -1938,6 +1944,7 @@ def impl(inputs, input_types): "log_softmax" : _softmax_op(_op.nn.log_softmax), "Softmax" : _softmax_op(_op.nn.softmax), "softsign" : _mx_softsign, + "softmin" : _mx_softmin, "hard_sigmoid" : _mx_hard_sigmoid, "reciprocal" : _mx_reciprocal, # per op specialization diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 5ed2fb890b71..463b50f5b265 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -372,8 +372,17 @@ def test_forward_elemwise_ops(): tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) +def test_forward_softmin(): + data = mx.sym.var('data') + mx_sym = mx.sym.softmin(data) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 3, 100, 100)) + + mx_sym = mx.sym.softmin(data, axis=2) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 3, 100, 100)) + + def test_forward_unary_ops(): - for op in ["abs", "sqrt", "ceil", "floor", "round", "reciprocal", + for op in ["abs", "sqrt", "ceil", "floor", "round", "reciprocal", "trunc", "softsign", "hard_sigmoid", "cos", "sin", "tan", "cosh", "sinh", "tanh", @@ -1191,6 +1200,7 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size test_forward_rrelu() test_forward_prelu() test_forward_softrelu() + test_forward_softmin() test_forward_fc_flatten() test_forward_clip() test_forward_split()