Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch normalization layer with test and examples #1965

Closed
wants to merge 3 commits into from

Conversation

ducha-aiki
Copy link
Contributor

Implemented batch normalization layer (see http://arxiv.org/abs/1502.03167) based on @ChenglongChen and @Russell91 code with fixes and improvements.
Also added shuffling pool by @jjkjkj of the input data to the data_layer to not to have same files together in same batch. Tests passes and rebased on master.
For illustration of the effectiveness two examples of CIFAR-10 classifier with sigmoid non-linearity with and without batch normalization.
bn_cifar

@weiliu89
Copy link

Nice! I guess cifar_baseline is using ReLU instead of Sigmoid? Do you have any training examples of using ReLU + BN? Does it converge much faster than not using BN as stated in the paper?

@ChenglongChen
Copy link

@ducha-aiki,
Nice work!
I myself have the problem to figure out how to implement per-epoch (or mini-batch) shuffling.

As for the fixed mean and variance for inference, I think we can hack that by using two extra vars to keep track of the (exponential) moving averaging mean or variance, and use those instead of the current batch_mean and batch_variance for normalization in TEST phase.

Besides, the current implementation keeps two "copies" of the blob (buffer_blob and x_norm) which might be a bit memory consuming when using big deep net. It might worth considering switching to for-loop rather than the current BLAS vectorization as @Russell91 did in his init commit.

@weiliu89
Copy link

@ChenglongChen

I used different "stage" at the beginning of TEST "phase" to compute mean and variance from a few training mini-batches. Moving average might work during training as mentioned in Section 3.1 in the paper.

@ChenglongChen
Copy link

@weiliu89,
I haven't totally understood what Caffe's params refer to. I have to check what's the difference between stage and phase. But, it sounds like you are doing what Algorithm 2 in the paper describes (the inference part), right?

@weiliu89
Copy link

@ChenglongChen

Yes. I refer to how Caffe handle phase, and add several stages that bn_layer does different thing in different stage. This is the easiest way I can think of to implement Algorithm 2. It seems working well, but I haven't debug it though

@sunbaigui
Copy link

@ducha-aiki does master branch supports googlenet network now?
why not merge this into dev branch, and experiment on
image
Figure 4 in the paper.
Has anyone tried these models yet?

@ducha-aiki
Copy link
Contributor Author

@weiliu89, we will add some graphs with ReLU-CIFAR later. For now, BN-model converges very fast, but a bit less accurate than non-BN. If we make a bit deeper model, than cifar-baseline, than BN converges faster and to more accurate one. And thank you for stage-phase suggestion

@sunbaigui, sure, if you sponsor us with some GPU. When I have trained my modification of the GoogLeNet, it took 3 weeks. Even if it would be 7 times faster, it is too much GPU time to spent for us :)
About master and dev - there is no dev anymore (see #1943 ) and master supports GoogLeNet.

@ChenglongChen we will think about loop-based implementation. However, you are welcome to make PR into this branch :)

@jjkjkj
Copy link
Contributor

jjkjkj commented Feb 26, 2015

cifar_baseline is example/cifar/train_full.sh
@weiliu89 I could not archive cifar_baseline accuracy with batch normalization, without net modification( simply removing lrn, adding bn before or after every relu in examples/cifar10/cifar10_full_train_test.prototxt and only playing with learning rate)
I think this due to over-simplicity of this net and the fact than lrn for this net is enough normalization already.

I also trained variation of vgg16 on cifar with and without batch normalization
The only difference in architecture, compared to original vgg16, is shrieked down ip layers and adding dropout(not sure this is necessary for bn network)
Also I forgot to turn off PReLU and MSR weight filler in this networks, but I think that this makes little to no difference.
1
https://gist.github.com/jjkjkj/39e87099e9381e6886a5

For no_bn net base_lr=0.001 causes net to diverge. For bn net lr is first guess, so maybe with bigger lr it will converge faster and better.

