Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds standardize initializer #2717

Merged
merged 3 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions docs/api_reference/flax.linen.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ Profiling
----------------------

.. automodule:: flax.linen
.. currentmodule:: flax.linen

.. autosummary::
:toctree: _autosummary

Expand All @@ -61,8 +59,6 @@ Inspection
----------------------

.. automodule:: flax.linen
.. currentmodule:: flax.linen

.. autosummary::
:toctree: _autosummary

Expand All @@ -73,8 +69,6 @@ Transformations
----------------------

.. automodule:: flax.linen.transforms
.. currentmodule:: flax.linen

.. autosummary::
:toctree: _autosummary

Expand All @@ -95,8 +89,6 @@ Metadata
----------------------

.. automodule:: flax.linen.meta
.. currentmodule:: flax.linen

.. autosummary::
:toctree: _autosummary

Expand Down Expand Up @@ -146,25 +138,66 @@ Pooling
pool


Initializers
------------------------

.. automodule:: flax.linen.initializers
.. autosummary::
:toctree: _autosummary

constant
delta_orthogonal
glorot_normal
glorot_uniform
he_normal
he_uniform
kaiming_normal
kaiming_uniform
lecun_normal
lecun_uniform
normal
ones
orthogonal
uniform
standardize
variance_scaling
xavier_normal
xavier_uniform
zeros


Activation functions
------------------------

.. automodule:: flax.linen.activation
.. autosummary::
:toctree: _autosummary

PReLU
celu
elu
gelu
glu
hard_sigmoid
hard_silu
hard_swish
hard_tanh
leaky_relu
log_sigmoid
log_softmax
logsumexp
one_hot
relu
relu6 as relu6,
selu
sigmoid
silu
soft_sign
softmax
softplus
standardize
swish
PReLU
tanh


Combinators
Expand Down
1 change: 1 addition & 0 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
soft_sign as soft_sign,
softmax as softmax,
softplus as softplus,
standardize as standardize,
swish as swish,
tanh as tanh
)
Expand Down
1 change: 1 addition & 0 deletions flax/linen/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from jax.nn import soft_sign
from jax.nn import softmax
from jax.nn import softplus
from jax.nn import standardize
from jax.nn import swish
import jax.numpy as jnp
from jax.numpy import tanh
Expand Down