Minimal implementation of SimSiam (Exploring Simple Siamese Representation Learning by Xinlei Chen & Kaiming He) in TensorFlow 2. For an introduction, please see this blog post: Self-supervised contrastive learning with SimSiam.
The purpose of this repository is to demonstrate the workflow of SimSiam and NOT to implement it note to note and at the same time I will try not to miss out on the major bits discussed in the paper. For that matter, I'll be using the Flowers dataset.
Following depicts the workflow of SimSiam (taken from the paper) -
The authors have also provided a PyTorch-like psuedocode in the paper (how cool!) -
# f: backbone + projection mlp
# h: prediction mlp
for x in loader: # load a minibatch x with n samples
x1, x2 = aug(x), aug(x) # random augmentation
z1, z2 = f(x1), f(x2) # projections, n-by-d
p1, p2 = h(z1), h(z2) # predictions, n-by-d
L = D(p1, z2)/2 + D(p2, z1)/2 # loss
L.backward() # back-propagate
update(f, h) # SGD update
def D(p, z): # negative cosine similarity
z = z.detach() # stop gradient
p = normalize(p, dim=1) # l2-normalize
z = normalize(z, dim=1) # l2-normalize
return -(p*z).sum(dim=1).mean()
The authors emphasize the stop_gradient
operation that helps the network to avoid collapsing solutions. Further details about this are available in the paper. SimSiam eliminates the need for using large batch sizes, momentum encoders, memory banks, negative samples, etc. that are important components of the modern self-supervised learning frameworks for visual recognition. This makes SimSiam an easily approachable framework for practical problems.
SimSiam_Pre_training.ipynb
: Pre-trains a ResNet50 using SimSiam.SimSiam_Evaluation.ipynb
: Evaluates (linear evaluation) ResNet50 as pre-trained inSimSiam_Pre_training.ipynb
.
Pre-training Schedule | Validation Accuracy (Linear Evaluation) |
---|---|
50 epochs | 45.64% |
75 epochs | 44.91% |
I think with further hyperparameter-tuning and regularization these scores can be improved.
Supervised training (results are taken from here and here):
Training Type | Validation Accuracy (Linear Evaluation) |
---|---|
Supervised ImageNet-trained ResNet50 Features | 48.36% |
From Scratch Training with ResNet50 | 63.64% |
The figure below shows the training loss plots from two different pre-training schedules (50 epochs and 75 epochs) -
We see that the loss gets plateaued after 35 epochs. We can experiment with the following components to further improve this -
- data augmentation pipeline
- architectures of the two MLP heads
- learning rate schedule used during pre-training
and so on.
- 50 epochs
- 75 epochs
Thanks to Connor Shorten's video on the paper that helped in understanding the paper briefly. Thanks to the ML-GDE program for providing GCP Credits that helped in preparing the experiments.