From fec514fd2754c831cf02d3d23556a813bc5fd80d Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Wed, 27 Mar 2019 01:00:45 +0000 Subject: [PATCH] add support for mxnet smooth_l1 --- python/tvm/relay/frontend/mxnet.py | 10 ++++++++++ tests/python/frontend/mxnet/test_forward.py | 9 +++++++++ 2 files changed, 19 insertions(+) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 758793c980d68..8a402fde4c5f6 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -566,6 +566,15 @@ def _mx_embedding(inputs, _): return _op.take(weight, indices.astype('int32'), axis=0) +def _mx_smooth_l1(inputs, attrs): + scalar = attrs.get_float("scalar", 1.0) + scalar_sq = scalar * scalar + mask = _op.less(inputs[0], _expr.const(1.0 / scalar_sq, dtype='float32')) + return _op.where(mask, + _expr.const(scalar_sq / 2.0, dtype='float32') * inputs[0] * inputs[0], + _op.abs(inputs[0]) - _expr.const(0.5 / scalar_sq)) + + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free _identity_list = [ @@ -701,6 +710,7 @@ def _mx_embedding(inputs, _): "Embedding" : _mx_embedding, "SoftmaxOutput" : _mx_softmax_output, "SoftmaxActivation" : _mx_softmax_activation, + "smooth_l1" : _mx_smooth_l1, # vision "_contrib_BilinearResize2D" : _mx_upsampling, "_contrib_MultiBoxPrior" : _mx_multibox_prior, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index aad666ca75b4a..faccfbfd12fe9 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -464,6 +464,14 @@ def verify(data_shape, weight_shape): verify((2, 2), (4, 5)) verify((2, 3, 4), (4, 5)) + +def test_forward_smooth_l1(): + data = mx.sym.var('data') + mx_sym = mx.sym.smooth_l1(data) + verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4)) + mx_sym = mx.sym.smooth_l1(data, scalar=1.0) + verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4)) + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -498,3 +506,4 @@ def verify(data_shape, weight_shape): test_forward_broadcast_axis() test_forward_full() test_forward_embedding() + test_forward_smooth_l1()