Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SimSiam #407

Merged
merged 34 commits into from
Jan 17, 2021
Merged

SimSiam #407

merged 34 commits into from
Jan 17, 2021

Conversation

zlapp
Copy link
Contributor

@zlapp zlapp commented Nov 25, 2020

What does this PR do?

Implement https://arxiv.org/pdf/2011.10566v1.pdf
Largely based on https://github.com/lucidrains/byol-pytorch extension of BYOL to support SimSiam.
I used pl-bolts BYOL implementation as a reference.
Colab gist for testing on cifar-10 https://gist.github.com/zlapp/c35b8c97d4f6537f21aa07bbc37959c9
Discussed on slack channel https://pytorch-lightning.slack.com/archives/C010PRC9M2R/p1606329394008100

Also adds KNN online evaluation callback

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you to create a separate PR for every change.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@pep8speaks
Copy link

pep8speaks commented Nov 25, 2020

Hello @zlapp! Thanks for updating this PR.

Line 144:72: W504 line break after binary operator

Comment last updated at 2021-01-17 20:56:04 UTC

@codecov
Copy link

codecov bot commented Nov 26, 2020

Codecov Report

Merging #407 (1717c0a) into master (413b9df) will decrease coverage by 0.54%.
The diff coverage is 76.55%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #407      +/-   ##
==========================================
- Coverage   79.49%   78.95%   -0.55%     
==========================================
  Files         102      105       +3     
  Lines        5912     6121     +209     
