From 7d25cc3bd34fc8e77bb22b5584711cabdb39bb56 Mon Sep 17 00:00:00 2001 From: gsq7474741 Date: Wed, 10 Nov 2021 01:49:34 +0800 Subject: [PATCH 1/2] add RRelu v0.1 --- .../unittests/test_activation_nn_grad.py | 25 ++++++++ .../tests/unittests/test_activation_op.py | 54 ++++++++++++++++++ .../tests/unittests/test_imperative_layers.py | 3 + python/paddle/nn/__init__.py | 1 + python/paddle/nn/functional/__init__.py | 1 + python/paddle/nn/functional/activation.py | 56 ++++++++++++++++++ python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/activation.py | 57 +++++++++++++++++++ 8 files changed, 198 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py index 825d74388bc0b4..9c21d86e90303b 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py @@ -165,6 +165,31 @@ def test_grad(self): self.func(p) +class TestRReluGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 3, 7, 9] + eps = 0.005 + dtype = np.float64 + seed = 2022 + + x = layers.data('x', shape, False, dtype) + x.persistable = True + + y = F.rrelu(x, seed=seed) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.02 + + gradient_checker.grad_check([x], y, x_init=x_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places = [fluid.CUDAPlace(0)] + for p in places: + self.func(p) + + class TestELUDoubleGradCheck(unittest.TestCase): @prog_scope() def func(self, place): diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index b82dd631c64890..09638b5a490eec 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -1356,6 +1356,60 @@ def test_errors(self): F.leaky_relu(x_fp16) +class TestRReluAPI(unittest.TestCase): + # test paddle.nn.RReLU, paddle.nn.functional.rrelu, + def setUp(self): + np.random.seed(1024) + self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32') + self.one_np = np.array([-1.]).astype('float32') + self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + minus_one = paddle.fluid.data('one', [1]) + rand_alpha = F.rrelu(minus_one, seed=2022) + x = paddle.fluid.data('X', [10, 12]) + out1 = F.rrelu(x, seed=2022) + m = paddle.nn.RReLU(seed=2022) + out2 = m(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np, + 'one': self.one_np}, + fetch_list=[out1, out2, rand_alpha]) + out_ref = ref_leaky_relu(self.x_np, alpha=-res[2]) + for r in range(2): + self.assertEqual(np.allclose(out_ref, res[r]), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + rand_alpha = F.rrelu(paddle.to_tensor(-1.), seed=2022) + x = paddle.to_tensor(self.x_np) + out1 = F.rrelu(x, seed=2022) + m = paddle.nn.RReLU(seed=2022) + out2 = m(x) + out_ref = ref_leaky_relu(self.x_np, alpha=-rand_alpha.numpy().item()) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + + paddle.enable_static() + + def test_errors(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.rrelu, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.fluid.data( + name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.rrelu, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.fluid.data( + name='x_fp16', shape=[12, 10], dtype='float16') + F.rrelu(x_fp16) + + def gelu(x, approximate): if approximate: y_ref = 0.5 * x * (1.0 + np.tanh( diff --git a/python/paddle/fluid/tests/unittests/test_imperative_layers.py b/python/paddle/fluid/tests/unittests/test_imperative_layers.py index 3561405ae090bd..c1733cb77aa30a 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_layers.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_layers.py @@ -60,6 +60,9 @@ def test_layer_str(self): module = nn.LeakyReLU() self.assertEqual(str(module), 'LeakyReLU(negative_slope=0.01)') + module = nn.RReLU() + self.assertEqual(str(module), 'RReLU(negative_slope=0.01)') + module = nn.Sigmoid() self.assertEqual(str(module), 'Sigmoid()') diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 064052c07695de..cca16f6341f5d7 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -38,6 +38,7 @@ from .layer.activation import SELU # noqa: F401 from .layer.activation import Silu # noqa: F401 from .layer.activation import LeakyReLU # noqa: F401 +from .layer.activation import RReLU # noqa: F401 from .layer.activation import Sigmoid # noqa: F401 from .layer.activation import Hardsigmoid # noqa: F401 from .layer.activation import LogSigmoid # noqa: F401 diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 1af53e0826be87..24abfad72a73c0 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -24,6 +24,7 @@ from .activation import hardsigmoid # noqa: F401 from .activation import hardswish # noqa: F401 from .activation import leaky_relu # noqa: F401 +from .activation import rrelu # noqa: F401 from .activation import log_sigmoid # noqa: F401 from .activation import maxout # noqa: F401 from .activation import prelu # noqa: F401 diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index a39c00075a3de1..7659feb0dc325b 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -27,6 +27,7 @@ from ...fluid.data_feeder import check_variable_and_dtype, check_dtype import paddle from paddle import _C_ops +import numpy as np __all__ = [] @@ -435,6 +436,61 @@ def leaky_relu(x, negative_slope=0.01, name=None): return out +def rrelu(x, lower=0.125, upper=0.333, seed=0, name=None): + r""" + rrelu activation + + .. math:: + leaky\_relu(x)= + \left\{ + \begin{array}{rcl} + x, & & if \ x >= 0 \\ + negative\_slope * x, & & otherwise \\ + \end{array} + \right. + negative\_slope~U(lower,upper) + + Args: + x (Tensor): The input Tensor with data type float32, float64. + lower(float, optional): The lower bound of the uniform distribution. Default is 0.125. + upper(float, optional): The upper bound of the uniform distribution. Default is 0.333. + seed(int, optional): The random seed of uniform distribution engin. If seed is 0, + it will use the seed of the global default generator (which can be set by paddle.seed). + Note that if seed is not 0, this operator will always generate the same random negative_slope every + time. Default is 0. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + + x = paddle.to_tensor([-2., 0., 1.]) + out = F.rrelu(x) # [-0.02, 0., 1.] + + """ + np.random.seed(seed) + negative_slope = np.random.uniform(lower, upper, [1]) + if in_dygraph_mode(): + return _C_ops.leaky_relu(x, 'alpha', negative_slope) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'leaky_relu') + helper = LayerHelper('leaky_relu', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='leaky_relu', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'alpha': negative_slope}) + return out + + def prelu(x, weight, name=None): """ prelu activation. diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index eb7535b16c6e1e..1abe7fe520f2f6 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -23,6 +23,7 @@ from .activation import ReLU # noqa: F401 from .activation import ReLU6 # noqa: F401 from .activation import LeakyReLU # noqa: F401 +from .activation import RReLU # noqa: F401 from .activation import Sigmoid # noqa: F401 from .activation import Softmax # noqa: F401 from .activation import LogSoftmax # noqa: F401 diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index cf0ac79ca8ff6f..0c710b34ae536c 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -603,6 +603,63 @@ def extra_repr(self): return 'negative_slope={}{}'.format(self._negative_slope, name_str) +class RReLU(Layer): + r""" + Randomized Leaky ReLU Activation. + + .. math:: + + RReLU(x)= + \left\{ + \begin{array}{rcl} + x, & & if \ x >= 0 \\ + negative\_slope * x, & & otherwise \\ + \end{array} + \right. + negative\_slope~U(lower,upper) + + + Parameters: + lower(float, optional): The lower bound of the uniform distribution. Default is 0.125. + upper(float, optional): The upper bound of the uniform distribution. Default is 0.333. + seed(int, optional): The random seed of uniform distribution engin. If seed is 0, + it will use the seed of the global default generator (which can be set by paddle.seed). + Note that if seed is not 0, this operator will always generate the same random negative_slope every + time. Default is 0. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + m = paddle.nn.RReLU() + x = paddle.to_tensor(np.array([-2, 0, 1], 'float32')) + out = m(x) # [-0.02, 0., 1.] + """ + + def __init__(self, lower=0.125, upper=0.333, seed=0, name=None): + super(RReLU, self).__init__() + self._lower = lower + self._upper = upper + self._seed = seed + self._name = name + + def forward(self, x): + return F.rrelu(x, self._lower, self._upper, self._seed, self._name) + + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'lower={}, upper={}, seed={}{}'.format(self._lower, self._upper, + self._seed, name_str) + + class Sigmoid(Layer): """ this interface is used to construct a callable object of the ``Sigmoid`` class. This layer calcluate the `sigmoid` of input x. From c0cde21303ad9381eb3c5509848e6b8065818c14 Mon Sep 17 00:00:00 2001 From: gsq7474741 Date: Wed, 10 Nov 2021 02:25:06 +0800 Subject: [PATCH 2/2] add RRelu v0.2 --- python/paddle/fluid/tests/unittests/test_imperative_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_layers.py b/python/paddle/fluid/tests/unittests/test_imperative_layers.py index c1733cb77aa30a..e218c5ba8e31b4 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_layers.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_layers.py @@ -61,7 +61,7 @@ def test_layer_str(self): self.assertEqual(str(module), 'LeakyReLU(negative_slope=0.01)') module = nn.RReLU() - self.assertEqual(str(module), 'RReLU(negative_slope=0.01)') + self.assertEqual(str(module), 'RReLU(lower=0.125, upper=0.333, seed=0)') module = nn.Sigmoid() self.assertEqual(str(module), 'Sigmoid()')