From 989b4819442943fca378f935a58df49b75624cc5 Mon Sep 17 00:00:00 2001 From: Samuel Date: Wed, 8 Apr 2020 09:15:41 +0530 Subject: [PATCH] [PYTORCH]celu, gelu, selu activations (#5263) --- python/tvm/relay/frontend/pytorch.py | 38 +++++++++++++++++-- tests/python/frontend/pytorch/test_forward.py | 34 ++++++++++++++++- 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 1f10e60f6ac0..46068a4e24ed 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -216,15 +216,44 @@ def _impl(inputs, input_types): def _leaky_relu(): def _impl(inputs, input_types): data = inputs[0] - alpha = int(inputs[1]) + alpha = float(inputs[1]) return _op.nn.leaky_relu(data, alpha) return _impl def _elu(): def _impl(inputs, input_types): data = inputs[0] - alpha = _expr.const(int(inputs[1]), dtype='float32') - return alpha * _op.nn.relu(alpha - _op.exp(data)) + _op.nn.relu(data) + alpha = _expr.const(float(inputs[1])) + return alpha * _op.nn.relu(_expr.const(1.0) - _op.exp(data)) + _op.nn.relu(data) + return _impl + +def _celu(): + def _impl(inputs, input_types): + data = inputs[0] + alpha = _expr.const(float(inputs[1])) + return alpha * _op.nn.relu(_expr.const(1.0) - _op.exp(data / alpha)) + _op.nn.relu(data) + return _impl + +def _gelu(): + def _impl(inputs, input_types): + import math + data = inputs[0] + + def _pow3(x): + return x * x * x + return _expr.const(0.5) * data * (_expr.const(1.0) + + _op.tanh(_expr.const(math.sqrt(2.0 / math.pi)) * + (data + _expr.const(0.044715) * _pow3(data)))) + return _impl + +def _selu(): + def _impl(inputs, input_types): + data = inputs[0] + # https://pytorch.org/docs/stable/nn.html#selu + alpha = _expr.const(-1.6732632423543772848170429916717) + gamma = _expr.const(1.0507009873554804934193349852946) + return gamma * (alpha * _op.nn.relu(_expr.const(1.0) + - _op.exp(data)) + _op.nn.relu(data)) return _impl def _log_sigmoid(): @@ -1066,6 +1095,9 @@ def _wrap_const(c): "aten::prelu" : _prelu(), "aten::leaky_relu" : _leaky_relu(), "aten::elu" : _elu(), + "aten::celu" : _celu(), + "aten::gelu" : _gelu(), + "aten::selu" : _selu(), "aten::log_sigmoid" : _log_sigmoid(), "aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(), "aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index fb3f18ba079a..05bf7e460890 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -353,16 +353,43 @@ def test_forward_prelu(): def test_forward_leakyrelu(): torch.set_grad_enabled(False) - input_shape = [10, 10] + input_shape = [1, 3, 10, 10] input_data = torch.rand(input_shape).float() + verify_model(torch.nn.LeakyReLU().eval(), input_data=input_data) verify_model(torch.nn.LeakyReLU(negative_slope=0.05).eval(), input_data=input_data) + verify_model(torch.nn.LeakyReLU(negative_slope=1.0).eval(), input_data=input_data) + verify_model(torch.nn.LeakyReLU(negative_slope=1.25).eval(), input_data=input_data) def test_forward_elu(): torch.set_grad_enabled(False) - input_shape = [10, 10] + input_shape = [1, 3, 10, 10] input_data = torch.rand(input_shape).float() + verify_model(torch.nn.ELU().eval(), input_data=input_data) + verify_model(torch.nn.ELU(alpha=0.3).eval(), input_data=input_data) + verify_model(torch.nn.ELU(alpha=1.0).eval(), input_data=input_data) verify_model(torch.nn.ELU(alpha=1.3).eval(), input_data=input_data) +def test_forward_celu(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.CELU().eval(), input_data=input_data) + verify_model(torch.nn.CELU(alpha=0.3).eval(), input_data=input_data) + verify_model(torch.nn.CELU(alpha=1.0).eval(), input_data=input_data) + verify_model(torch.nn.CELU(alpha=1.3).eval(), input_data=input_data) + +def test_forward_gelu(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.GELU().eval(), input_data=input_data) + +def test_forward_selu(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.SELU().eval(), input_data=input_data) + def test_forward_log_sigmoid(): torch.set_grad_enabled(False) input_shape = [10, 10] @@ -1131,6 +1158,9 @@ def forward(self, xs): test_forward_prelu() test_forward_leakyrelu() test_forward_elu() + test_forward_celu() + test_forward_gelu() + test_forward_selu() test_forward_log_sigmoid() test_forward_adaptiveavgpool() test_forward_maxpool2d()