@ChenglongChen @weiliu89 @ducha-aiki About test phase. I tested cifar_vgg16 whith range of batch sizes (2-250) in test phase and found very small changes in accuracy (with batch size 2 accuracy is only 1% less than with 250)

@ducha-aiki
Copy link
Contributor Author

@shelhamer @longjon @jeffdonahue Could you please review this PR?

@weiliu89
Copy link

I think that current PR doesn't compute mean and variance from training images (or moving mean and variance) during testing phase, but it compute mean and variance from test mini-batch, which I think is not exactly the same as described in the paper. I am not sure how much it affects the test accuracy.

@yangyi02
Copy link

@jjkjkj For the cifar experiment, do you try the comparison adding bn before or after every relu? Does that matter?

@jjkjkj
Copy link
Contributor

jjkjkj commented Feb 26, 2015

@yangyi02 Yes i tried and found no difference(with examples/cifar10/cifar10_full_train_test.prototxt). As i said i think that this net is bad example for batch normalization.
As for cifar_vgg16, now i'm training net without dropout.

@nakosung
Copy link

nakosung commented Mar 2, 2015

To feed evaluation network mean/var, setting mean = beta, var = (1/gamma)^2 will be OK? (Learned beta, gamma is similar to true mean/var?)

@justfortest1
Copy link

@ducha-aiki when batchsize=1 in test phase, can it also work?

@ducha-aiki
Copy link
Contributor Author

@justfortest1 no.

@lsy1993311
Copy link

@weiliu89 Could you please share you implementation? I think different stages should be considered.

@weiliu89
Copy link

@lsy1993311 My Caffe version is old and I am not familiar with how to upload the code and haven't tested it. The high level idea is to include set_stage(string) and stage() in include/caffe/common.hpp (Refer to set_phase() and phase() in the same file). Then in src/caffe/solver.cpp, I add a function at the beginning of TestAll() which tries to compute mean & std. In the function, I set phase to TRAIN, and include two stages by using set_stage() as described before. The first stage is called "aggregation" which does several iteration of Forward pass to aggregate mean & std from a few mini-batches; the second stage is called "finalize" which compute the final mean & std by dividing the number of mini-batches you have passed. Finally, in batch_norm_layer, I can call Caffe::stage() and implement some additional thing in order to handle different stage (e.g. "aggregate" and "finalize"). I won't go in details into how to do it as it should be trivial.

However I don't have time to really debug this thoroughly. One thing to notice is that what I described above needs to compute mean & std every time I call TestAll() which might not necessary because it costs extra computation during training. On the other hand, you can only call the function in Snapshot() and use moving mean & std during training as described in the paper (can set a different stage for doing this during training).

@melgor
Copy link

melgor commented Mar 17, 2015

I have been testing that version with my data set using VGG16 model. And it works, speed up converge.
But it does not produce better result than normal version. Even more, batch size on test influence on accuracy.
@weiliu89 propose resolving that problem. I will try to implement it. But this is only mismatch between Caffe code and Paper?

@ducha-aiki
Copy link
Contributor Author

@melgor, as far as I see - yes. It would be great, if you help to implement @weiliu89 solution.

@melgor
Copy link

melgor commented Mar 18, 2015

I have just found implementation of @ChenglongChen which implement BN with right code in Test Phase. It save the mean and variance in BN layer. It looks like better implementation because it does not need to change Solver code. But it does not calculate mean and variance within all Train data, only update the value of statistic using ex: S_{t+1} = decay * Y_{t+1} + (1 - decay) * S_{t}, where decay is parameter.

What do you think about such implementation? @ChenglongChen, does it work better than mini-batch statistic?

More information here:
https://github.com/ChenglongChen/caffe-windows/blob/master/src/caffe/layers/bn_layer.cpp

@ChenglongChen
Copy link

Sorry guys. I have been caught up with work this moment, so I don't have time to test it out thoroughly. The use of exponentially weighted moving average (EWMA) is simply due to the fact that BN tends to keep the distribution of activation stable (?).

