Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SimSiam #407

Merged
merged 34 commits into from
Jan 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
45a708b
simsiam init imp
zlapp Nov 25, 2020
42b322f
add doc
zlapp Nov 25, 2020
790fe60
fix indent
zlapp Nov 26, 2020
58ffec2
black reformatted
zlapp Nov 28, 2020
6abbfed
No grad fixes, detach in sim calc
zlapp Nov 29, 2020
7d6737b
adjusted loss factor -2
zlapp Nov 30, 2020
0e18985
init dm similar to simclr implementation, revert loss to paper imp
zlapp Dec 7, 2020
467ac4d
further simclr adjustment
zlapp Dec 7, 2020
3fb3c48
support resnet18 backbone
zlapp Dec 7, 2020
28cfd69
knn online callback
zlapp Dec 7, 2020
c0e1f35
fit and eval knn on val epoch end
zlapp Dec 7, 2020
ade7ad5
simsiam tests
zlapp Dec 13, 2020
6cd4985
added import
zlapp Dec 13, 2020
9a091a8
gpus 0
zlapp Dec 14, 2020
3b8dada
scikit-learn req
zlapp Dec 14, 2020
5ab5f4b
flake8
zlapp Dec 14, 2020
81fe48e
scikit-learn bump version 0.23
zlapp Dec 14, 2020
2ceb2eb
isort
zlapp Dec 14, 2020
4b56fcd
isort
zlapp Dec 14, 2020
6c63f0b
fix detatch
zlapp Dec 21, 2020
696f271
rm deep copy
zlapp Dec 27, 2020
bc9c240
Add types to knn_online.py
akihironitta Jan 7, 2021
217b36f
Add types to models.py
akihironitta Jan 7, 2021
3764220
Fix types in models.py
akihironitta Jan 7, 2021
3214c85
Fix tests
akihironitta Jan 7, 2021
13f5af3
Fix types in knn_online.py
akihironitta Jan 7, 2021
0083276
Add SimSiam
akihironitta Jan 7, 2021
829274e
Apply isort
akihironitta Jan 7, 2021
7819820
Import sklearn as optional package
akihironitta Jan 7, 2021
75df5b9
Fix flake8
akihironitta Jan 9, 2021
adf184d
Add args via Trainer and make the tests work on cpu
akihironitta Jan 9, 2021
715d934
Fix flake8
akihironitta Jan 9, 2021
e648eec
chlog
Borda Jan 17, 2021
1717c0a
yapf
Borda Jan 17, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions .github/workflows/code-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ jobs:
pip --version
shell: bash
- name: PEP8
run: |
flake8 .
run: flake8 .

format-check-yapf:
runs-on: ubuntu-20.04
Expand All @@ -38,8 +37,7 @@ jobs:
pip --version
shell: bash
- name: yapf
run: |
yapf --diff --parallel --recursive .
run: yapf --diff --parallel --recursive .

