diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index bf172f2c1d..340b5d03a5 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -160,6 +160,7 @@ def _normalize( use_scale: bool, bias_init: Initializer, scale_init: Initializer, + force_float32_reductions: bool = True ): """Normalizes the input of a normalization layer and optionally applies a learned scale and bias. @@ -179,6 +180,9 @@ def _normalize( use_scale: If true, scale the output. bias_init: Initialization function for the bias term. scale_init: Initialization function for the scaling function. + force_float32_reductions: If false, the scale and bias parameters use the + param_dtype. Otherwise, they will have at least float32 precision due to + the mean and var being promoted to float32. Returns: The normalized input. @@ -200,6 +204,8 @@ def _normalize( scale = mdl.param( 'scale', scale_init, reduced_feature_shape, param_dtype ).reshape(feature_shape) + if not force_float32_reductions: + scale = jnp.asarray(scale, param_dtype) mul *= scale args.append(scale) y *= mul @@ -207,6 +213,8 @@ def _normalize( bias = mdl.param( 'bias', bias_init, reduced_feature_shape, param_dtype ).reshape(feature_shape) + if not force_float32_reductions: + bias = jnp.asarray(bias, param_dtype) y += bias args.append(bias) dtype = dtypes.canonicalize_dtype(*args, dtype=dtype) @@ -346,7 +354,8 @@ def __call__( 'batch_stats', 'mean', lambda s: jnp.zeros( - s, jnp.float32 if self.force_float32_reductions else x.dtype + s, + jnp.float32 if self.force_float32_reductions else self.param_dtype, ), feature_shape, ) @@ -354,13 +363,23 @@ def __call__( 'batch_stats', 'var', lambda s: jnp.ones( - s, jnp.float32 if self.force_float32_reductions else x.dtype + s, + jnp.float32 if self.force_float32_reductions else self.param_dtype, ), feature_shape, ) if use_running_average: - mean, var = ra_mean.value, ra_var.value + mean = ( + ra_mean.value + if self.force_float32_reductions + else jnp.asarray(ra_mean.value, self.param_dtype) + ) + var = ( + ra_var.value + if self.force_float32_reductions + else jnp.asarray(ra_var.value, self.param_dtype) + ) else: mean, var = _compute_stats( x, @@ -393,6 +412,7 @@ def __call__( self.use_scale, self.bias_init, self.scale_init, + self.force_float32_reductions, ) @@ -509,6 +529,7 @@ def __call__(self, x, *, mask: jax.Array | None = None): self.use_scale, self.bias_init, self.scale_init, + self.force_float32_reductions, ) @@ -609,6 +630,7 @@ def __call__(self, x, *, mask: jax.Array | None = None): self.use_scale, initializers.zeros, self.scale_init, + self.force_float32_reductions, ) @@ -788,6 +810,7 @@ def __call__(self, x, *, mask: jax.Array | None = None): self.use_scale, self.bias_init, self.scale_init, + self.force_float32_reductions, ) @@ -912,6 +935,7 @@ def __call__(self, x, *, mask: jax.Array | None = None): self.use_scale, self.bias_init, self.scale_init, + self.force_float32_reductions, )