Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PReLU activation implementation #1419

Closed
isaaccorley opened this issue Jul 11, 2021 · 10 comments · Fixed by #1570
Closed

PReLU activation implementation #1419

isaaccorley opened this issue Jul 11, 2021 · 10 comments · Fixed by #1570
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@isaaccorley
Copy link
Contributor

I wanted to gauge interest on adding a PReLU activation. I noticed that flax.linen.activations are simply aliasing jax.nn activation functions which also doesn't have a PReLU implementation.

To add some background, PReLU is simply Leaky ReLU where the alpha (slope) parameter is trainable and not fixed. This makes it simple to implement as a Module if desired.

Here's an example implementation from another project of mine.

from functools import partial
from typing import Any, Sequence

import jax.numpy as jnp
import flax.linen as nn


# This is nearly identical to jnp.ones however multiplies the output of jnp.ones by the constant value
def constant(key, shape: Sequence[int], value: Any, dtype: Any = jnp.float32) -> jnp.ndarray:
    value = jnp.asarray(value, dtype)
    return jnp.ones(shape, dtype) * value


class PReLU(nn.Module):
    negative_slope_init: float = 0.01
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = jnp.asarray(x, self.dtype)
        negative_slope = self.param(
            "negative_slope",
            partial(constant, value=self.negative_slope_init, dtype=self.dtype),
            (1,)
        )
        return jnp.where(x >= 0, x, negative_slope * x)
@marcvanzee
Copy link
Collaborator

Given that all current activation functions reside in JAX, it seem more fitting to add this JAX. Do you want to file an issue against their repo?

@jheek jheek added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label Jul 12, 2021
@isaaccorley
Copy link
Contributor Author

isaaccorley commented Jul 12, 2021

Thanks for the suggestion. The main reason I filed the issue here was because it seems like PReLU is a special case where it has a trainable param and, if I'm not mistaken, all other jax activations do not.

I'm not sure if this changes your suggestion, but it's something to consider.

@levskaya
Copy link
Collaborator

levskaya commented Sep 8, 2021

@isaaccorley - hey so sorry for the slow feedback on your suggestion here.

2 points:

  • instead of defining a constant init func, we can just declare a jnp scalar array of the correct dtype.
  • I think an -activation- "function" should strictly follow the dtype of its argument, so no dtype attribute, just derive it from x

So what if we added something like this?

class PReLU(nn.Module):
    negative_slope_init: float = 0.01
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        negative_slope = self.param(
            "negative_slope",
            lambda k: jnp.array(self.negative_slope_init, x.dtype)
        )
        return jnp.where(x >= 0, x, negative_slope * x)

@isaaccorley
Copy link
Contributor Author

I'm indifferent on the implementation. I think the only thing to point out would be since we are inheriting from Module and other Modules have a dtype param, should we stray from that standard even though it is an activation function?

I created a constant init func because jax itself seemed to be lacking one, however I haven't received a response to the issue I posted in the jax repo requesting to add it so I'm fine with just using a lambda.

@levskaya
Copy link
Collaborator

levskaya commented Sep 8, 2021

  • Other Modules have a dtype param to control the precision of their -intermediate- values, and a simple activation function like this doesn't have intermediates. We don't require modules to surface a dtype= attribute - it's just convention for the core layers to do so to give users the ability to control the floating-point types of the "insides"

  • The "constant" functions you're looking for already exist: jnp.full and jnp.full_like

@isaaccorley
Copy link
Contributor Author

  1. Makes sense thanks for clarifying that.
  2. Thanks for pointing me jnp.full. I wasn't aware of that.

Shall I make a PR then?

@levskaya
Copy link
Collaborator

levskaya commented Sep 9, 2021

Yeah if you'd like to make a PR we could add the above to activations.py I think (after all the passthrough function imports). (but no pressure - if you don't have time we can add it soon ourselves.)

@isaaccorley
Copy link
Contributor Author

I'll try to take a first stab at it since it will be my first time contributing to flax.

@mfouesneau
Copy link

mfouesneau commented May 26, 2023

The current implementation of PReLU does not work as the other activation functions.

The following example code raises an error at initialization

class MLP(nn.Module):
    """ Definition of MLP 

    attributes
    ----------
    :param hidden_sizes: list of int corresponding to the number of neurons in each hidden layer.
    :param output_size: int corresponding to the number of neurons in the output layer.
    """
    hidden_sizes: Sequence[int]
    out_size: int
    
    @nn.compact
    def __call__(self, x, **kwargs):
        name = kwargs.pop('name', 'fc')
        for e, size in enumerate(self.hidden_sizes, 1):
            x = nn.Dense(size, name=name + str(e))(x)
            # x = nn.silu(x)
            x = nn.PReLU(x)
        x = nn.Dense(self.out_size, name=name + '_output')(x)
        return x 
    186 @compact
    187 def __call__(self, inputs: Array) -> Array:
    188   """Applies a linear transformation to the inputs along the last dimension.
    189 
    190   Args:
   (...)
    194     The transformed input.
    195   """
    196   kernel = self.param('kernel',
    197                       self.kernel_init,
--> 198                       (jnp.shape(inputs)[-1], self.features),
    199                       self.param_dtype)
    200   if self.use_bias:
    201     bias = self.param('bias', self.bias_init, (self.features,),
    202                       self.param_dtype)

IndexError: tuple index out of range

PReLU is not following the same definitions as the other activations.

In the previous example, it needs to be

x = nn.PReLU()(x) 

The documentation was not obvious to me.

@levskaya
Copy link
Collaborator

Because PReLU initializes a trained scalar parameter, it has to be treated as a layer. I've added clarification and example usage to the docs in #3122

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants