-
Notifications
You must be signed in to change notification settings - Fork 663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added RMS normalization layer #2881
Conversation
Codecov Report
@@ Coverage Diff @@
## main #2881 +/- ##
==========================================
+ Coverage 81.45% 81.47% +0.02%
==========================================
Files 55 55
Lines 5779 5798 +19
==========================================
+ Hits 4707 4724 +17
- Misses 1072 1074 +2
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
@@ -335,6 +343,70 @@ def __call__(self, x): | |||
self.bias_init, self.scale_init) | |||
|
|||
|
|||
class RMSNorm(Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add an example of how to use the layer? I think we should starting doing this for every layer, like @cgarciae does in his RNN PR: https://github.com/google/flax/pull/2604/files#r1107264719.
4a05ad4
to
42ba933
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thanks!! -- I second Marc's ask to add a small usage example in the docstring.
@marcvanzee @levskaya I added a docstring, let me know if this works! |
@chiamp - I added a exception for the deprecation warning, your tests all seem to pass now! |
flax/linen/normalization.py
Outdated
epsilon: float = 1e-6 | ||
dtype: Optional[Dtype] = None | ||
param_dtype: Dtype = jnp.float32 | ||
use_bias: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I just noticed this - we probably don't want use_bias
and bias_init
here since we're never adjusting the offset?
Resolves #2849.
Added an optional argument
use_mean
in the_compute_stats
function inflax/linen/normalization.py
, which will compute the mean and variance if set toTrue
, and will set the mean to 0 and compute the variance without subtracting the mean if set toFalse
. The latter mode is useful as square rooting this "variance" value (which is done in the_normalize
function) will give you the RMS.