Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

normalization modification #963

Merged
merged 4 commits into from
May 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ To release a new version, please update the changelog as followed:
## [Unreleased]

### Added
- Layer
- `InstanceNorm`, `InstanceNorm1d`, `InstanceNorm2d`, `InstanceNorm3d` (PR #963)

### Changed
- remove `tl.layers.initialize_global_variables(sess)` (PR #931)
Expand All @@ -82,13 +84,15 @@ To release a new version, please update the changelog as followed:
### Deprecated

### Fixed
- In `BatchNorm`, keep dimensions of mean and variance to suit `channels first` (PR #963)

### Removed

### Security

### Contributors
@zsdonghao: #931
- @zsdonghao: #931
- @yd-yin: #963


## [2.0.0-alpha] - 2019-05-04
Expand Down
30 changes: 30 additions & 0 deletions docs/modules/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,14 @@ Layer list
batch_transformer

BatchNorm
BatchNorm1d
BatchNorm2d
BatchNorm3d
LocalResponseNorm
InstanceNorm
InstanceNorm1d
InstanceNorm2d
InstanceNorm3d
LayerNorm
GroupNorm
SwitchNorm
Expand Down Expand Up @@ -364,6 +370,18 @@ Batch Normalization
^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: BatchNorm

Batch Normalization 1D
^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: BatchNorm1d

Batch Normalization 2D
^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: BatchNorm2d

Batch Normalization 3D
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: BatchNorm3d

Local Response Normalization
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: LocalResponseNorm
Expand All @@ -372,6 +390,18 @@ Instance Normalization
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: InstanceNorm

Instance Normalization 1D
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: InstanceNorm1d

Instance Normalization 2D
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: InstanceNorm2d

Instance Normalization 3D
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: InstanceNorm3d

Layer Normalization
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: LayerNorm
Expand Down
245 changes: 196 additions & 49 deletions tensorlayer/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
'BatchNorm2d',
'BatchNorm3d',
'InstanceNorm',
'InstanceNorm1d',
'InstanceNorm2d',
'InstanceNorm3d',
'LayerNorm',
'GroupNorm',
'SwitchNorm',
Expand Down Expand Up @@ -259,7 +262,7 @@ def build(self, inputs_shape):
self.moving_var = self._get_weights("moving_var", shape=params_shape, init=self.moving_var_init)

def forward(self, inputs):
mean, var = tf.nn.moments(inputs, self.axes)
mean, var = tf.nn.moments(inputs, self.axes, keepdims=True)
if self.is_train:
# update moving_mean and moving_var
self.moving_mean = moving_averages.assign_moving_average(
Expand Down Expand Up @@ -388,81 +391,225 @@ def _get_param_shape(self, inputs_shape):


class InstanceNorm(Layer):
"""The :class:`InstanceNorm` class is a for instance normalization.
"""
The :class:`InstanceNorm` is an instance normalization layer for both fully-connected and convolution outputs.
See ``tf.nn.batch_normalization`` and ``tf.nn.moments``.

Parameters
-----------
act : activation function.
The activation function of this layer.
epsilon : float
Eplison.
beta_init : initializer or None
The initializer for initializing beta, if None, skip beta.
Usually you should not skip beta unless you know what happened.
gamma_init : initializer or None
The initializer for initializing gamma, if None, skip gamma.
When the instance normalization layer is use instead of 'biases', or the next layer is linear, this can be
disabled since the scaling can be done by the next layer. see `Inception-ResNet-v2 <https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py>`__
num_features: int
Number of features for input tensor. Useful to build layer if using InstanceNorm1d, InstanceNorm2d or InstanceNorm3d,
but should be left as None if using InstanceNorm. Default None.
data_format : str
channels_last 'channel_last' (default) or channels_first.
name : None or str
A unique layer name
A unique layer name.


Examples
---------
With TensorLayer

>>> net = tl.layers.Input([None, 50, 50, 32], name='input')
>>> net = tl.layers.InstanceNorm()(net)

Notes
-----
The :class:`InstanceNorm` is universally suitable for 3D/4D/5D input in static model, but should not be used
in dynamic model where layer is built upon class initialization. So the argument 'num_features' should only be used
for subclasses :class:`InstanceNorm1d`, :class:`InstanceNorm2d` and :class:`InstanceNorm3d`. All the three subclasses are
suitable under all kinds of conditions.
"""

def __init__(
self,
act=None,
epsilon=1e-5,
name=None, #'instan_norm',
self, act=None, epsilon=0.00001, beta_init=tl.initializers.zeros(),
gamma_init=tl.initializers.random_normal(mean=1.0, stddev=0.002), num_features=None,
data_format='channels_last', name=None
):
# super(InstanceNorm, self).__init__(prev_layer=prev_layer, act=act, name=name)
super().__init__(name)
super(InstanceNorm, self).__init__(name=name)
self.act = act
self.epsilon = epsilon
self.beta_init = beta_init
self.gamma_init = gamma_init
self.num_features = num_features
self.data_format = data_format

if num_features is not None:
if not isinstance(self, InstanceNorm1d) and not isinstance(self, InstanceNorm2d) and not isinstance(
self, InstanceNorm3d):
raise ValueError(
"Please use InstanceNorm1d or InstanceNorm2d or InstanceNorm3d instead of InstanceNorm "
"if you want to specify 'num_features'."
)
self.build(None)
self._built = True

logging.info(
"InstanceNorm %s: epsilon: %f act: %s" %
"InstanceNorm %s: epsilon: %f act: %s " %
(self.name, epsilon, self.act.__name__ if self.act is not None else 'No Activation')
)

def __repr__(self):
actstr = self.act.__name__ if self.act is not None else 'No Activation'
s = '{classname}(num_features=num_features, epsilon={epsilon}' + actstr
if self.name is not None:
s += ', name="{name}"'
s += ')'
return s.format(classname=self.__class__.__name__, **self.__dict__)

def _get_param_shape(self, inputs_shape):
if self.data_format == 'channels_last':
axis = len(inputs_shape) - 1
elif self.data_format == 'channels_first':
axis = 1
else:
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))

channels = inputs_shape[axis]
params_shape = [1] * len(inputs_shape)
params_shape[axis] = channels

axes = [i for i in range(len(inputs_shape)) if i != 0 and i != axis]
return params_shape, axes

def build(self, inputs_shape):
# self.scale = tf.compat.v1.get_variable(
# self.name + '\scale', [inputs.get_shape()[-1]],
# initializer=tf.compat.v1.initializers.truncated_normal(mean=1.0, stddev=0.02), dtype=LayersConfig.tf_dtype
# )
self.scale = self._get_weights(
"scale", shape=[inputs_shape[-1]], init=tf.compat.v1.initializers.truncated_normal(mean=1.0, stddev=0.02)
)
# self.offset = tf.compat.v1.get_variable(
# self.name + '\offset', [inputs.get_shape()[-1]], initializer=tf.compat.v1.initializers.constant(0.0),
# dtype=LayersConfig.tf_dtype
# )
self.offset = self._get_weights(
"offset", shape=[inputs_shape[-1]], init=tf.compat.v1.initializers.constant(0.0)
)
# self.add_weights([self.scale, self.offset])
params_shape, self.axes = self._get_param_shape(inputs_shape)

self.beta, self.gamma = None, None
if self.beta_init:
self.beta = self._get_weights("beta", shape=params_shape, init=self.beta_init)

if self.gamma_init:
self.gamma = self._get_weights("gamma", shape=params_shape, init=self.gamma_init)

def forward(self, inputs):
mean, var = tf.nn.moments(inputs, self.axes, keepdims=True)
outputs = batch_normalization(inputs, mean, var, self.beta, self.gamma, self.epsilon, self.data_format)
if self.act:
outputs = self.act(outputs)
return outputs

mean, var = tf.nn.moments(x=inputs, axes=[1, 2], keepdims=True)

outputs = self.scale * tf.compat.v1.div(inputs - mean, tf.sqrt(var + self.epsilon)) + self.offset
outputs = self.act(outputs)
class InstanceNorm1d(InstanceNorm):
"""The :class:`InstanceNorm1d` applies Instance Normalization over 3D input (a mini-instance of 1D
inputs with additional channel dimension), of shape (N, L, C) or (N, C, L).
See more details in :class:`InstanceNorm`.

Examples
---------
With TensorLayer

>>> # in static model, no need to specify num_features
>>> net = tl.layers.Input([None, 50, 32], name='input')
>>> net = tl.layers.InstanceNorm1d()(net)
>>> # in dynamic model, build by specifying num_features
>>> conv = tl.layers.Conv1d(32, 5, 1, in_channels=3)
>>> bn = tl.layers.InstanceNorm1d(num_features=32)

"""

def _get_param_shape(self, inputs_shape):
if self.data_format == 'channels_last':
axis = 2
elif self.data_format == 'channels_first':
axis = 1
else:
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))

if self.num_features is None:
channels = inputs_shape[axis]
else:
channels = self.num_features
params_shape = [1] * 3
params_shape[axis] = channels

axes = [i for i in range(3) if i != 0 and i != axis]
return params_shape, axes

return outputs

# with tf.variable_scope(name) as vs:
# mean, var = tf.nn.moments(self.inputs, [1, 2], keep_dims=True)
#
# scale = tf.get_variable(
# 'scale', [self.inputs.get_shape()[-1]],
# initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02), dtype=LayersConfig.tf_dtype
# )
#
# offset = tf.get_variable(
# 'offset', [self.inputs.get_shape()[-1]], initializer=tf.constant_initializer(0.0),
# dtype=LayersConfig.tf_dtype
# )
#
# self.outputs = scale * tf.div(self.inputs - mean, tf.sqrt(var + epsilon)) + offset
# self.outputs = self._apply_activation(self.outputs)
#
# variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
#
# self._add_layers(self.outputs)
# self._add_params(variables)
class InstanceNorm2d(InstanceNorm):
"""The :class:`InstanceNorm2d` applies Instance Normalization over 4D input (a mini-instance of 2D
inputs with additional channel dimension) of shape (N, H, W, C) or (N, C, H, W).
See more details in :class:`InstanceNorm`.

Examples
---------
With TensorLayer

>>> # in static model, no need to specify num_features
>>> net = tl.layers.Input([None, 50, 50, 32], name='input')
>>> net = tl.layers.InstanceNorm2d()(net)
>>> # in dynamic model, build by specifying num_features
>>> conv = tl.layers.Conv2d(32, (5, 5), (1, 1), in_channels=3)
>>> bn = tl.layers.InstanceNorm2d(num_features=32)

"""

def _get_param_shape(self, inputs_shape):
if self.data_format == 'channels_last':
axis = 3
elif self.data_format == 'channels_first':
axis = 1
else:
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))

if self.num_features is None:
channels = inputs_shape[axis]
else:
channels = self.num_features
params_shape = [1] * 4
params_shape[axis] = channels

axes = [i for i in range(4) if i != 0 and i != axis]
return params_shape, axes


class InstanceNorm3d(InstanceNorm):
"""The :class:`InstanceNorm3d` applies Instance Normalization over 5D input (a mini-instance of 3D
inputs with additional channel dimension) with shape (N, D, H, W, C) or (N, C, D, H, W).
See more details in :class:`InstanceNorm`.

Examples
---------
With TensorLayer

>>> # in static model, no need to specify num_features
>>> net = tl.layers.Input([None, 50, 50, 50, 32], name='input')
>>> net = tl.layers.InstanceNorm3d()(net)
>>> # in dynamic model, build by specifying num_features
>>> conv = tl.layers.Conv3d(32, (5, 5, 5), (1, 1), in_channels=3)
>>> bn = tl.layers.InstanceNorm3d(num_features=32)

"""

def _get_param_shape(self, inputs_shape):
if self.data_format == 'channels_last':
axis = 4
elif self.data_format == 'channels_first':
axis = 1
else:
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))

if self.num_features is None:
channels = inputs_shape[axis]
else:
channels = self.num_features
params_shape = [1] * 5
params_shape[axis] = channels

axes = [i for i in range(5) if i != 0 and i != axis]
return params_shape, axes


# FIXME : not sure about the correctness, need testing
Expand Down