Skip to content

Commit

Permalink
fix BatchNorm for fp16 (#36376) (#36691)
Browse files Browse the repository at this point in the history
* fix BatchNorm for fp16
  • Loading branch information
GuoxiaWang authored Oct 27, 2021
1 parent 64643d5 commit 417b22d
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions python/paddle/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,32 +564,42 @@ def __init__(self,
self._use_global_stats = use_global_stats

if get_default_dtype() == 'float16':
set_default_dtype('float32')
self._dtype = 'float32'
else:
self._dtype = get_default_dtype()

param_shape = [num_features]

# create parameter
if weight_attr == False:
self.weight = self.create_parameter(
attr=None, shape=param_shape, default_initializer=Constant(1.0))
attr=None,
shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(1.0))
self.weight.stop_gradient = True
else:
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(1.0))
self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0.

if bias_attr == False:
self.bias = self.create_parameter(
attr=None,
shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(0.0),
is_bias=True)
self.bias.stop_gradient = True
else:
self.bias = self.create_parameter(
attr=self._bias_attr, shape=param_shape, is_bias=True)
attr=self._bias_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=True)
self.bias.stop_gradient = self._bias_attr != None and self._bias_attr.learning_rate == 0.

moving_mean_name = None
Expand All @@ -600,6 +610,7 @@ def __init__(self,
moving_variance_name = name + "_variance"

self._mean = self.create_parameter(
dtype=self._dtype,
attr=ParamAttr(
name=moving_mean_name,
initializer=Constant(0.0),
Expand All @@ -609,6 +620,7 @@ def __init__(self,
self._mean.stop_gradient = True

self._variance = self.create_parameter(
dtype=self._dtype,
attr=ParamAttr(
name=moving_variance_name,
initializer=Constant(1.0),
Expand Down

0 comments on commit 417b22d

Please sign in to comment.