-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* simsiam init imp * add doc * fix indent * black reformatted * No grad fixes, detach in sim calc * adjusted loss factor -2 * init dm similar to simclr implementation, revert loss to paper imp * further simclr adjustment * support resnet18 backbone * knn online callback * fit and eval knn on val epoch end * simsiam tests * added import * gpus 0 * scikit-learn req * flake8 * scikit-learn bump version 0.23 * isort * isort * fix detatch * rm deep copy * Add types to knn_online.py * Add types to models.py * Fix types in models.py * Fix tests * Fix types in knn_online.py * Add SimSiam * Apply isort * Import sklearn as optional package * Fix flake8 * Add args via Trainer and make the tests work on cpu * Fix flake8 * chlog * yapf Co-authored-by: Akihiro Nitta <nitta@akihironitta.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
- Loading branch information
1 parent
413b9df
commit 043b557
Showing
10 changed files
with
652 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from typing import Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
from pytorch_lightning import Callback, LightningModule, Trainer | ||
from torch.utils.data import DataLoader | ||
|
||
from pl_bolts.utils import _SKLEARN_AVAILABLE | ||
from pl_bolts.utils.warnings import warn_missing_pkg | ||
|
||
if _SKLEARN_AVAILABLE: | ||
from sklearn.neighbors import KNeighborsClassifier | ||
else: # pragma: no cover | ||
warn_missing_pkg("sklearn", pypi_name="scikit-learn") | ||
|
||
|
||
class KNNOnlineEvaluator(Callback): # pragma: no-cover | ||
""" | ||
Evaluates self-supervised K nearest neighbors. | ||
Example:: | ||
# your model must have 1 attribute | ||
model = Model() | ||
model.num_classes = ... # the num of classes in the model | ||
online_eval = KNNOnlineEvaluator( | ||
num_classes=model.num_classes, | ||
dataset='imagenet' | ||
) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset: str, | ||
num_classes: Optional[int] = None, | ||
) -> None: | ||
""" | ||
Args: | ||
dataset: if stl10, need to get the labeled batch | ||
num_classes: Number of classes | ||
""" | ||
if not _SKLEARN_AVAILABLE: # pragma: no cover | ||
raise ModuleNotFoundError( | ||
"You want to use `KNeighborsClassifier` function from `scikit-learn` which is not installed yet." | ||
) | ||
|
||
super().__init__() | ||
|
||
self.num_classes = num_classes | ||
self.dataset = dataset | ||
|
||
def get_representations(self, pl_module: LightningModule, x: torch.Tensor) -> torch.Tensor: | ||
with torch.no_grad(): | ||
representations = pl_module(x) | ||
representations = representations.reshape(representations.size(0), -1) | ||
return representations | ||
|
||
def get_all_representations( | ||
self, | ||
pl_module: LightningModule, | ||
dataloader: DataLoader, | ||
) -> Tuple[np.ndarray, np.ndarray]: | ||
all_representations = None | ||
ys = None | ||
|
||
for batch in dataloader: | ||
x, y = self.to_device(batch, pl_module.device) | ||
|
||
with torch.no_grad(): | ||
representations = self.get_representations(pl_module, x) | ||
|
||
if all_representations is None: | ||
all_representations = representations.detach() | ||
else: | ||
all_representations = torch.cat([all_representations, representations.detach()]) | ||
|
||
if ys is None: | ||
ys = y | ||
else: | ||
ys = torch.cat([ys, y]) | ||
|
||
return all_representations.cpu().numpy(), ys.cpu().numpy() # type: ignore[union-attr] | ||
|
||
def to_device(self, batch: torch.Tensor, device: Union[str, torch.device]) -> Tuple[torch.Tensor, torch.Tensor]: | ||
# get the labeled batch | ||
if self.dataset == 'stl10': | ||
labeled_batch = batch[1] | ||
batch = labeled_batch | ||
|
||
inputs, y = batch | ||
|
||
# last input is for online eval | ||
x = inputs[-1] | ||
x = x.to(device) | ||
y = y.to(device) | ||
|
||
return x, y | ||
|
||
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: | ||
pl_module.knn_evaluator = KNeighborsClassifier(n_neighbors=self.num_classes) | ||
|
||
train_dataloader = pl_module.train_dataloader() | ||
representations, y = self.get_all_representations(pl_module, train_dataloader) | ||
|
||
# knn fit | ||
pl_module.knn_evaluator.fit(representations, y) # type: ignore[union-attr,operator] | ||
train_acc = pl_module.knn_evaluator.score(representations, y) # type: ignore[union-attr,operator] | ||
|
||
# log metrics | ||
|
||
val_dataloader = pl_module.val_dataloader() | ||
representations, y = self.get_all_representations(pl_module, val_dataloader) # type: ignore[arg-type] | ||
|
||
# knn val acc | ||
val_acc = pl_module.knn_evaluator.score(representations, y) # type: ignore[union-attr,operator] | ||
|
||
# log metrics | ||
pl_module.log('online_knn_train_acc', train_acc, on_step=False, on_epoch=True, sync_dist=True) | ||
pl_module.log('online_knn_val_acc', val_acc, on_step=False, on_epoch=True, sync_dist=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from pl_bolts.utils.self_supervised import torchvision_ssl_encoder | ||
|
||
|
||
class MLP(nn.Module): | ||
|
||
def __init__(self, input_dim: int = 2048, hidden_size: int = 4096, output_dim: int = 256) -> None: | ||
super().__init__() | ||
self.output_dim = output_dim | ||
self.input_dim = input_dim | ||
self.model = nn.Sequential( | ||
nn.Linear(input_dim, hidden_size, bias=False), | ||
nn.BatchNorm1d(hidden_size), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(hidden_size, output_dim, bias=True), | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.model(x) | ||
return x | ||
|
||
|
||
class SiameseArm(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
encoder: Optional[nn.Module] = None, | ||
input_dim: int = 2048, | ||
hidden_size: int = 4096, | ||
output_dim: int = 256, | ||
) -> None: | ||
super().__init__() | ||
|
||
if encoder is None: | ||
encoder = torchvision_ssl_encoder('resnet50') | ||
# Encoder | ||
self.encoder = encoder | ||
# Projector | ||
self.projector = MLP(input_dim, hidden_size, output_dim) | ||
# Predictor | ||
self.predictor = MLP(output_dim, hidden_size, output_dim) | ||
|
||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
y = self.encoder(x)[0] | ||
z = self.projector(y) | ||
h = self.predictor(z) | ||
return y, z, h |
Oops, something went wrong.