From df5315621a0536e0913174bead82297a211af052 Mon Sep 17 00:00:00 2001 From: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> Date: Sun, 26 Sep 2021 22:08:18 -0500 Subject: [PATCH 1/6] add prelu activation --- flax/linen/__init__.py | 2 +- flax/linen/activation.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 08a8b83023..22498ba7a4 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -19,7 +19,7 @@ # re-export commonly used modules and functions from .activation import (celu, elu, gelu, glu, leaky_relu, log_sigmoid, log_softmax, relu, sigmoid, soft_sign, softmax, - softplus, swish, silu, tanh) + softplus, swish, silu, tanh, PReLU) from .attention import (MultiHeadDotProductAttention, SelfAttention, dot_product_attention, make_attention_mask, make_causal_mask, combine_masks) diff --git a/flax/linen/activation.py b/flax/linen/activation.py index a0466b1c33..ee29441150 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -40,3 +40,35 @@ from jax.numpy import tanh # pylint: enable=unused-import + +from typing import Any + +from flax.linen.module import Module, compact +import jax.numpy as jnp + + +Array = Any + + +class PReLU(Module): + """Parametric Rectified Linear Unit (PReLU) activation function. + + Attributes: + negative_slope_init: the value to initialize the negative slope. + """ + negative_slope_init: float = 0.01 + @compact + def __call__(self, inputs: Array) -> Array: + """Applies a convolution to the inputs. + + Args: + inputs: the nd-array to apply the activation function to. + + Returns: + The transformed input. + """ + negative_slope = self.param( + 'negative_slope', + lambda k: jnp.array(self.negative_slope_init, inputs.dtype) + ) + return jnp.where(inputs >= 0, inputs, negative_slope * inputs) From 6055a6c3c706e8d18c2fa367ca7ce435099cc756 Mon Sep 17 00:00:00 2001 From: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> Date: Sun, 26 Sep 2021 22:08:33 -0500 Subject: [PATCH 2/6] add prelu activation to docs --- docs/flax.linen.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/flax.linen.rst b/docs/flax.linen.rst index 1ab3b5b93d..cd87f0aaa1 100644 --- a/docs/flax.linen.rst +++ b/docs/flax.linen.rst @@ -114,6 +114,7 @@ Activation functions softmax softplus swish + PReLU Attention primitives From 4da0d98a5255d3866bb41a3c834d0bd9336a3711 Mon Sep 17 00:00:00 2001 From: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> Date: Sun, 26 Sep 2021 22:08:43 -0500 Subject: [PATCH 3/6] add linen activation tests --- tests/linen/linen_activation_test.py | 42 ++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/linen/linen_activation_test.py diff --git a/tests/linen/linen_activation_test.py b/tests/linen/linen_activation_test.py new file mode 100644 index 0000000000..3f35c37745 --- /dev/null +++ b/tests/linen/linen_activation_test.py @@ -0,0 +1,42 @@ +# Copyright 2021 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for flax.nn.activation.""" + +from absl.testing import absltest +from absl.testing import parameterized + +from flax import linen as nn + +import jax +from jax import random +import jax.numpy as jnp + + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + + +class ActivationTest(parameterized.TestCase): + + def test_prelu(self): + rng = random.PRNGKey(0) + x = jnp.ones((4, 6, 5)) + act = nn.PReLU() + y, _ = act.init_with_output(rng, x) + self.assertEqual(y.shape, x.shape) + + +if __name__ == '__main__': + absltest.main() From 40a54b9d531a0478027eb3b1a9a3fc20b28a7681 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 21 Oct 2021 12:47:35 -0500 Subject: [PATCH 4/6] update per suggestions --- flax/linen/activation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flax/linen/activation.py b/flax/linen/activation.py index ee29441150..44165238a7 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -59,7 +59,7 @@ class PReLU(Module): negative_slope_init: float = 0.01 @compact def __call__(self, inputs: Array) -> Array: - """Applies a convolution to the inputs. + """Applies an activation to the inputs. Args: inputs: the nd-array to apply the activation function to. @@ -69,6 +69,6 @@ def __call__(self, inputs: Array) -> Array: """ negative_slope = self.param( 'negative_slope', - lambda k: jnp.array(self.negative_slope_init, inputs.dtype) + lambda k: jnp.asarray(self.negative_slope_init) ) return jnp.where(inputs >= 0, inputs, negative_slope * inputs) From 7b0749325dd9dd175eb8ccfbe9807d79afbcf6b2 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Tue, 26 Oct 2021 19:42:27 -0500 Subject: [PATCH 5/6] revert change --- flax/linen/activation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flax/linen/activation.py b/flax/linen/activation.py index 44165238a7..61075733f5 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -69,6 +69,6 @@ def __call__(self, inputs: Array) -> Array: """ negative_slope = self.param( 'negative_slope', - lambda k: jnp.asarray(self.negative_slope_init) + lambda k: jnp.asarray(self.negative_slope_init, inputs.dtype) ) return jnp.where(inputs >= 0, inputs, negative_slope * inputs) From 73b9e0cad783c7e0a6766b2d89de0bd8a46babac Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Wed, 27 Oct 2021 19:37:44 -0500 Subject: [PATCH 6/6] updated per suggestions --- flax/linen/activation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flax/linen/activation.py b/flax/linen/activation.py index 61075733f5..83628297b5 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -69,6 +69,6 @@ def __call__(self, inputs: Array) -> Array: """ negative_slope = self.param( 'negative_slope', - lambda k: jnp.asarray(self.negative_slope_init, inputs.dtype) + lambda k: jnp.asarray(self.negative_slope_init, jnp.float32) ) - return jnp.where(inputs >= 0, inputs, negative_slope * inputs) + return jnp.where(inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs)