Skip to content

PyTorch Implementation of SimSiam from "Exploring Simple Siamese Representation Learning" by Chen et al.

License

Notifications You must be signed in to change notification settings

isaaccorley/simsiam-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

simsiam-pytorch

Minimal PyTorch Implementation of SimSiam from "Exploring Simple Siamese Representation Learning" by Chen et al.

Load and train on a custom dataset

from simsiam.models import SimSiam
from simsiam.losses import negative_cosine_similarity

model = SimSiam(
    backbone="resnet50",    # encoder network
    latent_dim=2048,        # predictor network output size
    proj_hidden_dim=2048    # projection mlp hidden layer size
    pred_hidden_dim=512     # predictor mlp hidden layer size
)
model = model.to("cuda")    # use all the parallels
model.train()

transforms = ...
dataset = ...
dataloader = ...
opt = ...

for epoch in range(epochs):
    for batch, (x, y) in enumerate(dataloader):
        opt.zero_grad()

        x1, x2 = transforms(x), transforms(x)           # Augment
        e1, e2 = model.encode(x1), model.encode(x2)     # Encode
        z1, z2 = model.project(e1), model.project(e2)   # Project
        p1, p2 = model.predict(z1), model.predict(z2)   # Predict

        # Compute loss
        loss1 = negative_cosine_similarity(p1, z1)
        loss2 = negative_cosine_similarity(p2, z2)
        loss = loss1/2 + loss2/2
        loss.backward()
        opt.step()

# Save encoder weights for later
torch.save(model.encoder.state_dict(), "pretrained.pt")

Use pretrained weights in a classifier

from simsiam.models import ResNet

# just a wrapper around encoder + linear classifier networks
model = ResNet(
    backbone="resnet50",    # Same as during pretraining
    num_classes=10,         # number of output neurons 
    pretrained=False,       # Whether to load pretrained imagenet weights
    freeze=True             # Freeze the encoder weights (or not)
)

# Load the pretrained weights from SimSiam
model.encoder.load_state_dict(torch.load("pretrained.pt"))

model = model.to("cuda")
model.train()

transforms = ...
dataset = ...
dataloader = ...
opt = optim.SGD(model.parameters())
loss_func = nn.CrossEntropyLoss()

# Train on your small labeled train set
for epoch in range(epochs):
    for batch, (x, y) in enumerate(dataloader):
        opt.zero_grad()
        y_pred = model(x)
        loss = loss_func(y_pred, y)
        loss.backward()
        opt.step()

Install dependencies

pip install -r requirements.txt

Train on STL-10 dataset

Modify pretrain.yaml to your liking and run

python pretrain.py --cfg configs/pretrain.json

View logs in tensorboard

tensorboard --logdir=logs

About

PyTorch Implementation of SimSiam from "Exploring Simple Siamese Representation Learning" by Chen et al.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages