Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Group Norm #11283

Closed
smorrel1 opened this issue Jun 14, 2018 · 12 comments
Closed

Group Norm #11283

smorrel1 opened this issue Jun 14, 2018 · 12 comments

Comments

@smorrel1
Copy link

Group norm is more accurate than Batch norm for small batches, useful for many vision tasks. See Group Normalization
Pytorch and Tensorflow (code below) have implementations. Could anyone port this please? It would be really helpful for many of us but I'm not sure how to implement. Thanks!

def GroupNorm(x, gamma, beta, G, eps=1e−5):
# x: input features with shape [N,C,H,W]
# gamma, beta: scale and offset, with shape [1,C,1,1] # G: number of groups for GN
N, C, H, W = x.shape
x = tf.reshape(x, [N, G, C // G, H, W])
mean, var = tf.nn.moments(x, [2, 3, 4], keep dims=True) x = (x − mean) / tf.sqrt(var + eps)
x = tf.reshape(x, [N, C, H, W]) return x ∗ gamma + beta
Figure 3. Python code of Group Norm based on TensorFlow.

@kalyc
Copy link
Contributor

kalyc commented Jun 14, 2018

Thanks for submitting this issue @smorrel1
@sandeep-krishnamurthy could you add label "Feature Request" to this issue?

@YueshangGu
Copy link

So have the Group Norm been added into mxnet?

@srikar2097
Copy link

Hi @kalyc @sandeep-krishnamurthy If I want to contribute Group Norm implementation in Gluon. What is the process?

@eric-haibin-lin
Copy link
Member

eric-haibin-lin commented Sep 25, 2018

@srikar2097 you can add a contrib block in gluon. If you want some feedback on your plan to implement GroupNorm, you can subscribe dev list https://mxnet.incubator.apache.org/community/ and send out an email for RFC. If you already knows how to implement it, feel free to submit a PR directly

@chi-hung
Copy link

chi-hung commented Nov 20, 2018

Well, I have implemented GroupNorm. It's slower than nn.BatchNorm, but it works (as the code below):

class GroupNorm(nn.HybridBlock):
    """
    If the batch size is small, it's better to use GroupNorm instead of BatchNorm.
    GroupNorm achieves good results even at small batch sizes.
    Reference:
      https://arxiv.org/pdf/1803.08494.pdf
    """
    def __init__(self, num_channels, num_groups=32, eps=1e-5,
                 multi_precision=False, **kwargs):
        super(GroupNorm, self).__init__(**kwargs)

        with self.name_scope():
            self.weight = self.params.get('weight', grad_req='write',
                                          shape=(1, num_channels, 1, 1))
            self.bias = self.params.get('bias', grad_req='write',
                                        shape=(1, num_channels, 1, 1))
        self.C = num_channels
        self.G = num_groups
        self.eps = eps
        self.multi_precision = multi_precision

        assert self.C % self.G == 0

    def hybrid_forward(self, F, x, weight, bias):

        x_new = F.reshape(x, (0, self.G, -1))                                # (N,C,H,W) -> (N,G,H*W*C//G)

        if self.multi_precision:
            mean = F.mean(F.cast(x_new, "float32"),
                          axis=-1, keepdims=True)                            # (N,G,H*W*C//G) -> (N,G,1)
            mean = F.cast(mean, "float16")
        else:
            mean = F.mean(x_new, axis=-1, keepdims=True)

        centered_x_new = F.broadcast_minus(x_new, mean)                      # (N,G,H*W*C//G)

        if self.multi_precision:
            var = F.mean(F.cast(F.square(centered_x_new),"float32"),
                         axis=-1, keepdims=True)                             # (N,G,H*W*C//G) -> (N,G,1)
            var = F.cast(var, "float16")
        else:
            var = F.mean(F.square(centered_x_new), axis=-1, keepdims=True)

        x_new = F.broadcast_div(centered_x_new, F.sqrt(var + self.eps)       # (N,G,H*W*C//G) -> (N,C,H,W)
                                ).reshape_like(x)
        x_new = F.broadcast_add(F.broadcast_mul(x_new, weight),bias)
        return x_new

Clearly there are several issues, for example:

  • An operator such as F.moments() (quite common) is not implemented in MXNet yet. Hence, my implementation here might be slow.
  • The use of reshape_like() seems unavoidable -> the input tensor has to be kept, which costs RAM.
  • When training with mixed-precision, the above implementation cast a FP16-input into FP32 to avoid loss of precision while calculating both mean & variance. Casting a FP16-tensor to FP32 and then back to FP16 wastes time (this stupid step can be eliminated if we implement this layer at the level of CUDA).

I think this layer is quite important, as not everyone have plenty of GPUs ( if you have plenty, then F.contrib.SyncBatchNorm() will work well ).

P.s. a question to the MXNet authors:
There seem to be an OP called F.SumSquare()( see: https://github.com/dmlc/gluon-cv/blob/0a699a5ccc21310c7ce41d4737f0de9f54fbf45a/gluoncv/model_zoo/syncbn.py#L206 ), which is used for the calculation of the second-order moment I guess. I didn't find it in MXNet's API..., does this OP really exist?

@hustzeyu
Copy link

@eric-haibin-lin 楼上的实现ok吗?

@eric-haibin-lin
Copy link
Member

@chi-hung I think the sumSquare op resides in @zhanghang1989 's fork. @zhanghang1989 could you confirm?

@Jerryzcn
Copy link
Contributor

@haojin2

@haojin2
Copy link
Contributor

haojin2 commented Apr 18, 2019

Attempting implementation in backend...

@zhanghang1989
Copy link
Contributor

@chi-hung I think the sumSquare op resides in @zhanghang1989 's fork. @zhanghang1989 could you confirm?

Yes, it stays in some branch of my fork. We used to implement that for SyncBN

@haojin2
Copy link
Contributor

haojin2 commented May 15, 2019

Implementation in #14959, almost done, just gradient of data still has some work to be done.

@sxjscience
Copy link
Member

sxjscience commented Jul 22, 2019

@smorrel1 , @haojin2 has implemented GroupNorm and you may try that. Feel free to reopen the issue if you met any problems.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests