-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Group Norm #11283
Comments
Thanks for submitting this issue @smorrel1 |
So have the Group Norm been added into mxnet? |
Hi @kalyc @sandeep-krishnamurthy If I want to contribute Group Norm implementation in Gluon. What is the process? |
@srikar2097 you can add a contrib block in gluon. If you want some feedback on your plan to implement GroupNorm, you can subscribe dev list https://mxnet.incubator.apache.org/community/ and send out an email for RFC. If you already knows how to implement it, feel free to submit a PR directly |
Well, I have implemented GroupNorm. It's slower than class GroupNorm(nn.HybridBlock):
"""
If the batch size is small, it's better to use GroupNorm instead of BatchNorm.
GroupNorm achieves good results even at small batch sizes.
Reference:
https://arxiv.org/pdf/1803.08494.pdf
"""
def __init__(self, num_channels, num_groups=32, eps=1e-5,
multi_precision=False, **kwargs):
super(GroupNorm, self).__init__(**kwargs)
with self.name_scope():
self.weight = self.params.get('weight', grad_req='write',
shape=(1, num_channels, 1, 1))
self.bias = self.params.get('bias', grad_req='write',
shape=(1, num_channels, 1, 1))
self.C = num_channels
self.G = num_groups
self.eps = eps
self.multi_precision = multi_precision
assert self.C % self.G == 0
def hybrid_forward(self, F, x, weight, bias):
x_new = F.reshape(x, (0, self.G, -1)) # (N,C,H,W) -> (N,G,H*W*C//G)
if self.multi_precision:
mean = F.mean(F.cast(x_new, "float32"),
axis=-1, keepdims=True) # (N,G,H*W*C//G) -> (N,G,1)
mean = F.cast(mean, "float16")
else:
mean = F.mean(x_new, axis=-1, keepdims=True)
centered_x_new = F.broadcast_minus(x_new, mean) # (N,G,H*W*C//G)
if self.multi_precision:
var = F.mean(F.cast(F.square(centered_x_new),"float32"),
axis=-1, keepdims=True) # (N,G,H*W*C//G) -> (N,G,1)
var = F.cast(var, "float16")
else:
var = F.mean(F.square(centered_x_new), axis=-1, keepdims=True)
x_new = F.broadcast_div(centered_x_new, F.sqrt(var + self.eps) # (N,G,H*W*C//G) -> (N,C,H,W)
).reshape_like(x)
x_new = F.broadcast_add(F.broadcast_mul(x_new, weight),bias)
return x_new Clearly there are several issues, for example:
I think this layer is quite important, as not everyone have plenty of GPUs ( if you have plenty, then P.s. a question to the MXNet authors: |
@eric-haibin-lin 楼上的实现ok吗? |
@chi-hung I think the sumSquare op resides in @zhanghang1989 's fork. @zhanghang1989 could you confirm? |
Attempting implementation in backend... |
Yes, it stays in some branch of my fork. We used to implement that for SyncBN |
Implementation in #14959, almost done, just gradient of data still has some work to be done. |
Group norm is more accurate than Batch norm for small batches, useful for many vision tasks. See Group Normalization
Pytorch and Tensorflow (code below) have implementations. Could anyone port this please? It would be really helpful for many of us but I'm not sure how to implement. Thanks!
def GroupNorm(x, gamma, beta, G, eps=1e−5):
# x: input features with shape [N,C,H,W]
# gamma, beta: scale and offset, with shape [1,C,1,1] # G: number of groups for GN
N, C, H, W = x.shape
x = tf.reshape(x, [N, G, C // G, H, W])
mean, var = tf.nn.moments(x, [2, 3, 4], keep dims=True) x = (x − mean) / tf.sqrt(var + eps)
x = tf.reshape(x, [N, C, H, W]) return x ∗ gamma + beta
Figure 3. Python code of Group Norm based on TensorFlow.
The text was updated successfully, but these errors were encountered: