Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the model that won the classification task of ImageNet 2013 #33

Closed
wants to merge 1 commit into from
Closed

Conversation

kloudkl
Copy link
Contributor

@kloudkl kloudkl commented Jan 15, 2014

Changes relative to imagenet(_val/_deploy).prototxt: data cropsize 225; conv1 kernelsize7, stride 2; conv2 group 1, stride 2.

This fixes #32.

@shelhamer
Copy link
Member

Thanks for the model definition! Have you trained and evaluated this?

@kloudkl
Copy link
Contributor Author

kloudkl commented Jan 19, 2014

It took Caffe more than 9 days to train the ImageNet dataset on an NVIDIA Tesla K20 GPU which is a quite high end device. Currently I only have access to GTX 560 Ti whose memory is not enough to load the model parameters. Hopefully someone who is interested in reproducing the winning results and have the hardware resources will solve this problem. It would be even better to share the trained model on the website of caffe as the author of Caffe has done.

@sguada
Copy link
Contributor

sguada commented Jan 19, 2014

I have done some initial tests, but the new net is 40% slower to train. So
it will take 2 weeks to train. Now I'm trying other options to speed up the
training, like changing the learning rates and weight decay.
I will share the network if it gets good results.

Sergio
On Jan 18, 2014 6:34 PM, "kloudkl" notifications@github.com wrote:

It took Caffe more than 9 days to train the ImageNet dataset on a NVIDIA
Tesla K20 GPU which is a quite high end device. Currently I only have
access to GTX 560 Ti whose memory is not enough to load the model
parameters. Hopefully someone who is both interested in reproducing the
winning results and have the hardware resources will solve this problem. It
would be even better to share the trained model on the website of caffehttp://caffe.berkeleyvision.org/imagenet_pretrained.htmlas the author of Caffe has done.


Reply to this email directly or view it on GitHubhttps://github.com//pull/33#issuecomment-32699450
.

@kloudkl
Copy link
Contributor Author

kloudkl commented Jan 23, 2014

Thank you for your effort, @sguada! What are the results of your initial tests? If it takes too much time, you probably would like to first train and test on a smaller dataset such as a portion of of ImageNet classification task's dataset to verify that the model would work as Zeiler described. If its performance is inferior to expected, it is necessary to do some debugging such as checking whether or not conv2 perform group convolution. Maybe all the exact implementation details cannot be derived from paper.

@SWu
Copy link

SWu commented Jan 23, 2014

