From 5c4ee50e5a851019b7ef77d991adb2041cc14f0f Mon Sep 17 00:00:00 2001 From: "Huang, Guangtai" Date: Thu, 30 Apr 2020 03:41:37 +0800 Subject: [PATCH 1/2] init --- python/mxnet/gluon/nn/basic_layers.py | 4 ++-- src/operator/nn/group_norm-inl.h | 25 +++++++++++++------------ src/operator/nn/group_norm.cc | 4 ++-- tests/python/unittest/test_operator.py | 16 ++++++++-------- 4 files changed, 25 insertions(+), 24 deletions(-) diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index 70b0a71841f1..0054c5289ae1 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -828,10 +828,10 @@ def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True, self._center = center self._scale = scale self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null', - shape=(num_groups,), init=gamma_initializer, + shape=(0,), init=gamma_initializer, allow_deferred_init=True) self.beta = self.params.get('beta', grad_req='write' if center else 'null', - shape=(num_groups,), init=beta_initializer, + shape=(0,), init=beta_initializer, allow_deferred_init=True) def hybrid_forward(self, F, data, gamma, beta): diff --git a/src/operator/nn/group_norm-inl.h b/src/operator/nn/group_norm-inl.h index 69d5a304dc2c..143e2168d113 100644 --- a/src/operator/nn/group_norm-inl.h +++ b/src/operator/nn/group_norm-inl.h @@ -136,16 +136,16 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs, TBlob data_grp = data.reshape(temp_data_shape); const TBlob& mean_grp = mean.reshape(moments_shape); const TBlob& std_grp = std.reshape(moments_shape); - const TBlob& output = outputs[groupnorm::kOut].reshape(temp_data_shape); + const TBlob& output_grp = outputs[groupnorm::kOut].reshape(temp_data_shape); // Calculate data = data - mean BinaryBroadcastCompute(attrs, ctx, {data_grp, mean_grp}, - {kWriteTo}, {output}); + {kWriteTo}, {output_grp}); // Calculate std const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape); - MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, { + MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { broadcast::Reduce( s, std_, req[0], workspace, centered_out); @@ -157,11 +157,12 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs, // Calculate data = data / std BinaryBroadcastCompute(attrs, ctx, - {output, std_grp}, - {kWriteTo}, {output}); + {output_grp, std_grp}, + {kWriteTo}, {output_grp}); - mxnet::TShape new_param_shape(data_shape.ndim() + 1, 1); - new_param_shape[1] = num_groups; + const TBlob& output = outputs[groupnorm::kOut]; + mxnet::TShape new_param_shape(data_shape.ndim(), 1); + new_param_shape[1] = data_shape[1]; const TBlob& gamma = inputs[groupnorm::kGamma].reshape(new_param_shape); const TBlob& beta = inputs[groupnorm::kBeta].reshape(new_param_shape); @@ -215,8 +216,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); // Reshape gamma to be broadcastable - mxnet::TShape new_param_shape(dshape.ndim() + 1, 1); - new_param_shape[1] = num_groups; + mxnet::TShape new_param_shape(dshape.ndim(), 1); + new_param_shape[1] = dshape[1]; const TBlob& gamma = inputs[2].reshape(new_param_shape); @@ -233,7 +234,7 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, // Prepare the necessary shapes for reduction mxnet::TShape red_src_shape, red_dst_shape, red_exclude_src_shape, red_exclude_dst_shape; BroadcastReduceShapeCompact(temp_dshape, mean_.shape_, &red_src_shape, &red_dst_shape); - BroadcastReduceShapeCompact(temp_dshape, gamma.shape_, + BroadcastReduceShapeCompact(dshape, gamma.shape_, &red_exclude_src_shape, &red_exclude_dst_shape); int N = red_src_shape.Size() / red_dst_shape.Size(); @@ -308,8 +309,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, if (req[0] != kNullOp) { const TBlob output_ = outputs[0].reshape(data_.shape_); BinaryBroadcastCompute(attrs, ctx, - {ograd, gamma}, - {kWriteTo}, {ograd_mult}); + {inputs[0], gamma}, + {kWriteTo}, {ograd_mult.reshape(data.shape_)}); BinaryBroadcastCompute(attrs, ctx, {ograd_mult, std_}, {kWriteTo}, {ograd_mult}); diff --git a/src/operator/nn/group_norm.cc b/src/operator/nn/group_norm.cc index 6b8fe9bbd4c9..c939b4499c94 100644 --- a/src/operator/nn/group_norm.cc +++ b/src/operator/nn/group_norm.cc @@ -47,8 +47,8 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs, return false; } - in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups)); - in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups)); + in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(dshape[1])); + in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(dshape[1])); out_shape->clear(); out_shape->push_back(dshape); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 32812b12eca8..0baa941d142d 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1960,10 +1960,10 @@ def x_hat_helper(x, num_groups, eps): return x_hat, mean, std def np_groupnorm(data, gamma, beta, num_groups, eps): - new_param_shape = (1, num_groups, 1, 1, 1) + new_param_shape = (1, dshape[1], 1, 1) x_hat, mean, std = x_hat_helper(data, num_groups, eps) - out = x_hat * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape) - return out.reshape(dshape), mean, std + out = x_hat.reshape(dshape) * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape) + return out, mean, std def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps): x_hat, mean, std = x_hat_helper(data, num_groups, eps) @@ -1971,7 +1971,7 @@ def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps): dshape = data.shape dtype = data.dtype new_moments_shape = (new_shape[0], num_groups, 1, 1, 1) - new_param_shape = (1, num_groups, 1, 1, 1) + new_param_shape = (1, dshape[1], 1, 1) acc_type = acc_types[str(dtype)] ograd = ograd.reshape(new_shape) data = data.reshape(new_shape) @@ -1979,9 +1979,9 @@ def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps): beta = beta.reshape(new_param_shape) mean = mean.reshape(new_moments_shape) std = std.reshape(new_moments_shape) - beta_grad = np.sum(ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) - gamma_grad = np.sum(x_hat * ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) - x_hat_grad = ograd * gamma + beta_grad = np.sum(ograd, axis=(0, 3, 4), dtype=acc_type, keepdims=False).astype(dtype).flatten() + gamma_grad = np.sum(x_hat * ograd, axis=(0, 3, 4), dtype=acc_type, keepdims=False).astype(dtype).flatten() + x_hat_grad = ograd * gamma.reshape(1, num_groups, dshape[1] // num_groups, 1, 1) ograd_mult = x_hat_grad / std red_out = np.mean(ograd_mult, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype) data_grad = ograd_mult - red_out @@ -1996,7 +1996,7 @@ def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps): height = random.randint(1, 5) width = random.randint(1, 5) dshape = (batch_size, num_channels, height, width) - param_shape = (num_groups,) + param_shape = (num_channels,) temp_shape = (batch_size, num_groups, int(num_channels / num_groups), height, width) np_data = np.random.uniform(0.2, 1.0, dshape) np_gamma = np.random.uniform(-1.0, 1.0, param_shape) From 32703ae8ed0ccc9ec751978017cd49cf4f812aa7 Mon Sep 17 00:00:00 2001 From: "Huang, Guangtai" Date: Thu, 30 Apr 2020 16:04:50 +0800 Subject: [PATCH 2/2] add in_channels --- python/mxnet/gluon/nn/basic_layers.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index 0054c5289ae1..797392a6a36a 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -820,7 +820,7 @@ class GroupNorm(HybridBlock): """ def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', - prefix=None, params=None): + in_channels=0, prefix=None, params=None): super(GroupNorm, self).__init__(prefix=prefix, params=params) self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale} self._num_groups = num_groups @@ -828,10 +828,10 @@ def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True, self._center = center self._scale = scale self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null', - shape=(0,), init=gamma_initializer, + shape=(in_channels,), init=gamma_initializer, allow_deferred_init=True) self.beta = self.params.get('beta', grad_req='write' if center else 'null', - shape=(0,), init=beta_initializer, + shape=(in_channels,), init=beta_initializer, allow_deferred_init=True) def hybrid_forward(self, F, data, gamma, beta): @@ -839,7 +839,10 @@ def hybrid_forward(self, F, data, gamma, beta): return norm_data def __repr__(self): - s = '{name}({content})' + s = '{name}({content}' + in_channels = self.gamma.shape[0] + s += ', in_channels={0}'.format(in_channels) + s += ')' return s.format(name=self.__class__.__name__, content=', '.join(['='.join([k, v.__repr__()]) for k, v in self._kwargs.items()]))