The algo2 in the paper is a bit complicate:

  1. before the TEST phase, we forward a few mini-batch to compute the mean & var for the 1st BN layer, then we save this mean & var for other round inference (& forward)
  2. we then forward those mini-batch to compute the mean & var for the 2nd BN layer, notice that the normalization part of the 1st BN layer is carried out using mean & var computed in step1 not the mini-batch statistics.
  3. similarly, we perform the above for the rest BN layers.
  4. after computing all the mean & var, we then have the inference BN network.

@jjkjkj
Copy link
Contributor

jjkjkj commented Mar 25, 2015

@lsy1993311 As exepcted: slightly faster initial training but strong overfitting(it's natural whet parameters >> dataset). So, BN does not always remove need of dropout.

@nakosung
Copy link

What if removing BN layers at testing phase? I mean that no normalization/reconstruction will be used during testing.

@borisgin
Copy link

Hi Dmitro, nice work!
When I tryied to train cifar_bn, I got endless warnings "force_color: ...". Looks like a minor bug, So I just commented these lines in data_layer. cpp.
Another observation: when I put bn layer just before convolutional & Ip layer, I obtained faster convergence then when bn_layer is located after convolutional layer:
cifar_fast_bn

@ducha-aiki
Copy link
Contributor Author

@weiliu89, thanks for catch!

@borisgin Hi Boris, thanks for for observation. It is interesting, that is it very architecture dependent: we have tried on other architectures and there was no difference, as stated in original paper. Still lot of place for exploring :)
The force_color bug is fixed now, thanks.

@andrei-pokrovsky
Copy link

weiliu89>>I think that current PR doesn't compute mean and variance from training images (or moving mean and variance) during testing phase, but it compute mean and variance from test mini-batch, which I think is not exactly the same as described in the paper. I am not sure how much it affects the test accuracy.

FWIW I agree, that's different from the paper description. Also what if the test data contains only one sample (single image inference)?

@xuzhm
Copy link

xuzhm commented Aug 5, 2015

bug?? occurs "nan" .........
20150805193632

after debug found that batch_variance_ have negative number .....
20150805202636

@AIROBOTAI
Copy link

Hi @ducha-aiki, and others, thanks for your excellent work! Here I have a question about your code. Could you please explain to me what is the purpose for the codes starting from the line No. 153 to line NO. 185 in the function "void DataLayer::InternalThreadEntry()" in data_layer.cpp? I just cannot figure out why these codes should be there when the "datum.encoded()" is true. Thx a lot in advance!

@ducha-aiki
Copy link
Contributor Author

