Skip to content

Commit

Permalink
Merge pull request #1570 from isaaccorley:activations/prelu
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 406135790
  • Loading branch information
Flax Authors committed Oct 28, 2021
2 parents 4cbf6bf + 73b9e0c commit 004a118
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/flax.linen.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ Activation functions
softmax
softplus
swish
PReLU


Attention primitives
Expand Down
2 changes: 1 addition & 1 deletion flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# re-export commonly used modules and functions
from .activation import (celu, elu, gelu, glu, leaky_relu, log_sigmoid,
log_softmax, relu, sigmoid, soft_sign, softmax,
softplus, swish, silu, tanh)
softplus, swish, silu, tanh, PReLU)
from .attention import (MultiHeadDotProductAttention, SelfAttention,
dot_product_attention, make_attention_mask,
make_causal_mask, combine_masks)
Expand Down
32 changes: 32 additions & 0 deletions flax/linen/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,35 @@

from jax.numpy import tanh
# pylint: enable=unused-import

from typing import Any

from flax.linen.module import Module, compact
import jax.numpy as jnp


Array = Any


class PReLU(Module):
"""Parametric Rectified Linear Unit (PReLU) activation function.
Attributes:
negative_slope_init: the value to initialize the negative slope.
"""
negative_slope_init: float = 0.01
@compact
def __call__(self, inputs: Array) -> Array:
"""Applies an activation to the inputs.
Args:
inputs: the nd-array to apply the activation function to.
Returns:
The transformed input.
"""
negative_slope = self.param(
'negative_slope',
lambda k: jnp.asarray(self.negative_slope_init, jnp.float32)
)
return jnp.where(inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs)
42 changes: 42 additions & 0 deletions tests/linen/linen_activation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2021 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for flax.nn.activation."""

from absl.testing import absltest
from absl.testing import parameterized

from flax import linen as nn

import jax
from jax import random
import jax.numpy as jnp


# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()


class ActivationTest(parameterized.TestCase):

def test_prelu(self):
rng = random.PRNGKey(0)
x = jnp.ones((4, 6, 5))
act = nn.PReLU()
y, _ = act.init_with_output(rng, x)
self.assertEqual(y.shape, x.shape)


if __name__ == '__main__':
absltest.main()

0 comments on commit 004a118

Please sign in to comment.