==========================================
+ Hits         4700     4833     +133     
- Misses       1212     1288      +76     
Flag Coverage Δ
cpu 25.66% <22.00%> (-0.12%) ⬇️
pytest 25.66% <22.00%> (-0.12%) ⬇️
unittests 78.48% <76.55%> (-0.53%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
...s/models/self_supervised/simsiam/simsiam_module.py 72.67% <72.67%> (ø)
pl_bolts/callbacks/knn_online.py 90.00% <90.00%> (ø)
pl_bolts/models/self_supervised/simsiam/models.py 96.15% <96.15%> (ø)
pl_bolts/models/self_supervised/__init__.py 100.00% <100.00%> (ø)
pl_bolts/datasets/cifar10_dataset.py 71.73% <0.00%> (-26.09%) ⬇️
pl_bolts/datasets/base_dataset.py 81.81% <0.00%> (-13.64%) ⬇️
...l_bolts/models/rl/vanilla_policy_gradient_model.py 96.36% <0.00%> (+2.72%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 413b9df...1717c0a. Read the comment docs.

@zlapp
Copy link
Contributor Author

zlapp commented Nov 28, 2020

Initial training results on cifar10 comparing SimSiam (orange) to BYOL (blue).

image
image

I am guessing there is a bug in the initial implementation.

@zlapp
Copy link
Contributor Author

zlapp commented Nov 30, 2020

Better online accuracies after fix of detach SimSiam (red) to BYOL (blue):
image
image

Noticed SimSiam uses a factor of -1 times cosine similarity vs BYOL which uses a factor of -2

SimSiam loss:
image

BYOL loss:
image

@zlapp
Copy link
Contributor Author

zlapp commented Nov 30, 2020

Changing the loss factor for SimSiam to -2 (and not dividing by 2) to be like BYOL improved results (nearly identical).
Seems like the loss factor needs further investigation.
SimSiam (light blue) to BYOL (blue):
image
image

@zlapp
Copy link
Contributor Author

zlapp commented Dec 9, 2020

Reached 88% accuracy on CIFAR10 after 800 epochs:

online_simsiam

Command:
python simsiam_module.py --dataset cifar10 --optimizer sgd --batch_size 512 --learning_rate 0.03 --max_epochs 800 --weight_decay 0.0005 --arch resnet18 --hidden_mlp 512 --online_ft

@ananyahjha93 ananyahjha93 self-requested a review December 14, 2020 17:21
Comment on lines 181 to 189
# Image 1 to image 2 loss
_, z1, h1 = self.online_network(img_1)
_, z2, h2 = self.target_network(img_2)
loss_a = -1.0 * self.cosine_similarity(h1, z2)

# Image 2 to image 1 loss
_, z1, h1 = self.online_network(img_2)
_, z2, h2 = self.target_network(img_1)
loss_b = -1.0 * self.cosine_similarity(h1, z2)
Copy link

@haideraltahan haideraltahan Dec 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on SimSiam's pseudocode and other implementations (1, 2), shouldn't there be just one network without a deep copy target network?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Believe it is equivalent. For example if you refer to BYOL implementation in pl bolts there is a deep copy here there isn't https://github.com/lucidrains/byol-pytorch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nonetheless, great job on the implementation. It might be worth investigating in the future if there is any difference in performance between the two methods. As with this approach, maybe there is more memory usage?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a deep copy instead of using the same network twice you use ~3GB gpu extra memory with resnet18. The two versions are about equally fast on cifar10, but I think that's because the gpu is memorybound on the task.
Given that the final performane ends up the same, I guess removing the copied network would be a good idea as it's less wasteful and scales better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @MikkelAntonsen.
Could you please share the version you ran with plot of results on CIFAR10?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing @MikkelAntonsen.
I just pushed a commit with the changes you suggested in the gist.
Great job making the improvements in efficiency while maintaining accuracy.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the SSLOnlineEvaluator does an additional forward pass per train_batch to get embeddings. It would be possible to return the encoder output from training_step() and which would be accessible in the outputs argument in SSLOnlineEvaluator, AFAIK. This is unfortunate because it couples the implementation of SSLOnlineEvaluator with the networks that uses it. But if we are to reproduce results on imagenet, is reducing the number of forward passes by 1/3 negligible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MikkelAntonsen I believe this is related more to SSL capabilities in general of pl_bolts so might be better to open a separate issue since this isn't only effecting SimSiam (tagging @ananyahjha93).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MikkelAntonsen Maybe the byol implementation can also benefit from using only one network (without deepcopy) and using detach() to control the gradient flow?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just read the BYOL paper abstract and it seems like the online network and target network use a different set of weights, so I'm not sure how you could share a network. If you look at figure 2 in the paper, it does seem like they use the stop gradient trick, but only for the target network. Do you see any ways to incorporate simsiam ideas into BYOL without the implemention just ending up as simsiam?

@Borda
Copy link
Member

Borda commented Jan 2, 2021

@zlapp how is it going here, is it still WIP?

@zlapp
Copy link
Contributor Author

zlapp commented Jan 3, 2021

@zlapp how is it going here, is it still WIP?

Hi @Borda, based on the results here #407 (comment) I believe the PR is ready to be merged.

@akihironitta akihironitta changed the title [wip] SimSiam SimSiam Jan 4, 2021
Copy link
Contributor

@akihironitta akihironitta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wjn0
Copy link
Contributor

wjn0 commented Jan 7, 2021

@zlapp as a potential user thanks so much for your (+ the PR reviewers!) work, this looks great. I do have one question: the paper notes in appendix B that one difference between SimSiam and BYOL is the bottleneck structure in the predictor. Namely, the hidden dimension of the predictor MLP should be 1/4 of the output dimension. So, for example, with the default prediction space dimension of 256, I think the hidden dim should be 64. This apparently helps with training stability. Do you agree with my reading of the paper? If so, I'm wondering if this should be the default in the bolt as well?

@akihironitta
Copy link
Contributor

@zlapp Would you mind having a look at zlapp#1?

Could you also have a look at https://github.com/zlapp/pytorch-lightning-bolts/pull/2?

@Borda Borda added the Priority High priority task label Jan 17, 2021
@Borda Borda requested a review from akihironitta January 17, 2021 20:41
Copy link
Contributor Author

@zlapp zlapp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@summelon
Copy link

The same question as @wjn0.
Any update?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model Priority High priority task
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants