-
Notifications
You must be signed in to change notification settings - Fork 649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PReLU activation implementation #1419
Comments
Given that all current activation functions reside in JAX, it seem more fitting to add this JAX. Do you want to file an issue against their repo? |
Thanks for the suggestion. The main reason I filed the issue here was because it seems like PReLU is a special case where it has a trainable param and, if I'm not mistaken, all other jax activations do not. I'm not sure if this changes your suggestion, but it's something to consider. |
@isaaccorley - hey so sorry for the slow feedback on your suggestion here. 2 points:
So what if we added something like this? class PReLU(nn.Module):
negative_slope_init: float = 0.01
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
negative_slope = self.param(
"negative_slope",
lambda k: jnp.array(self.negative_slope_init, x.dtype)
)
return jnp.where(x >= 0, x, negative_slope * x) |
I'm indifferent on the implementation. I think the only thing to point out would be since we are inheriting from Module and other Modules have a dtype param, should we stray from that standard even though it is an activation function? I created a constant init func because jax itself seemed to be lacking one, however I haven't received a response to the issue I posted in the jax repo requesting to add it so I'm fine with just using a lambda. |
|
Shall I make a PR then? |
Yeah if you'd like to make a PR we could add the above to |
I'll try to take a first stab at it since it will be my first time contributing to flax. |
The current implementation of PReLU does not work as the other activation functions. The following example code raises an error at initialization class MLP(nn.Module):
""" Definition of MLP
attributes
----------
:param hidden_sizes: list of int corresponding to the number of neurons in each hidden layer.
:param output_size: int corresponding to the number of neurons in the output layer.
"""
hidden_sizes: Sequence[int]
out_size: int
@nn.compact
def __call__(self, x, **kwargs):
name = kwargs.pop('name', 'fc')
for e, size in enumerate(self.hidden_sizes, 1):
x = nn.Dense(size, name=name + str(e))(x)
# x = nn.silu(x)
x = nn.PReLU(x)
x = nn.Dense(self.out_size, name=name + '_output')(x)
return x
PReLU is not following the same definitions as the other activations. In the previous example, it needs to be x = nn.PReLU()(x) The documentation was not obvious to me. |
Because PReLU initializes a trained scalar parameter, it has to be treated as a layer. I've added clarification and example usage to the docs in #3122 |
I wanted to gauge interest on adding a PReLU activation. I noticed that
flax.linen.activations
are simply aliasingjax.nn
activation functions which also doesn't have a PReLU implementation.To add some background, PReLU is simply Leaky ReLU where the alpha (slope) parameter is trainable and not fixed. This makes it simple to implement as a Module if desired.
Here's an example implementation from another project of mine.
The text was updated successfully, but these errors were encountered: