diff --git a/docs/flax.linen.rst b/docs/flax.linen.rst index 1ab3b5b93d..cd87f0aaa1 100644 --- a/docs/flax.linen.rst +++ b/docs/flax.linen.rst @@ -114,6 +114,7 @@ Activation functions softmax softplus swish + PReLU Attention primitives diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 08a8b83023..22498ba7a4 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -19,7 +19,7 @@ # re-export commonly used modules and functions from .activation import (celu, elu, gelu, glu, leaky_relu, log_sigmoid, log_softmax, relu, sigmoid, soft_sign, softmax, - softplus, swish, silu, tanh) + softplus, swish, silu, tanh, PReLU) from .attention import (MultiHeadDotProductAttention, SelfAttention, dot_product_attention, make_attention_mask, make_causal_mask, combine_masks) diff --git a/flax/linen/activation.py b/flax/linen/activation.py index a0466b1c33..83628297b5 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -40,3 +40,35 @@ from jax.numpy import tanh # pylint: enable=unused-import + +from typing import Any + +from flax.linen.module import Module, compact +import jax.numpy as jnp + + +Array = Any + + +class PReLU(Module): + """Parametric Rectified Linear Unit (PReLU) activation function. + + Attributes: + negative_slope_init: the value to initialize the negative slope. + """ + negative_slope_init: float = 0.01 + @compact + def __call__(self, inputs: Array) -> Array: + """Applies an activation to the inputs. + + Args: + inputs: the nd-array to apply the activation function to. + + Returns: + The transformed input. + """ + negative_slope = self.param( + 'negative_slope', + lambda k: jnp.asarray(self.negative_slope_init, jnp.float32) + ) + return jnp.where(inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs) diff --git a/tests/linen/linen_activation_test.py b/tests/linen/linen_activation_test.py new file mode 100644 index 0000000000..3f35c37745 --- /dev/null +++ b/tests/linen/linen_activation_test.py @@ -0,0 +1,42 @@ +# Copyright 2021 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for flax.nn.activation.""" + +from absl.testing import absltest +from absl.testing import parameterized + +from flax import linen as nn + +import jax +from jax import random +import jax.numpy as jnp + + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + + +class ActivationTest(parameterized.TestCase): + + def test_prelu(self): + rng = random.PRNGKey(0) + x = jnp.ones((4, 6, 5)) + act = nn.PReLU() + y, _ = act.init_with_output(rng, x) + self.assertEqual(y.shape, x.shape) + + +if __name__ == '__main__': + absltest.main()