Skip to content

Commit

Permalink
add SimSiam (#407)
Browse files Browse the repository at this point in the history
* simsiam init imp

* add doc

* fix indent

* black reformatted

* No grad fixes, detach in sim calc

* adjusted loss factor -2

* init dm similar to simclr implementation, revert loss to paper imp

* further simclr adjustment

* support resnet18 backbone

* knn online callback

* fit and eval knn on val epoch end

* simsiam tests

* added import

* gpus 0

* scikit-learn req

* flake8

* scikit-learn bump version 0.23

* isort

* isort

* fix detatch

* rm deep copy

* Add types to knn_online.py

* Add types to models.py

* Fix types in models.py

* Fix tests

* Fix types in knn_online.py

* Add SimSiam

* Apply isort

* Import sklearn as optional package

* Fix flake8

* Add args via Trainer and make the tests work on cpu

* Fix flake8

* chlog

* yapf

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
3 people authored Jan 17, 2021
1 parent 413b9df commit 043b557
Show file tree
Hide file tree
Showing 10 changed files with 652 additions and 7 deletions.
9 changes: 3 additions & 6 deletions .github/workflows/code-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ jobs:
pip --version
shell: bash
- name: PEP8
run: |
flake8 .
run: flake8 .

format-check-yapf:
runs-on: ubuntu-20.04
Expand All @@ -38,8 +37,7 @@ jobs:
pip --version
shell: bash
- name: yapf
run: |
yapf --diff --parallel --recursive .
run: yapf --diff --parallel --recursive .

imports-check-isort:
runs-on: ubuntu-20.04
Expand Down Expand Up @@ -67,5 +65,4 @@ jobs:
pip install mypy
pip list
- name: mypy
run: |
mypy
run: mypy
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added flags to datamodules ([#388](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/388))
- Added metric GIoU ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347))
- Added Intersection over Union Metric/Loss ([#469](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/469))
- Added SimSiam model ([#407](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/407))

### Changed

Expand Down
121 changes: 121 additions & 0 deletions pl_bolts/callbacks/knn_online.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from typing import Optional, Tuple, Union

import numpy as np
import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from torch.utils.data import DataLoader

from pl_bolts.utils import _SKLEARN_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _SKLEARN_AVAILABLE:
from sklearn.neighbors import KNeighborsClassifier
else: # pragma: no cover
warn_missing_pkg("sklearn", pypi_name="scikit-learn")


class KNNOnlineEvaluator(Callback): # pragma: no-cover
"""
Evaluates self-supervised K nearest neighbors.
Example::
# your model must have 1 attribute
model = Model()
model.num_classes = ... # the num of classes in the model
online_eval = KNNOnlineEvaluator(
num_classes=model.num_classes,
dataset='imagenet'
)
"""

def __init__(
self,
dataset: str,
num_classes: Optional[int] = None,
) -> None:
"""
Args:
dataset: if stl10, need to get the labeled batch
num_classes: Number of classes
"""
if not _SKLEARN_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use `KNeighborsClassifier` function from `scikit-learn` which is not installed yet."
)

super().__init__()

self.num_classes = num_classes
self.dataset = dataset

def get_representations(self, pl_module: LightningModule, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
representations = pl_module(x)
representations = representations.reshape(representations.size(0), -1)
return representations

def get_all_representations(
self,
pl_module: LightningModule,
dataloader: DataLoader,
) -> Tuple[np.ndarray, np.ndarray]:
all_representations = None
ys = None

for batch in dataloader:
x, y = self.to_device(batch, pl_module.device)

with torch.no_grad():
representations = self.get_representations(pl_module, x)

if all_representations is None:
all_representations = representations.detach()
else:
all_representations = torch.cat([all_representations, representations.detach()])

if ys is None:
ys = y
else:
ys = torch.cat([ys, y])

return all_representations.cpu().numpy(), ys.cpu().numpy() # type: ignore[union-attr]

def to_device(self, batch: torch.Tensor, device: Union[str, torch.device]) -> Tuple[torch.Tensor, torch.Tensor]:
# get the labeled batch
if self.dataset == 'stl10':
labeled_batch = batch[1]
batch = labeled_batch

inputs, y = batch

# last input is for online eval
x = inputs[-1]
x = x.to(device)
y = y.to(device)

return x, y

def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
pl_module.knn_evaluator = KNeighborsClassifier(n_neighbors=self.num_classes)

train_dataloader = pl_module.train_dataloader()
representations, y = self.get_all_representations(pl_module, train_dataloader)

# knn fit
pl_module.knn_evaluator.fit(representations, y) # type: ignore[union-attr,operator]
train_acc = pl_module.knn_evaluator.score(representations, y) # type: ignore[union-attr,operator]

# log metrics

val_dataloader = pl_module.val_dataloader()
representations, y = self.get_all_representations(pl_module, val_dataloader) # type: ignore[arg-type]

# knn val acc
val_acc = pl_module.knn_evaluator.score(representations, y) # type: ignore[union-attr,operator]

# log metrics
pl_module.log('online_knn_train_acc', train_acc, on_step=False, on_epoch=True, sync_dist=True)
pl_module.log('online_knn_val_acc', val_acc, on_step=False, on_epoch=True, sync_dist=True)
2 changes: 2 additions & 0 deletions pl_bolts/models/self_supervised/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator # noqa: F401
from pl_bolts.models.self_supervised.moco.moco2_module import MocoV2 # noqa: F401
from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR # noqa: F401
from pl_bolts.models.self_supervised.simsiam.simsiam_module import SimSiam # noqa: F401
from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner # noqa: F401
from pl_bolts.models.self_supervised.swav.swav_module import SwAV # noqa: F401

Expand All @@ -34,6 +35,7 @@
"SSLEvaluator",
"MocoV2",
"SimCLR",
"SimSiam",
"SSLFineTuner",
"SwAV",
]
Empty file.
51 changes: 51 additions & 0 deletions pl_bolts/models/self_supervised/simsiam/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Optional, Tuple

import torch
from torch import nn

from pl_bolts.utils.self_supervised import torchvision_ssl_encoder


class MLP(nn.Module):

def __init__(self, input_dim: int = 2048, hidden_size: int = 4096, output_dim: int = 256) -> None:
super().__init__()
self.output_dim = output_dim
self.input_dim = input_dim
self.model = nn.Sequential(
nn.Linear(input_dim, hidden_size, bias=False),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, output_dim, bias=True),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.model(x)
return x


class SiameseArm(nn.Module):

def __init__(
self,
encoder: Optional[nn.Module] = None,
input_dim: int = 2048,
hidden_size: int = 4096,
output_dim: int = 256,
) -> None:
super().__init__()

if encoder is None:
encoder = torchvision_ssl_encoder('resnet50')
# Encoder
self.encoder = encoder
# Projector
self.projector = MLP(input_dim, hidden_size, output_dim)
# Predictor
self.predictor = MLP(output_dim, hidden_size, output_dim)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
y = self.encoder(x)[0]
z = self.projector(y)
h = self.predictor(z)
return y, z, h
Loading

0 comments on commit 043b557

Please sign in to comment.