Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Bug Fix] Fix GroupNorm Implementation #18199

Merged
merged 2 commits into from
Apr 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,26 +820,29 @@ class GroupNorm(HybridBlock):
"""
def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
prefix=None, params=None):
in_channels=0, prefix=None, params=None):
super(GroupNorm, self).__init__(prefix=prefix, params=params)
self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale}
self._num_groups = num_groups
self._epsilon = epsilon
self._center = center
self._scale = scale
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
shape=(num_groups,), init=gamma_initializer,
shape=(in_channels,), init=gamma_initializer,
allow_deferred_init=True)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
shape=(num_groups,), init=beta_initializer,
shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True)

def hybrid_forward(self, F, data, gamma, beta):
norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized one quick issue. Should we consider to move GroupNorm to npx? Currently, the layer won’t be usable in the new numpy interface. @zhreshold

return norm_data

def __repr__(self):
s = '{name}({content})'
s = '{name}({content}'
in_channels = self.gamma.shape[0]
s += ', in_channels={0}'.format(in_channels)
s += ')'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
for k, v in self._kwargs.items()]))
Expand Down
25 changes: 13 additions & 12 deletions src/operator/nn/group_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,16 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,
TBlob data_grp = data.reshape(temp_data_shape);
const TBlob& mean_grp = mean.reshape(moments_shape);
const TBlob& std_grp = std.reshape(moments_shape);
const TBlob& output = outputs[groupnorm::kOut].reshape(temp_data_shape);
const TBlob& output_grp = outputs[groupnorm::kOut].reshape(temp_data_shape);

// Calculate data = data - mean
BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
{data_grp, mean_grp},
{kWriteTo}, {output});
{kWriteTo}, {output_grp});

// Calculate std
const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape);
MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, {
MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, true>(
s, std_, req[0], workspace, centered_out);
Expand All @@ -157,11 +157,12 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,

// Calculate data = data / std
BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
{output, std_grp},
{kWriteTo}, {output});
{output_grp, std_grp},
{kWriteTo}, {output_grp});

mxnet::TShape new_param_shape(data_shape.ndim() + 1, 1);
new_param_shape[1] = num_groups;
const TBlob& output = outputs[groupnorm::kOut];
mxnet::TShape new_param_shape(data_shape.ndim(), 1);
new_param_shape[1] = data_shape[1];

const TBlob& gamma = inputs[groupnorm::kGamma].reshape(new_param_shape);
const TBlob& beta = inputs[groupnorm::kBeta].reshape(new_param_shape);
Expand Down Expand Up @@ -215,8 +216,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,

Stream<xpu> *s = ctx.get_stream<xpu>();
// Reshape gamma to be broadcastable
mxnet::TShape new_param_shape(dshape.ndim() + 1, 1);
new_param_shape[1] = num_groups;
mxnet::TShape new_param_shape(dshape.ndim(), 1);
new_param_shape[1] = dshape[1];

const TBlob& gamma = inputs[2].reshape(new_param_shape);

Expand All @@ -233,7 +234,7 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
// Prepare the necessary shapes for reduction
mxnet::TShape red_src_shape, red_dst_shape, red_exclude_src_shape, red_exclude_dst_shape;
BroadcastReduceShapeCompact(temp_dshape, mean_.shape_, &red_src_shape, &red_dst_shape);
BroadcastReduceShapeCompact(temp_dshape, gamma.shape_,
BroadcastReduceShapeCompact(dshape, gamma.shape_,
&red_exclude_src_shape, &red_exclude_dst_shape);

int N = red_src_shape.Size() / red_dst_shape.Size();
Expand Down Expand Up @@ -308,8 +309,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
if (req[0] != kNullOp) {
const TBlob output_ = outputs[0].reshape(data_.shape_);
BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
{ograd, gamma},
{kWriteTo}, {ograd_mult});
{inputs[0], gamma},
{kWriteTo}, {ograd_mult.reshape(data.shape_)});
BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
{ograd_mult, std_},
{kWriteTo}, {ograd_mult});
Expand Down
4 changes: 2 additions & 2 deletions src/operator/nn/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs,
return false;
}

in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups));
in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups));
in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(dshape[1]));
in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(dshape[1]));

out_shape->clear();
out_shape->push_back(dshape);
Expand Down
16 changes: 8 additions & 8 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,28 +1960,28 @@ def x_hat_helper(x, num_groups, eps):
return x_hat, mean, std

def np_groupnorm(data, gamma, beta, num_groups, eps):
new_param_shape = (1, num_groups, 1, 1, 1)
new_param_shape = (1, dshape[1], 1, 1)
x_hat, mean, std = x_hat_helper(data, num_groups, eps)
out = x_hat * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape)
return out.reshape(dshape), mean, std
out = x_hat.reshape(dshape) * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape)
return out, mean, std

def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps):
x_hat, mean, std = x_hat_helper(data, num_groups, eps)
new_shape = x_hat.shape
dshape = data.shape
dtype = data.dtype
new_moments_shape = (new_shape[0], num_groups, 1, 1, 1)
new_param_shape = (1, num_groups, 1, 1, 1)
new_param_shape = (1, dshape[1], 1, 1)
acc_type = acc_types[str(dtype)]
ograd = ograd.reshape(new_shape)
data = data.reshape(new_shape)
gamma = gamma.reshape(new_param_shape)
beta = beta.reshape(new_param_shape)
mean = mean.reshape(new_moments_shape)
std = std.reshape(new_moments_shape)
beta_grad = np.sum(ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype)
gamma_grad = np.sum(x_hat * ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype)
x_hat_grad = ograd * gamma
beta_grad = np.sum(ograd, axis=(0, 3, 4), dtype=acc_type, keepdims=False).astype(dtype).flatten()
gamma_grad = np.sum(x_hat * ograd, axis=(0, 3, 4), dtype=acc_type, keepdims=False).astype(dtype).flatten()
x_hat_grad = ograd * gamma.reshape(1, num_groups, dshape[1] // num_groups, 1, 1)
ograd_mult = x_hat_grad / std
red_out = np.mean(ograd_mult, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype)
data_grad = ograd_mult - red_out
Expand All @@ -1996,7 +1996,7 @@ def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps):
height = random.randint(1, 5)
width = random.randint(1, 5)
dshape = (batch_size, num_channels, height, width)
param_shape = (num_groups,)
param_shape = (num_channels,)
temp_shape = (batch_size, num_groups, int(num_channels / num_groups), height, width)
np_data = np.random.uniform(0.2, 1.0, dshape)
np_gamma = np.random.uniform(-1.0, 1.0, param_shape)
Expand Down