diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 62a0ae021e..b134c8b86a 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -402,7 +402,7 @@ class LayerNorm(Module): >>> np.testing.assert_allclose(y, y2) >>> y = nn.LayerNorm(reduction_axes=(1, 2), feature_axes=-1).apply(variables, x) - >>> y2 = nn.InstanceNorm(channel_axes=-1).apply(variables, x) + >>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x) >>> np.testing.assert_allclose(y, y2) Attributes: @@ -610,7 +610,7 @@ class GroupNorm(Module): >>> np.testing.assert_allclose(y, y2) >>> y = nn.GroupNorm(num_groups=None, group_size=1).apply(variables, x) - >>> y2 = nn.InstanceNorm(channel_axes=-1).apply(variables, x) + >>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x) >>> np.testing.assert_allclose(y, y2) Attributes: @@ -784,7 +784,7 @@ class InstanceNorm(Module): >>> y = layer.apply(variables, x) >>> # having a channel_axis of -1 in InstanceNorm is identical to reducing all non-batch, - >>> # non-channel axes and using the channel_axes as the feature_axes in LayerNorm + >>> # non-channel axes and using the feature_axes as the feature_axes in LayerNorm >>> y2 = nn.LayerNorm(reduction_axes=[1, 2], feature_axes=-1).apply(variables, x) >>> np.testing.assert_allclose(y, y2, atol=1e-7) >>> y3 = nn.GroupNorm(num_groups=x.shape[-1]).apply(variables, x) @@ -800,9 +800,9 @@ class InstanceNorm(Module): by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. - channel_axes: Axes for channel. This is considered the feature axes for the - learned bias and scaling parameter. All other axes except the batch axes - (which is assumed to be the leading axis) will be reduced. + feature_axes: Axes for features. The learned bias and scaling parameters will + be in the shape defined by the feature axes. All other axes except the batch + axes (which is assumed to be the leading axis) will be reduced. axis_name: the axis name used to combine batch statistics from multiple devices. See ``jax.pmap`` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the @@ -826,7 +826,7 @@ class InstanceNorm(Module): use_scale: bool = True bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones - channel_axes: Axes = -1 + feature_axes: Axes = -1 axis_name: Optional[str] = None axis_index_groups: Any = None use_fast_variance: bool = True @@ -843,11 +843,11 @@ def __call__(self, x, mask=None): Returns: Normalized inputs (the same shape as inputs). """ - channel_axes = _canonicalize_axes(x.ndim, self.channel_axes) - if 0 in channel_axes: + feature_axes = _canonicalize_axes(x.ndim, self.feature_axes) + if 0 in feature_axes: raise ValueError('The channel axes cannot include the leading dimension ' 'as this is assumed to be the batch axis.') - reduction_axes = [i for i in range(1, x.ndim) if i not in channel_axes] + reduction_axes = [i for i in range(1, x.ndim) if i not in feature_axes] mean, var = _compute_stats( x, @@ -865,7 +865,7 @@ def __call__(self, x, mask=None): mean, var, reduction_axes, - channel_axes, + feature_axes, self.dtype, self.param_dtype, self.epsilon, diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 9cc9b0404c..9af3290189 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -431,12 +431,12 @@ def __call__(self, x): np.testing.assert_allclose(y1, y2, rtol=1e-4) @parameterized.parameters( - {'channel_axes': -1}, - {'channel_axes': (1, 2)}, - {'channel_axes': (1, 2, 3)}, - {'channel_axes': -1, 'use_fast_variance': False}, + {'feature_axes': -1}, + {'feature_axes': (1, 2)}, + {'feature_axes': (1, 2, 3)}, + {'feature_axes': -1, 'use_fast_variance': False}, ) - def test_instance_norm(self, channel_axes, use_fast_variance=True): + def test_instance_norm(self, feature_axes, use_fast_variance=True): rng = random.key(0) key1, key2 = random.split(rng) e = 1e-5 @@ -447,21 +447,21 @@ def test_instance_norm(self, channel_axes, use_fast_variance=True): use_bias=False, use_scale=False, epsilon=e, - channel_axes=channel_axes, + feature_axes=feature_axes, use_fast_variance=use_fast_variance, ) y, _ = model_cls.init_with_output(key2, x) self.assertEqual(x.dtype, y.dtype) self.assertEqual(x.shape, y.shape) - canonicalized_channel_axes = [ + canonicalized_feature_axes = [ i if i >= 0 else (x.ndim + i) for i in ( - channel_axes if isinstance(channel_axes, tuple) else (channel_axes,) + feature_axes if isinstance(feature_axes, tuple) else (feature_axes,) ) ] reduction_axes = [ - i for i in range(1, x.ndim) if i not in canonicalized_channel_axes + i for i in range(1, x.ndim) if i not in canonicalized_feature_axes ] y_one_liner = ( x - x.mean(axis=reduction_axes, keepdims=True) @@ -470,29 +470,29 @@ def test_instance_norm(self, channel_axes, use_fast_variance=True): np.testing.assert_allclose(y_one_liner, y, atol=1e-6) @parameterized.parameters( - {'channel_axes': 0}, - {'channel_axes': -4}, - {'channel_axes': (0, 3)}, - {'channel_axes': (2, -4)}, + {'feature_axes': 0}, + {'feature_axes': -4}, + {'feature_axes': (0, 3)}, + {'feature_axes': (2, -4)}, ) - def test_instance_norm_raise_error(self, channel_axes): + def test_instance_norm_raise_error(self, feature_axes): with self.assertRaisesRegex( ValueError, 'The channel axes cannot include the leading dimension ' 'as this is assumed to be the batch axis.', ): x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5)) - layer = nn.InstanceNorm(channel_axes=channel_axes) + layer = nn.InstanceNorm(feature_axes=feature_axes) _ = layer.init(jax.random.key(1), x) @parameterized.parameters( { 'layer1': nn.LayerNorm(feature_axes=(1, 2)), - 'layer2': nn.InstanceNorm(channel_axes=(1, 2)), + 'layer2': nn.InstanceNorm(feature_axes=(1, 2)), }, { 'layer1': nn.LayerNorm(reduction_axes=(1, 2), feature_axes=-1), - 'layer2': nn.InstanceNorm(channel_axes=-1), + 'layer2': nn.InstanceNorm(feature_axes=-1), }, { 'layer1': nn.LayerNorm( @@ -502,7 +502,7 @@ def test_instance_norm_raise_error(self, channel_axes): scale_init=nn.initializers.uniform(), ), 'layer2': nn.InstanceNorm( - channel_axes=(1, -1), + feature_axes=(1, -1), bias_init=nn.initializers.uniform(), scale_init=nn.initializers.uniform(), ),