Minimal PyTorch Implementation of SimSiam from "Exploring Simple Siamese Representation Learning" by Chen et al.
from simsiam.models import SimSiam
from simsiam.losses import negative_cosine_similarity
model = SimSiam(
backbone="resnet50", # encoder network
latent_dim=2048, # predictor network output size
proj_hidden_dim=2048 # projection mlp hidden layer size
pred_hidden_dim=512 # predictor mlp hidden layer size
)
model = model.to("cuda") # use all the parallels
model.train()
transforms = ...
dataset = ...
dataloader = ...
opt = ...
for epoch in range(epochs):
for batch, (x, y) in enumerate(dataloader):
opt.zero_grad()
x1, x2 = transforms(x), transforms(x) # Augment
e1, e2 = model.encode(x1), model.encode(x2) # Encode
z1, z2 = model.project(e1), model.project(e2) # Project
p1, p2 = model.predict(z1), model.predict(z2) # Predict
# Compute loss
loss1 = negative_cosine_similarity(p1, z1)
loss2 = negative_cosine_similarity(p2, z2)
loss = loss1/2 + loss2/2
loss.backward()
opt.step()
# Save encoder weights for later
torch.save(model.encoder.state_dict(), "pretrained.pt")
from simsiam.models import ResNet
# just a wrapper around encoder + linear classifier networks
model = ResNet(
backbone="resnet50", # Same as during pretraining
num_classes=10, # number of output neurons
pretrained=False, # Whether to load pretrained imagenet weights
freeze=True # Freeze the encoder weights (or not)
)
# Load the pretrained weights from SimSiam
model.encoder.load_state_dict(torch.load("pretrained.pt"))
model = model.to("cuda")
model.train()
transforms = ...
dataset = ...
dataloader = ...
opt = optim.SGD(model.parameters())
loss_func = nn.CrossEntropyLoss()
# Train on your small labeled train set
for epoch in range(epochs):
for batch, (x, y) in enumerate(dataloader):
opt.zero_grad()
y_pred = model(x)
loss = loss_func(y_pred, y)
loss.backward()
opt.step()
pip install -r requirements.txt
Modify pretrain.yaml to your liking and run
python pretrain.py --cfg configs/pretrain.json
tensorboard --logdir=logs