Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

renamed channel_axes to feature_axes in InstanceNorm #3667

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ class LayerNorm(Module):
>>> np.testing.assert_allclose(y, y2)

>>> y = nn.LayerNorm(reduction_axes=(1, 2), feature_axes=-1).apply(variables, x)
>>> y2 = nn.InstanceNorm(channel_axes=-1).apply(variables, x)
>>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x)
>>> np.testing.assert_allclose(y, y2)

Attributes:
Expand Down Expand Up @@ -610,7 +610,7 @@ class GroupNorm(Module):
>>> np.testing.assert_allclose(y, y2)

>>> y = nn.GroupNorm(num_groups=None, group_size=1).apply(variables, x)
>>> y2 = nn.InstanceNorm(channel_axes=-1).apply(variables, x)
>>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x)
>>> np.testing.assert_allclose(y, y2)

Attributes:
Expand Down Expand Up @@ -784,7 +784,7 @@ class InstanceNorm(Module):
>>> y = layer.apply(variables, x)

>>> # having a channel_axis of -1 in InstanceNorm is identical to reducing all non-batch,
>>> # non-channel axes and using the channel_axes as the feature_axes in LayerNorm
>>> # non-channel axes and using the feature_axes as the feature_axes in LayerNorm
>>> y2 = nn.LayerNorm(reduction_axes=[1, 2], feature_axes=-1).apply(variables, x)
>>> np.testing.assert_allclose(y, y2, atol=1e-7)
>>> y3 = nn.GroupNorm(num_groups=x.shape[-1]).apply(variables, x)
Expand All @@ -800,9 +800,9 @@ class InstanceNorm(Module):
by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one.
channel_axes: Axes for channel. This is considered the feature axes for the
learned bias and scaling parameter. All other axes except the batch axes
(which is assumed to be the leading axis) will be reduced.
feature_axes: Axes for features. The learned bias and scaling parameters will
be in the shape defined by the feature axes. All other axes except the batch
axes (which is assumed to be the leading axis) will be reduced.
axis_name: the axis name used to combine batch statistics from multiple
devices. See ``jax.pmap`` for a description of axis names (default: None).
This is only needed if the model is subdivided across devices, i.e. the
Expand All @@ -826,7 +826,7 @@ class InstanceNorm(Module):
use_scale: bool = True
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
channel_axes: Axes = -1
feature_axes: Axes = -1
axis_name: Optional[str] = None
axis_index_groups: Any = None
use_fast_variance: bool = True
Expand All @@ -843,11 +843,11 @@ def __call__(self, x, mask=None):
Returns:
Normalized inputs (the same shape as inputs).
"""
channel_axes = _canonicalize_axes(x.ndim, self.channel_axes)
if 0 in channel_axes:
feature_axes = _canonicalize_axes(x.ndim, self.feature_axes)
if 0 in feature_axes:
raise ValueError('The channel axes cannot include the leading dimension '
'as this is assumed to be the batch axis.')
reduction_axes = [i for i in range(1, x.ndim) if i not in channel_axes]
reduction_axes = [i for i in range(1, x.ndim) if i not in feature_axes]

mean, var = _compute_stats(
x,
Expand All @@ -865,7 +865,7 @@ def __call__(self, x, mask=None):
mean,
var,
reduction_axes,
channel_axes,
feature_axes,
self.dtype,
self.param_dtype,
self.epsilon,
Expand Down
36 changes: 18 additions & 18 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,12 +431,12 @@ def __call__(self, x):
np.testing.assert_allclose(y1, y2, rtol=1e-4)

@parameterized.parameters(
{'channel_axes': -1},
{'channel_axes': (1, 2)},
{'channel_axes': (1, 2, 3)},
{'channel_axes': -1, 'use_fast_variance': False},
{'feature_axes': -1},
{'feature_axes': (1, 2)},
{'feature_axes': (1, 2, 3)},
{'feature_axes': -1, 'use_fast_variance': False},
)
def test_instance_norm(self, channel_axes, use_fast_variance=True):
def test_instance_norm(self, feature_axes, use_fast_variance=True):
rng = random.key(0)
key1, key2 = random.split(rng)
e = 1e-5
Expand All @@ -447,21 +447,21 @@ def test_instance_norm(self, channel_axes, use_fast_variance=True):
use_bias=False,
use_scale=False,
epsilon=e,
channel_axes=channel_axes,
feature_axes=feature_axes,
use_fast_variance=use_fast_variance,
)
y, _ = model_cls.init_with_output(key2, x)
self.assertEqual(x.dtype, y.dtype)
self.assertEqual(x.shape, y.shape)

canonicalized_channel_axes = [
canonicalized_feature_axes = [
i if i >= 0 else (x.ndim + i)
for i in (
channel_axes if isinstance(channel_axes, tuple) else (channel_axes,)
feature_axes if isinstance(feature_axes, tuple) else (feature_axes,)
)
]
reduction_axes = [
i for i in range(1, x.ndim) if i not in canonicalized_channel_axes
i for i in range(1, x.ndim) if i not in canonicalized_feature_axes
]
y_one_liner = (
x - x.mean(axis=reduction_axes, keepdims=True)
Expand All @@ -470,29 +470,29 @@ def test_instance_norm(self, channel_axes, use_fast_variance=True):
np.testing.assert_allclose(y_one_liner, y, atol=1e-6)

@parameterized.parameters(
{'channel_axes': 0},
{'channel_axes': -4},
{'channel_axes': (0, 3)},
{'channel_axes': (2, -4)},
{'feature_axes': 0},
{'feature_axes': -4},
{'feature_axes': (0, 3)},
{'feature_axes': (2, -4)},
)
def test_instance_norm_raise_error(self, channel_axes):
def test_instance_norm_raise_error(self, feature_axes):
with self.assertRaisesRegex(
ValueError,
'The channel axes cannot include the leading dimension '
'as this is assumed to be the batch axis.',
):
x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5))
layer = nn.InstanceNorm(channel_axes=channel_axes)
layer = nn.InstanceNorm(feature_axes=feature_axes)
_ = layer.init(jax.random.key(1), x)

@parameterized.parameters(
{
'layer1': nn.LayerNorm(feature_axes=(1, 2)),
'layer2': nn.InstanceNorm(channel_axes=(1, 2)),
'layer2': nn.InstanceNorm(feature_axes=(1, 2)),
},
{
'layer1': nn.LayerNorm(reduction_axes=(1, 2), feature_axes=-1),
'layer2': nn.InstanceNorm(channel_axes=-1),
'layer2': nn.InstanceNorm(feature_axes=-1),
},
{
'layer1': nn.LayerNorm(
Expand All @@ -502,7 +502,7 @@ def test_instance_norm_raise_error(self, channel_axes):
scale_init=nn.initializers.uniform(),
),
'layer2': nn.InstanceNorm(
channel_axes=(1, -1),
feature_axes=(1, -1),
bias_init=nn.initializers.uniform(),
scale_init=nn.initializers.uniform(),
),
Expand Down
Loading