A couple of considerations (basing off http://www.matthewzeiler.com/pubs/arxive2013/arxive2013.pdf):

The imagenet preprocessing proposed by the caffe tutorial resizes to 256x256 without preserving aspect ratio; Zeiler uses 256 min dimension rescaling and center crop of 256x256.

conv2, conv4, and conv5 layers in the original model definition have 'group: 2' turned on. If I'm not mistaken, this is the sparse architecture used by Krizhevsky since they split the training over 2 gpus. Zeiler mentions that he uses dense connections instead.

Zeiler mentions the use of 224x224 crops, but that results in one less 1st layer filter per dimension than reported in his paper. Zeiler also initializes all biases to 0 instead of alternating 0 and 1 in different layers that Krizhevsky used. I haven't checked, but the padding for different layers may also need to be changed to match the layer dimensions reported in the paper.

Most importantly, Zeiler mentions that they 'renormalize each filter in the convolutional layers whose RMS value exceeds a fixed radius of 10^-1 to this fixed radius' as key to preventing individual filters from dominating the first layer. I don't think this is implemented currently (in my understanding this is different from local contrast normalization layers, and instead normalizes the convolutional filter weights to not exceed some variance).

Finally, the numbers reported in that paper are still quite significantly behind the actual winning system (and he mentions that the performance in the paper has been surpassed in the ILSVRC 2013 competition), so there are probably more tweaks that he made which are unpublished.

I have been training a network based on my interpretation of the above (without the convolutional filter RMS normalization) but it is very slow on my Tesla M2090. After about 3 weeks and 60 epochs, top 1 error for validation is about 41.5%, which is still higher than Krizhevsky's result; we will see if that improves much further.

Does anyone else have insight into the details of the conv RMS normalization?

@sguada
Copy link
Contributor

sguada commented Jan 23, 2014

@kloudkl I have been doing some small training, I mean training for a few epochs and compare the validation performance to the log we have of our training of Krizhevsky network.

@SWu o fix the misalignment between the crop 224x224 and the first layer filter reported by Zeiler in his paper one could either pad 1 pixel (which is done automatically in cuda-convenet but not in caffe) or just do a crop of 225x225 as @kloudkl suggested.

@SWu it seems to me that the top1 validation you are getting after 60 epochs is very low, as you can see in the figure below, the Krizhevsky network can achieve top1 validation error of 0.4058 in 20 epochs, 0.5529 in 40 epochs and 0.574 in 60 epochs. So @SWu I don't think your network is going to be able to improve much further after 60 epochs.

alexnet_log

It is true that I this point caffe don't have layer to renormalize the filters as Zeiler described in his paper, so that could be the reason why the performance is worse. We could try to add it, so if you want to work on this let me know.

What I have mostly been doing is adjusting the base_lr, gamma, weight_decay and stepsize of the solver to adjust for the change in the batch size from 256 to 128 and to improve the speed of training. So far I'm not able to match the speed of training but in my second attempt is getting closer. See figure below

zeiler_log

@SWu
Copy link

SWu commented Jan 23, 2014

@sguada When I say top1 validation error, I mean (1 - accuracy) reported by the accuracy layer. So it is actually fluctuating at around 58.5% accuracy, which is ~1% higher than your log of Krizhevsky network at that point.

@sguada
Copy link
Contributor

sguada commented Jan 23, 2014

@SWu sorry I misread your previous post. That is not bad then, according to Zeiler paper his network got 38.4 top1 validation error after training for 70 epochs, it seems you are getting close. But if you haven't changed the learning rates, stepsize, and gamma that we have for the alexnet in caffe, I would not expect to improve much after 60 epochs.

Do you have a log file of your training? If you don't mind could you share it? Or your prototxt files?

@SWu
Copy link

SWu commented Jan 23, 2014

The diff of the prototxt: http://pastebin.com/M49MTupT

Changes to the solver prototxt:
stepsize: 200,000 instead of 100,000 since the batch size was halved
test_iter/test_interval: 2000 instead of 1000
max_iter: 900,000 instead of 4,500,000

By the way, am I correct in thinking that the 4,500,000 max_iter in the imagenet_solver.prototxt is a typo and it should actually be 450,000? Imagenet has ~1,280,000 images, so with batchsize of 256, every 5000 iterations is an epoch, and 90 epochs should be 450,000.

I don't have a log file since I'm printing directly to stderr, but I am very close to your numbers for 20 and 40 epochs, but ~1% higher for 60 epochs (actually, I haven't quite reached 60 yet, it's 75% through the 59th epoch).

@sguada
Copy link
Contributor

sguada commented Jan 24, 2014

@SWu yeah there was a typo in the imagenet_solver.prototxt it should be 450,000 which represent 90 epochs.

Before I was also printing to stderr but now I always redirect it to file by adding " 2> log.txt" so I can look at it later and analyze it later. If you have been saving the snapshots during the training you could test them later and see how the performance was changing.

It is interesting to know that at 20 and 40 epochs the performance was close but you needed to wait until 60 epochs to see an improvement. I would have thought that the improvement will be there from the beginning.

The other thing I did to adjust for the batch being half was reducing the weight_decay by half, but not sure how much that will affect to the final performance.

shelhamer added a commit that referenced this pull request Jan 26, 2014
max iteration no. is 450,000 (= 90 epochs)

caught by @SWu
#33 (comment)
@sguada
Copy link
Contributor

sguada commented Feb 28, 2014

@SWu did you finish to train the network? Could you share your results?

@kloudkl
Copy link
Contributor Author

kloudkl commented Feb 28, 2014

@SWu, @sguada, whoever first shares the new record-breaking model please open a new pull request and I will close this one right after that.

@SWu
Copy link

SWu commented Apr 22, 2014

Finally got my hands on a better gpu (a k40 :) ) and was able to retry some things including fixing the order of the LRN and MaxPool ordering and tweaking the padding. This gave ~59.95% validation accuracy after about 2 weeks. One observation with these new changes is that I start seeing better validation accuracy immediately even in the first few epochs, compared to your logs. See the prototxt and validaton log here: http://pastebin.com/hb2Tp3rd

This still does not have the convolution re-normalization described in Zeiler's paper. Is someone working on that currently?

@by-liu
Copy link

by-liu commented May 20, 2014

@SWu I cannot see anything in your link. Could you share your prototxt file again?

@apark263
Copy link

For filter renormalization, is this just a matter of dividing the coefficients by a term such that the l2 norm of each filter is constant across all filters at each point during training?

@anguyen8
Copy link

May anyone update the status of this PR implementing the 2013 winning model?
I am wondering with the current layers available in Caffe, is it possible to replicate Zeiler's Net?
I greatly appreciate any info.

@shelhamer
Copy link
Member

Caffe is missing a single operation in the ZF net [1] for filter regularization:

[...] we renormalize each filter in the convolutional layers whose RMS value exceeds a fixed radius of 10−1 to this fixed radius

If you add this filter RMS cap layer to Caffe the model can be trained as described in the paper for public reference.

[1] M. Zeiler and R. Fergus. Visualizing and understanding convolutional networks. Arxiv. org, 1131:v3, 2013.

@anguyen8
Copy link

@shelhamer Thanks much Evan! I really appreciate the pointer.
Also, according to Matt Zeiler, the actual 2013 winning model (Clarifai) will not be published, hence the model in this paper is the closest we can get.

@bhack
Copy link
Contributor

bhack commented Aug 25, 2014

Probably @Yangqing has some news from Imagenet 2014

@shelhamer
Copy link
Member

This is superseded by VGG's devil models from BMVC14 now in the model zoo and readied for use by #1138. Thanks @ksimonyan and VGG for sharing the models!

@shelhamer shelhamer closed this Sep 25, 2014
mitmul pushed a commit to mitmul/caffe that referenced this pull request Sep 30, 2014
max iteration no. is 450,000 (= 90 epochs)

caught by @SWu
BVLC#33 (comment)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement the model that won the classification task of ImageNet 2013
8 participants