imports-check-isort:
runs-on: ubuntu-20.04
Expand Down Expand Up @@ -67,5 +65,4 @@ jobs:
pip install mypy
pip list
- name: mypy
run: |
mypy
run: mypy
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added flags to datamodules ([#388](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/388))
- Added metric GIoU ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347))
- Added Intersection over Union Metric/Loss ([#469](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/469))
- Added SimSiam model ([#407](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/407))

### Changed

Expand Down
121 changes: 121 additions & 0 deletions pl_bolts/callbacks/knn_online.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from typing import Optional, Tuple, Union

import numpy as np
import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from torch.utils.data import DataLoader

from pl_bolts.utils import _SKLEARN_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _SKLEARN_AVAILABLE:
from sklearn.neighbors import KNeighborsClassifier
else: # pragma: no cover
warn_missing_pkg("sklearn", pypi_name="scikit-learn")


class KNNOnlineEvaluator(Callback): # pragma: no-cover
"""
Evaluates self-supervised K nearest neighbors.

Example::

# your model must have 1 attribute
model = Model()
model.num_classes = ... # the num of classes in the model

online_eval = KNNOnlineEvaluator(
num_classes=model.num_classes,
dataset='imagenet'
)

"""

def __init__(
self,
dataset: str,
num_classes: Optional[int] = None,
) -> None:
"""
Args:
dataset: if stl10, need to get the labeled batch
num_classes: Number of classes
"""
if not _SKLEARN_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use `KNeighborsClassifier` function from `scikit-learn` which is not installed yet."
)

super().__init__()

self.num_classes = num_classes
self.dataset = dataset

def get_representations(self, pl_module: LightningModule, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
representations = pl_module(x)
representations = representations.reshape(representations.size(0), -1)
return representations

def get_all_representations(
self,
pl_module: LightningModule,
dataloader: DataLoader,
) -> Tuple[np.ndarray, np.ndarray]:
all_representations = None
ys = None

for batch in dataloader:
x, y = self.to_device(batch, pl_module.device)

with torch.no_grad():
representations = self.get_representations(pl_module, x)

if all_representations is None:
all_representations = representations.detach()
else:
all_representations = torch.cat([all_representations, representations.detach()])

if ys is None:
ys = y
else:
ys = torch.cat([ys, y])

return all_representations.cpu().numpy(), ys.cpu().numpy() # type: ignore[union-attr]

def to_device(self, batch: torch.Tensor, device: Union[str, torch.device]) -> Tuple[torch.Tensor, torch.Tensor]:
# get the labeled batch
if self.dataset == 'stl10':
labeled_batch = batch[1]
batch = labeled_batch

inputs, y = batch

# last input is for online eval
x = inputs[-1]
x = x.to(device)
y = y.to(device)

return x, y

def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
pl_module.knn_evaluator = KNeighborsClassifier(n_neighbors=self.num_classes)

train_dataloader = pl_module.train_dataloader()
representations, y = self.get_all_representations(pl_module, train_dataloader)

# knn fit
pl_module.knn_evaluator.fit(representations, y) # type: ignore[union-attr,operator]
train_acc = pl_module.knn_evaluator.score(representations, y) # type: ignore[union-attr,operator]

# log metrics

val_dataloader = pl_module.val_dataloader()
representations, y = self.get_all_representations(pl_module, val_dataloader) # type: ignore[arg-type]

# knn val acc
val_acc = pl_module.knn_evaluator.score(representations, y) # type: ignore[union-attr,operator]

# log metrics
pl_module.log('online_knn_train_acc', train_acc, on_step=False, on_epoch=True, sync_dist=True)
pl_module.log('online_knn_val_acc', val_acc, on_step=False, on_epoch=True, sync_dist=True)
2 changes: 2 additions & 0 deletions pl_bolts/models/self_supervised/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator # noqa: F401
from pl_bolts.models.self_supervised.moco.moco2_module import MocoV2 # noqa: F401
from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR # noqa: F401
from pl_bolts.models.self_supervised.simsiam.simsiam_module import SimSiam # noqa: F401
from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner # noqa: F401
from pl_bolts.models.self_supervised.swav.swav_module import SwAV # noqa: F401

Expand All @@ -34,6 +35,7 @@
"SSLEvaluator",
"MocoV2",
"SimCLR",
"SimSiam",
"SSLFineTuner",
"SwAV",
]
Empty file.
51 changes: 51 additions & 0 deletions pl_bolts/models/self_supervised/simsiam/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Optional, Tuple

import torch
from torch import nn

from pl_bolts.utils.self_supervised import torchvision_ssl_encoder


class MLP(nn.Module):

def __init__(self, input_dim: int = 2048, hidden_size: int = 4096, output_dim: int = 256) -> None:
super().__init__()
self.output_dim = output_dim
self.input_dim = input_dim
self.model = nn.Sequential(
nn.Linear(input_dim, hidden_size, bias=False),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, output_dim, bias=True),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.model(x)
return x


class SiameseArm(nn.Module):

def __init__(
self,
encoder: Optional[nn.Module] = None,
input_dim: int = 2048,
hidden_size: int = 4096,
output_dim: int = 256,
) -> None:
super().__init__()

if encoder is None:
encoder = torchvision_ssl_encoder('resnet50')
# Encoder
self.encoder = encoder
# Projector
self.projector = MLP(input_dim, hidden_size, output_dim)
# Predictor
self.predictor = MLP(output_dim, hidden_size, output_dim)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
y = self.encoder(x)[0]
z = self.projector(y)
h = self.predictor(z)
return y, z, h
Loading