@AIROBOTAI it needs for data shuffling. With batch normalization, it is important that network don`t see same images together in batch, so this lines implement shuffling.
As for the encoded datum - it is only mode we use, so we implemented shuffling only for it.

@AIROBOTAI
Copy link

@ducha-aiki Thanks for your prompt reply! But what should I do if my datum is NOT encoded? I have checked the return value of datum.encoded() to find it to be false. So in this case, those lines of code for shuffling will be jumped over.

@ducha-aiki
Copy link
Contributor Author

@AIROBOTAI then you can regenerate LMDB with encoded key, or add same lines to unencoded branch of if :)

@ctensmeyer
Copy link

One alternative to true shuffling is to do random skips. The DataLayer has
a parameter called rand_skip that causes the DB cursor to start in a random
position in the DB. It is easy to extend that concept so that the DB skips
a random number of instances every time it advances its cursor. You can do
so by modifying the last part of the InternalThreadEntry() to run the
cursor->next() in a loop for a random number of iterations. This way the
mini-batch is never composed of the same instances and the probability of
two instances being in the same batch is inversely proportional to the size
of the skip. I find in practice that this doesn't slow down DB reads very
much because it is still kind of sequential.

Hope that helps.

On Sat, Aug 22, 2015 at 11:02 AM, Dmytro Mishkin notifications@github.com
wrote:

@AIROBOTAI https://github.com/AIROBOTAI then you can regenerate LMDB
with encoded key, or add same lines to unencoded branch of if :)


Reply to this email directly or view it on GitHub
#1965 (comment).

@AIROBOTAI
Copy link

@ducha-aiki I have modified some seemingly confusing codes in your code to make it more straight to me. Now it works, thanks again!

@AIROBOTAI
Copy link

@waldol1 thanks for your suggestion! Your method seems more easy to use than the shuffling pool proposed in this pull request. I'd also like to know whether you have tested the test accuracy using your method and how is the performance. @ducha-aiki what's your comments on this new shuffling method? Thanks for your all!

@ctensmeyer
Copy link

I haven't tested this shuffling method with regards to BatchNorm, but it
seemed to help a little on an mnist autoencoder that I was testing when the
training set was small.

On Mon, Aug 24, 2015 at 9:06 AM, AIROBOTAI notifications@github.com wrote:

@waldol1 https://github.com/waldol1 thanks for your suggestion! Your
method seems more easy to use than the shuffling pool proposed in this pull
request. I'd also like to know whether you have tested the test accuracy
using your method and how is the performance. @ducha-aiki
https://github.com/ducha-aiki what's your comments on this new
shuffling method? Thanks for your all!


Reply to this email directly or view it on GitHub
#1965 (comment).

@talda
Copy link

talda commented Aug 30, 2015

@ducha-aiki @shelhamer - is there still plan to pull this in?
If so two suggestions:

  1. separate the data shuffling into another PR, also find a solution that works for all types of data layers.
  2. keep a running average of batch statistics so this PR could be used for classifying a single image.

@bhack
Copy link
Contributor

bhack commented Aug 30, 2015

@talda This is still to young to be reviewed. It is only six months old. :)

@ducha-aiki
Copy link
Contributor Author

@bhack actually, this PR is not needed at all, if believe to Google.
Here are slides http://lsun.cs.princeton.edu/slides/Christian.pdf
Pre-last one says:
"Releasing Pretrained Inception and MultiBox
Academic criticism: Results are hard to reproduce
We will be releasing pretrained Caffe models for:
● GoogLeNet (Inception 5)
● BN-Inception (Inception 6)
● MultiBox-Inception proposal generator (based on
Inception 6)
Contact @Yangqing "

@talda
Copy link

talda commented Aug 30, 2015

@bhack too bad Github does not have like/upvote button for comments. I would defiantly upvote your previous comment.

@erogol
Copy link
Contributor

erogol commented Aug 30, 2015

@ducha-aiki Can I classify a single image by using that PR's modifications and batch normalization ?

@ducha-aiki
Copy link
Contributor Author

@erogol No.

@erogol
Copy link
Contributor

erogol commented Sep 3, 2015

@ducha-aiki I applied moving average and it now works.

@ducha-aiki
Copy link
Contributor Author

Closed, because of #3229 merge.The scale + bias could be taken from #2996

@ducha-aiki ducha-aiki closed this Oct 23, 2015
venkai added a commit to venkai/caffe that referenced this pull request May 2, 2017
Added bn_layer.[cpp/cu] with corresponding hpp file.
Performs batch-normalization with in-place scale/shift.

Originally created by
ducha-aiki: https://github.com/ducha-aiki
ChenglongChen: https://github.com/ChenglongChen
Russell91: https://github.com/Russell91
jjkjkj: https://github.com/jjkjkj

detailed discussion of this implementation can be found at:
BVLC#1965
venkai added a commit to venkai/caffe that referenced this pull request May 5, 2017
Added bn_layer.[cpp/cu] with corresponding hpp file.
Performs batch-normalization with in-place scale/shift.

Originally created by
ducha-aiki: https://github.com/ducha-aiki
ChenglongChen: https://github.com/ChenglongChen
Russell91: https://github.com/Russell91
jjkjkj: https://github.com/jjkjkj

detailed discussion of this implementation can be found at:
BVLC#1965
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.