diff --git a/flax/linen/activation.py b/flax/linen/activation.py index 44165238a7..61075733f5 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -69,6 +69,6 @@ def __call__(self, inputs: Array) -> Array: """ negative_slope = self.param( 'negative_slope', - lambda k: jnp.asarray(self.negative_slope_init) + lambda k: jnp.asarray(self.negative_slope_init, inputs.dtype) ) return jnp.where(inputs >= 0, inputs, negative_slope * inputs)