diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 977a899ec93d..708e0259927a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -503,6 +503,34 @@ def _impl(inputs, input_types): scale=scale) return _impl +def _get_dims(data): + import torch + if isinstance(data, _expr.Expr): + dims = _infer_shape(data) + elif isinstance(data, list): + dims = data + elif isinstance(data, (torch.Tensor, np.ndarray)): + dims = data.shape + else: + msg = "Data type %s could not be parsed" % type(data) + raise AssertionError(msg) + return dims + +def _layer_norm(): + def _impl(inputs, input_types): + data = inputs[0] + ndims = len(_get_dims(inputs[1])) + assert ndims == 1, "Support only normalization over last one dimension." + + return _op.nn.layer_norm(data, + gamma=inputs[1], + beta=inputs[2], + axis=-1, + epsilon=float(inputs[4]), + center=False, + scale=False) + return _impl + def _transpose(): def _impl(inputs, input_types): data = inputs[0] @@ -1050,6 +1078,7 @@ def _wrap_const(c): "aten::contiguous" : _contiguous(), "aten::batch_norm" : _batch_norm(), "aten::instance_norm" : _instance_norm(), + "aten::layer_norm" : _layer_norm(), "aten::transpose" : _transpose(), "aten::transpose_" : _transpose(), "aten::t" : _transpose(), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e7c2e0841a87..fa32dca2d5b2 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -561,6 +561,9 @@ def test_forward_instancenorm(): (torch.nn.InstanceNorm3d(16), inp_3d)]: verify_model(ins_norm.eval(), input_data=inp) +def test_forward_layernorm(): + inp = torch.rand((20, 5, 10, 10)) + verify_model(torch.nn.LayerNorm(10).eval(), input_data=inp) def test_forward_transpose(): torch.set_grad_enabled(False) @@ -1132,6 +1135,7 @@ def forward(self, xs): test_forward_contiguous() test_forward_batchnorm() test_forward_instancenorm() + test_forward_layernorm() test_forward_transpose() test_forward_size() test_forward_view()