diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index ec922ab1564ef..0febfdd85c4a5 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -324,6 +324,14 @@ def _mx_batch_norm(inputs, attrs): return _op.nn.batch_norm(*inputs, **new_attrs) +def _mx_instance_norm(inputs, attrs): + assert len(inputs) == 3 + new_attrs = {} + new_attrs["axis"] = attrs.get_int("axis", 1) + new_attrs["epsilon"] = attrs.get_float("eps", 1e-5) + return _op.nn.instance_norm(*inputs, **new_attrs) + + def _mx_layer_norm(inputs, attrs): assert len(inputs) == 3 if attrs.get_bool("output_mean_var", False): @@ -1133,6 +1141,7 @@ def _mx_one_hot(inputs, attrs): "Dropout" : _mx_dropout, "BatchNorm" : _mx_batch_norm, "BatchNorm_v1" : _mx_batch_norm, + "InstanceNorm" : _mx_instance_norm, "LayerNorm" : _mx_layer_norm, "LRN" : _mx_lrn, "L2Normalization" : _mx_l2_normalize, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 34e3fdd4e760c..a719f6885171e 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -758,6 +758,26 @@ def verify(shape, axis=1, fix_gamma=False): verify((2, 3, 4, 5), fix_gamma=True) +def test_forward_instance_norm(): + def verify(shape, axis=1, epsilon=1e-5): + x = np.random.uniform(size=shape).astype("float32") + gamma = np.random.uniform(size=(shape[axis])).astype("float32") + beta = np.random.uniform(size=(shape[axis])).astype("float32") + ref_res = mx.nd.InstanceNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta), epsilon) + mx_sym = mx.sym.InstanceNorm(mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), epsilon) + shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x, gamma, beta) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) + verify((2, 3, 4, 5)) + verify((32, 64, 80, 64)) + verify((8, 6, 5)) + verify((8, 7, 6, 5, 4)) + + def test_forward_layer_norm(): def verify(shape, axis=-1): x = np.random.uniform(size=shape).astype("float32") @@ -926,6 +946,7 @@ def verify(data_shape, kernel_size, stride, pad, num_filter): test_forward_sequence_mask() test_forward_contrib_div_sqrt_dim() test_forward_batch_norm() + test_forward_instance_norm() test_forward_layer_norm() test_forward_one_hot() test_forward_convolution()