Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid using float32 in normalization for mean/var and scale/bias parameters when force_float32_reductions=False #4314

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def _normalize(
use_scale: bool,
bias_init: Initializer,
scale_init: Initializer,
force_float32_reductions: bool = True
):
"""Normalizes the input of a normalization layer and optionally applies a learned scale and bias.

Expand All @@ -179,6 +180,9 @@ def _normalize(
use_scale: If true, scale the output.
bias_init: Initialization function for the bias term.
scale_init: Initialization function for the scaling function.
force_float32_reductions: If false, the scale and bias parameters use the
param_dtype. Otherwise, they will have at least float32 precision due to
the mean and var being promoted to float32.

Returns:
The normalized input.
Expand All @@ -200,13 +204,17 @@ def _normalize(
scale = mdl.param(
'scale', scale_init, reduced_feature_shape, param_dtype
).reshape(feature_shape)
if not force_float32_reductions:
scale = jnp.asarray(scale, param_dtype)
mul *= scale
args.append(scale)
y *= mul
if use_bias:
bias = mdl.param(
'bias', bias_init, reduced_feature_shape, param_dtype
).reshape(feature_shape)
if not force_float32_reductions:
bias = jnp.asarray(bias, param_dtype)
y += bias
args.append(bias)
dtype = dtypes.canonicalize_dtype(*args, dtype=dtype)
Expand Down Expand Up @@ -346,21 +354,32 @@ def __call__(
'batch_stats',
'mean',
lambda s: jnp.zeros(
s, jnp.float32 if self.force_float32_reductions else x.dtype
s,
jnp.float32 if self.force_float32_reductions else self.param_dtype,
),
feature_shape,
)
ra_var = self.variable(
'batch_stats',
'var',
lambda s: jnp.ones(
s, jnp.float32 if self.force_float32_reductions else x.dtype
s,
jnp.float32 if self.force_float32_reductions else self.param_dtype,
),
feature_shape,
)

if use_running_average:
mean, var = ra_mean.value, ra_var.value
mean = (
ra_mean.value
if self.force_float32_reductions
else jnp.asarray(ra_mean.value, self.param_dtype)
)
var = (
ra_var.value
if self.force_float32_reductions
else jnp.asarray(ra_var.value, self.param_dtype)
)
else:
mean, var = _compute_stats(
x,
Expand Down Expand Up @@ -393,6 +412,7 @@ def __call__(
self.use_scale,
self.bias_init,
self.scale_init,
self.force_float32_reductions,
)


Expand Down Expand Up @@ -509,6 +529,7 @@ def __call__(self, x, *, mask: jax.Array | None = None):
self.use_scale,
self.bias_init,
self.scale_init,
self.force_float32_reductions,
)


Expand Down Expand Up @@ -609,6 +630,7 @@ def __call__(self, x, *, mask: jax.Array | None = None):
self.use_scale,
initializers.zeros,
self.scale_init,
self.force_float32_reductions,
)


Expand Down Expand Up @@ -788,6 +810,7 @@ def __call__(self, x, *, mask: jax.Array | None = None):
self.use_scale,
self.bias_init,
self.scale_init,
self.force_float32_reductions,
)


Expand Down Expand Up @@ -912,6 +935,7 @@ def __call__(self, x, *, mask: jax.Array | None = None):
self.use_scale,
self.bias_init,
self.scale_init,
self.force_float32_reductions,
)


Expand Down
Loading