From 05b1cf216d47fca36290f2c2aa23b7e2746c5ca1 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 8 Dec 2020 18:32:25 +0000 Subject: [PATCH] Support mode=instance, spatial for MXNet l2_normalize --- python/tvm/relay/frontend/mxnet.py | 13 ++++++++++--- tests/python/frontend/mxnet/test_forward.py | 6 ++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 2242be1bcdeb..f2330c72e1f4 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1091,12 +1091,19 @@ def _mx_box_decode(inputs, attrs): def _mx_l2_normalize(inputs, attrs): new_attrs = {} mode = attrs.get_str("mode", "instance") - if mode != "channel": + if mode == "channel": + new_attrs["axis"] = [1] + elif mode == "instance": + ndim = len(_infer_type(inputs[0]).checked_type.shape) + new_attrs["axis"] = list(range(1, ndim)) + elif mode == "spatial": + ndim = len(_infer_type(inputs[0]).checked_type.shape) + new_attrs["axis"] = list(range(2, ndim)) + else: raise tvm.error.OpAttributeInvalid( - 'Value of attribute "mode" must equal "channel" for operator l2_normalize.' + 'Mode "{}" is not supported for operator l2_normalize.'.format(mode) ) new_attrs["eps"] = attrs.get_float("eps", 1e-10) - new_attrs["axis"] = [1] return _op.nn.l2_normalize(inputs[0], **new_attrs) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 79c587fc7f9e..f076a27755ad 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -638,6 +638,12 @@ def test_forward_l2_normalize(): mx_sym = mx.sym.L2Normalization(data, mode="channel") verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5)) + mx_sym = mx.sym.L2Normalization(data, mode="instance") + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5)) + + mx_sym = mx.sym.L2Normalization(data, mode="spatial") + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5)) + @tvm.testing.uses_gpu def test_forward_logistic_regression_output():