From 45a708b3a48aa4cc410ba7bad118c639d09a1ca6 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Wed, 25 Nov 2020 23:31:27 +0200 Subject: [PATCH 01/34] simsiam init imp --- .../self_supervised/simsiam/__init__.py | 0 .../models/self_supervised/simsiam/models.py | 39 ++++ .../self_supervised/simsiam/simsiam_module.py | 185 ++++++++++++++++++ 3 files changed, 224 insertions(+) create mode 100644 pl_bolts/models/self_supervised/simsiam/__init__.py create mode 100644 pl_bolts/models/self_supervised/simsiam/models.py create mode 100644 pl_bolts/models/self_supervised/simsiam/simsiam_module.py diff --git a/pl_bolts/models/self_supervised/simsiam/__init__.py b/pl_bolts/models/self_supervised/simsiam/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pl_bolts/models/self_supervised/simsiam/models.py b/pl_bolts/models/self_supervised/simsiam/models.py new file mode 100644 index 0000000000..9e587d0455 --- /dev/null +++ b/pl_bolts/models/self_supervised/simsiam/models.py @@ -0,0 +1,39 @@ +from torch import nn + +from pl_bolts.utils.self_supervised import torchvision_ssl_encoder + + +class MLP(nn.Module): + def __init__(self, input_dim=2048, hidden_size=4096, output_dim=256): + super().__init__() + self.output_dim = output_dim + self.input_dim = input_dim + self.model = nn.Sequential( + nn.Linear(input_dim, hidden_size, bias=False), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, output_dim, bias=True)) + + def forward(self, x): + x = self.model(x) + return x + + +class SiameseArm(nn.Module): + def __init__(self, encoder=None): + super().__init__() + + if encoder is None: + encoder = torchvision_ssl_encoder('resnet50') + # Encoder + self.encoder = encoder + # Projector + self.projector = MLP() + # Predictor + self.predictor = MLP(input_dim=256) + + def forward(self, x): + y = self.encoder(x)[0] + z = self.projector(y) + h = self.predictor(z) + return y, z, h diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py new file mode 100644 index 0000000000..d6bbb86bb1 --- /dev/null +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -0,0 +1,185 @@ +from argparse import ArgumentParser +from copy import deepcopy +from typing import Any + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from pytorch_lightning import seed_everything +from torch.optim import Adam + +from pl_bolts.models.self_supervised.simsiam.models import SiameseArm +from pl_bolts.optimizers.lars_scheduling import LARSWrapper +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR + + +class SimSiam(pl.LightningModule): + def __init__(self, + num_classes, + learning_rate: float = 0.2, + weight_decay: float = 1.5e-6, + input_height: int = 32, + batch_size: int = 32, + num_workers: int = 0, + warmup_epochs: int = 10, + max_epochs: int = 1000, + **kwargs): + """ + Args: + datamodule: The datamodule + learning_rate: the learning rate + weight_decay: optimizer weight decay + input_height: image input height + batch_size: the batch size + num_workers: number of workers + warmup_epochs: num of epochs for scheduler warm up + max_epochs: max epochs for scheduler + """ + super().__init__() + self.save_hyperparameters() + + self.online_network = SiameseArm() + self._init_target_network() + + def _init_target_network(self): + self.target_network = deepcopy(self.online_network) + + def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + self._init_target_network() + + def forward(self, x): + y, _, _ = self.online_network(x) + return y + + def cosine_similarity(self, a, b): + a = F.normalize(a, dim=-1) + b = F.normalize(b, dim=-1) + sim = (a * b).sum(-1).mean() + return sim + + def shared_step(self, batch, batch_idx): + (img_1, img_2, _), y = batch + + # Image 1 to image 2 loss + y1, z1, h1 = self.online_network(img_1) + with torch.no_grad(): + y2, z2, h2 = self.target_network(img_2) + loss_a = -1 * self.cosine_similarity(h1, z2) + + # Image 2 to image 1 loss + y1, z1, h1 = self.online_network(img_2) + with torch.no_grad(): + y2, z2, h2 = self.target_network(img_1) + # L2 normalize + loss_b = -1 * self.cosine_similarity(h1, z2) + + # Final loss + total_loss = loss_a + loss_b + + return loss_a, loss_b, total_loss + + def training_step(self, batch, batch_idx): + loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) + + # log results + self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) + + return total_loss + + def validation_step(self, batch, batch_idx): + loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) + + # log results + self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) + + return total_loss + + def configure_optimizers(self): + optimizer = Adam(self.online_network.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) + optimizer = LARSWrapper(optimizer) + scheduler = LinearWarmupCosineAnnealingLR( + optimizer, + warmup_epochs=self.hparams.warmup_epochs, + max_epochs=self.hparams.max_epochs + ) + return [optimizer], [scheduler] + + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--online_ft', action='store_true', help='run online finetuner') + parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10, imagenet2012, stl10') + + (args, _) = parser.parse_known_args() + + # Data + parser.add_argument('--data_dir', type=str, default='.') + parser.add_argument('--num_workers', default=0, type=int) + + # optim + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--learning_rate', type=float, default=1e-3) + parser.add_argument('--weight_decay', type=float, default=1.5e-6) + parser.add_argument('--warmup_epochs', type=float, default=10) + + # Model + parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet') + + return parser + + +def cli_main(): + from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator + from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule + from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform + + seed_everything(1234) + + parser = ArgumentParser() + + # trainer args + parser = pl.Trainer.add_argparse_args(parser) + + # model args + parser = SimSiam.add_model_specific_args(parser) + args = parser.parse_args() + + # pick data + dm = None + + # init default datamodule + if args.dataset == 'cifar10': + dm = CIFAR10DataModule.from_argparse_args(args) + dm.train_transforms = SimCLRTrainDataTransform(32) + dm.val_transforms = SimCLREvalDataTransform(32) + args.num_classes = dm.num_classes + + elif args.dataset == 'stl10': + dm = STL10DataModule.from_argparse_args(args) + dm.train_dataloader = dm.train_dataloader_mixed + dm.val_dataloader = dm.val_dataloader_mixed + + (c, h, w) = dm.size() + dm.train_transforms = SimCLRTrainDataTransform(h) + dm.val_transforms = SimCLREvalDataTransform(h) + args.num_classes = dm.num_classes + + elif args.dataset == 'imagenet2012': + dm = ImagenetDataModule.from_argparse_args(args, image_size=196) + (c, h, w) = dm.size() + dm.train_transforms = SimCLRTrainDataTransform(h) + dm.val_transforms = SimCLREvalDataTransform(h) + args.num_classes = dm.num_classes + + model = SimSiam(**args.__dict__) + + # finetune in real-time + online_eval = SSLOnlineEvaluator(dataset=args.dataset, z_dim=2048, num_classes=dm.num_classes) + + trainer = pl.Trainer.from_argparse_args(args, max_steps=300000, callbacks=[online_eval]) + + trainer.fit(model, dm) + + +if __name__ == '__main__': + cli_main() From 42b322ff5b130e8037591a6cf7556262595931d1 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Wed, 25 Nov 2020 23:54:52 +0200 Subject: [PATCH 02/34] add doc --- .../self_supervised/simsiam/simsiam_module.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index d6bbb86bb1..15bd93335a 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -14,6 +14,51 @@ class SimSiam(pl.LightningModule): +""" + PyTorch Lightning implementation of `Exploring Simple Siamese Representation Learning (SimSiam) + `_ + + Paper authors: Xinlei Chen, Kaiming He. + + Model implemented by: + - `Zvi Lapp `_ + + .. warning:: Work in progress. This implementation is still being verified. + + TODOs: + - verify on CIFAR-10 + - verify on STL-10 + - pre-train on imagenet + + Example:: + + model = SimSiam(num_classes=10) + + dm = CIFAR10DataModule(num_workers=0) + dm.train_transforms = SimCLRTrainDataTransform(32) + dm.val_transforms = SimCLREvalDataTransform(32) + + trainer = pl.Trainer() + trainer.fit(model, dm) + + Train:: + + trainer = Trainer() + trainer.fit(model) + + CLI command:: + + # cifar10 + python simsiam_module.py --gpus 1 + + # imagenet + python simsiam_module.py + --gpus 8 + --dataset imagenet2012 + --data_dir /path/to/imagenet/ + --meta_dir /path/to/folder/with/meta.bin/ + --batch_size 32 + """ def __init__(self, num_classes, learning_rate: float = 0.2, From 790fe6098448aae9cf54371b7ca5f48a41e251ad Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Thu, 26 Nov 2020 02:43:48 +0200 Subject: [PATCH 03/34] fix indent --- pl_bolts/models/self_supervised/simsiam/simsiam_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index 15bd93335a..7527e21c05 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -14,7 +14,7 @@ class SimSiam(pl.LightningModule): -""" + """ PyTorch Lightning implementation of `Exploring Simple Siamese Representation Learning (SimSiam) `_ From 58ffec24e8c9990badc67d3b276ce5a4716416e1 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Sat, 28 Nov 2020 18:26:28 +0200 Subject: [PATCH 04/34] black reformatted --- .../self_supervised/simsiam/simsiam_module.py | 97 +++++++++++++------ 1 file changed, 65 insertions(+), 32 deletions(-) diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index 7527e21c05..cff3477714 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -59,16 +59,19 @@ class SimSiam(pl.LightningModule): --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32 """ - def __init__(self, - num_classes, - learning_rate: float = 0.2, - weight_decay: float = 1.5e-6, - input_height: int = 32, - batch_size: int = 32, - num_workers: int = 0, - warmup_epochs: int = 10, - max_epochs: int = 1000, - **kwargs): + + def __init__( + self, + num_classes, + learning_rate: float = 0.2, + weight_decay: float = 1.5e-6, + input_height: int = 32, + batch_size: int = 32, + num_workers: int = 0, + warmup_epochs: int = 10, + max_epochs: int = 1000, + **kwargs + ): """ Args: datamodule: The datamodule @@ -89,7 +92,9 @@ def __init__(self, def _init_target_network(self): self.target_network = deepcopy(self.online_network) - def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_end( + self, outputs, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: self._init_target_network() def forward(self, x): @@ -127,7 +132,9 @@ def training_step(self, batch, batch_idx): loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) # log results - self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) + self.log_dict( + {"1_2_loss": loss_a, "2_1_loss": loss_b, "train_loss": total_loss} + ) return total_loss @@ -135,48 +142,70 @@ def validation_step(self, batch, batch_idx): loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) # log results - self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) + self.log_dict( + {"1_2_loss": loss_a, "2_1_loss": loss_b, "train_loss": total_loss} + ) return total_loss def configure_optimizers(self): - optimizer = Adam(self.online_network.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) + optimizer = Adam( + self.online_network.parameters(), + lr=self.hparams.learning_rate, + weight_decay=self.hparams.weight_decay, + ) optimizer = LARSWrapper(optimizer) scheduler = LinearWarmupCosineAnnealingLR( optimizer, warmup_epochs=self.hparams.warmup_epochs, - max_epochs=self.hparams.max_epochs + max_epochs=self.hparams.max_epochs, ) return [optimizer], [scheduler] @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument('--online_ft', action='store_true', help='run online finetuner') - parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10, imagenet2012, stl10') + parser.add_argument( + "--online_ft", action="store_true", help="run online finetuner" + ) + parser.add_argument( + "--dataset", + type=str, + default="cifar10", + help="cifar10, imagenet2012, stl10", + ) (args, _) = parser.parse_known_args() # Data - parser.add_argument('--data_dir', type=str, default='.') - parser.add_argument('--num_workers', default=0, type=int) + parser.add_argument("--data_dir", type=str, default=".") + parser.add_argument("--num_workers", default=0, type=int) # optim - parser.add_argument('--batch_size', type=int, default=256) - parser.add_argument('--learning_rate', type=float, default=1e-3) - parser.add_argument('--weight_decay', type=float, default=1.5e-6) - parser.add_argument('--warmup_epochs', type=float, default=10) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--learning_rate", type=float, default=1e-3) + parser.add_argument("--weight_decay", type=float, default=1.5e-6) + parser.add_argument("--warmup_epochs", type=float, default=10) # Model - parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet') + parser.add_argument( + "--meta_dir", default=".", type=str, help="path to meta.bin for imagenet" + ) return parser def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator - from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule - from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform + from pl_bolts.datamodules import ( + CIFAR10DataModule, + ImagenetDataModule, + STL10DataModule, + ) + from pl_bolts.models.self_supervised.simclr import ( + SimCLREvalDataTransform, + SimCLRTrainDataTransform, + ) seed_everything(1234) @@ -193,13 +222,13 @@ def cli_main(): dm = None # init default datamodule - if args.dataset == 'cifar10': + if args.dataset == "cifar10": dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) args.num_classes = dm.num_classes - elif args.dataset == 'stl10': + elif args.dataset == "stl10": dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed @@ -209,7 +238,7 @@ def cli_main(): dm.val_transforms = SimCLREvalDataTransform(h) args.num_classes = dm.num_classes - elif args.dataset == 'imagenet2012': + elif args.dataset == "imagenet2012": dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) @@ -219,12 +248,16 @@ def cli_main(): model = SimSiam(**args.__dict__) # finetune in real-time - online_eval = SSLOnlineEvaluator(dataset=args.dataset, z_dim=2048, num_classes=dm.num_classes) + online_eval = SSLOnlineEvaluator( + dataset=args.dataset, z_dim=2048, num_classes=dm.num_classes + ) - trainer = pl.Trainer.from_argparse_args(args, max_steps=300000, callbacks=[online_eval]) + trainer = pl.Trainer.from_argparse_args( + args, max_steps=300000, callbacks=[online_eval] + ) trainer.fit(model, dm) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() From 6abbfed17a041a15fc6ca9c3e999f4a07e838afd Mon Sep 17 00:00:00 2001 From: zlapp <43241560+zlapp@users.noreply.github.com> Date: Sun, 29 Nov 2020 15:55:38 +0200 Subject: [PATCH 05/34] No grad fixes, detach in sim calc --- .../self_supervised/simsiam/simsiam_module.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index cff3477714..477e9a466b 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -102,6 +102,7 @@ def forward(self, x): return y def cosine_similarity(self, a, b): + b = b.detach() # stop gradient of backbone + projection mlp a = F.normalize(a, dim=-1) b = F.normalize(b, dim=-1) sim = (a * b).sum(-1).mean() @@ -111,20 +112,17 @@ def shared_step(self, batch, batch_idx): (img_1, img_2, _), y = batch # Image 1 to image 2 loss - y1, z1, h1 = self.online_network(img_1) - with torch.no_grad(): - y2, z2, h2 = self.target_network(img_2) - loss_a = -1 * self.cosine_similarity(h1, z2) + _, z1, h1 = self.online_network(img_1) + _, z2, h2 = self.target_network(img_2) + loss_a = -1.0 * self.cosine_similarity(h1, z2) # Image 2 to image 1 loss - y1, z1, h1 = self.online_network(img_2) - with torch.no_grad(): - y2, z2, h2 = self.target_network(img_1) - # L2 normalize - loss_b = -1 * self.cosine_similarity(h1, z2) + _, z1, h1 = self.online_network(img_2) + _, z2, h2 = self.target_network(img_1) + loss_b = -1.0 * self.cosine_similarity(h1, z2) # Final loss - total_loss = loss_a + loss_b + total_loss = loss_a / 2.0 + loss_b / 2.0 return loss_a, loss_b, total_loss From 7d6737bfe2e27a8934a7e288b39e657e268b17b2 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Mon, 30 Nov 2020 20:33:20 +0200 Subject: [PATCH 06/34] adjusted loss factor -2 --- pl_bolts/models/self_supervised/simsiam/simsiam_module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index 477e9a466b..e63926cfc3 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -114,15 +114,15 @@ def shared_step(self, batch, batch_idx): # Image 1 to image 2 loss _, z1, h1 = self.online_network(img_1) _, z2, h2 = self.target_network(img_2) - loss_a = -1.0 * self.cosine_similarity(h1, z2) + loss_a = -2.0 * self.cosine_similarity(h1, z2) # Image 2 to image 1 loss _, z1, h1 = self.online_network(img_2) _, z2, h2 = self.target_network(img_1) - loss_b = -1.0 * self.cosine_similarity(h1, z2) + loss_b = -2.0 * self.cosine_similarity(h1, z2) # Final loss - total_loss = loss_a / 2.0 + loss_b / 2.0 + total_loss = loss_a + loss_b return loss_a, loss_b, total_loss From 0e18985713e449c3b74d09a8c6fdbceb215bee2d Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Mon, 7 Dec 2020 08:28:53 +0200 Subject: [PATCH 07/34] init dm similar to simclr implementation, revert loss to paper imp --- .../self_supervised/simsiam/simsiam_module.py | 213 +++++++++++++----- 1 file changed, 151 insertions(+), 62 deletions(-) diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index e63926cfc3..a7d14a834c 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -12,6 +12,12 @@ from pl_bolts.optimizers.lars_scheduling import LARSWrapper from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR +from pl_bolts.transforms.dataset_normalizations import ( + cifar10_normalization, + imagenet_normalization, + stl10_normalization, +) + class SimSiam(pl.LightningModule): """ @@ -32,7 +38,7 @@ class SimSiam(pl.LightningModule): Example:: - model = SimSiam(num_classes=10) + model = SimSiam() dm = CIFAR10DataModule(num_workers=0) dm.train_transforms = SimCLRTrainDataTransform(32) @@ -62,7 +68,6 @@ class SimSiam(pl.LightningModule): def __init__( self, - num_classes, learning_rate: float = 0.2, weight_decay: float = 1.5e-6, input_height: int = 32, @@ -92,9 +97,7 @@ def __init__( def _init_target_network(self): self.target_network = deepcopy(self.online_network) - def on_train_batch_end( - self, outputs, batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: + def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None: self._init_target_network() def forward(self, x): @@ -114,15 +117,15 @@ def shared_step(self, batch, batch_idx): # Image 1 to image 2 loss _, z1, h1 = self.online_network(img_1) _, z2, h2 = self.target_network(img_2) - loss_a = -2.0 * self.cosine_similarity(h1, z2) + loss_a = -1.0 * self.cosine_similarity(h1, z2) # Image 2 to image 1 loss _, z1, h1 = self.online_network(img_2) _, z2, h2 = self.target_network(img_1) - loss_b = -2.0 * self.cosine_similarity(h1, z2) + loss_b = -1.0 * self.cosine_similarity(h1, z2) # Final loss - total_loss = loss_a + loss_b + total_loss = loss_a / 2.0 + loss_b / 2.0 return loss_a, loss_b, total_loss @@ -130,9 +133,7 @@ def training_step(self, batch, batch_idx): loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) # log results - self.log_dict( - {"1_2_loss": loss_a, "2_1_loss": loss_b, "train_loss": total_loss} - ) + self.log_dict({"1_2_loss": loss_a, "2_1_loss": loss_b, "train_loss": total_loss}) return total_loss @@ -140,9 +141,7 @@ def validation_step(self, batch, batch_idx): loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) # log results - self.log_dict( - {"1_2_loss": loss_a, "2_1_loss": loss_b, "train_loss": total_loss} - ) + self.log_dict({"1_2_loss": loss_a, "2_1_loss": loss_b, "train_loss": total_loss}) return total_loss @@ -154,41 +153,59 @@ def configure_optimizers(self): ) optimizer = LARSWrapper(optimizer) scheduler = LinearWarmupCosineAnnealingLR( - optimizer, - warmup_epochs=self.hparams.warmup_epochs, - max_epochs=self.hparams.max_epochs, + optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs, ) return [optimizer], [scheduler] @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) + # model params + parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture") + # specify flags to store false + parser.add_argument("--first_conv", action="store_false") + parser.add_argument("--maxpool1", action="store_false") parser.add_argument( - "--online_ft", action="store_true", help="run online finetuner" + "--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head" ) + parser.add_argument("--feat_dim", default=128, type=int, help="feature dimension") + parser.add_argument("--online_ft", action="store_true") + parser.add_argument("--fp32", action="store_true") + + # transform params + parser.add_argument("--gaussian_blur", action="store_true", help="add gaussian blur") + parser.add_argument("--jitter_strength", type=float, default=1.0, help="jitter strength") + parser.add_argument("--dataset", type=str, default="cifar10", help="stl10, cifar10") + parser.add_argument("--data_dir", type=str, default=".", help="path to download data") + + # training params + parser.add_argument("--fast_dev_run", action="store_true") + parser.add_argument("--nodes", default=1, type=int, help="number of nodes for training") + parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on") + parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") + parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd") parser.add_argument( - "--dataset", - type=str, - default="cifar10", - help="cifar10, imagenet2012, stl10", + "--lars_wrapper", action="store_true", help="apple lars wrapper over optimizer used" ) + parser.add_argument( + "--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay" + ) + parser.add_argument( + "--max_epochs", default=100, type=int, help="number of total epochs to run" + ) + parser.add_argument("--max_steps", default=-1, type=int, help="max steps") + parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") + parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu") - (args, _) = parser.parse_known_args() - - # Data - parser.add_argument("--data_dir", type=str, default=".") - parser.add_argument("--num_workers", default=0, type=int) - - # optim - parser.add_argument("--batch_size", type=int, default=256) - parser.add_argument("--learning_rate", type=float, default=1e-3) - parser.add_argument("--weight_decay", type=float, default=1.5e-6) - parser.add_argument("--warmup_epochs", type=float, default=10) - - # Model parser.add_argument( - "--meta_dir", default=".", type=str, help="path to meta.bin for imagenet" + "--temperature", default=0.1, type=float, help="temperature parameter in training loss" + ) + parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay") + parser.add_argument("--learning_rate", default=1e-3, type=float, help="base learning rate") + parser.add_argument( + "--start_lr", default=0, type=float, help="initial warmup learning rate" ) + parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate") return parser @@ -209,9 +226,6 @@ def cli_main(): parser = ArgumentParser() - # trainer args - parser = pl.Trainer.add_argparse_args(parser) - # model args parser = SimSiam.add_model_specific_args(parser) args = parser.parse_args() @@ -219,39 +233,114 @@ def cli_main(): # pick data dm = None - # init default datamodule - if args.dataset == "cifar10": - dm = CIFAR10DataModule.from_argparse_args(args) - dm.train_transforms = SimCLRTrainDataTransform(32) - dm.val_transforms = SimCLREvalDataTransform(32) - args.num_classes = dm.num_classes + # init datamodule + if args.dataset == "stl10": + dm = STL10DataModule( + data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers + ) - elif args.dataset == "stl10": - dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed + args.num_samples = dm.num_unlabeled_samples + + args.maxpool1 = False + args.first_conv = True + args.input_height = dm.size()[-1] + + normalization = stl10_normalization() + + args.gaussian_blur = True + args.jitter_strength = 1.0 + elif args.dataset == "cifar10": + val_split = 5000 + if args.nodes * args.gpus * args.batch_size > val_split: + val_split = args.nodes * args.gpus * args.batch_size + + dm = CIFAR10DataModule( + data_dir=args.data_dir, + batch_size=args.batch_size, + num_workers=args.num_workers, + val_split=val_split, + ) + + args.num_samples = dm.num_samples + + args.maxpool1 = False + args.first_conv = False + args.input_height = dm.size()[-1] + args.temperature = 0.5 + + normalization = cifar10_normalization() - (c, h, w) = dm.size() - dm.train_transforms = SimCLRTrainDataTransform(h) - dm.val_transforms = SimCLREvalDataTransform(h) - args.num_classes = dm.num_classes + args.gaussian_blur = False + args.jitter_strength = 0.5 + elif args.dataset == "imagenet": + args.maxpool1 = True + args.first_conv = True + normalization = imagenet_normalization() - elif args.dataset == "imagenet2012": - dm = ImagenetDataModule.from_argparse_args(args, image_size=196) - (c, h, w) = dm.size() - dm.train_transforms = SimCLRTrainDataTransform(h) - dm.val_transforms = SimCLREvalDataTransform(h) - args.num_classes = dm.num_classes + args.gaussian_blur = True + args.jitter_strength = 1.0 + + args.batch_size = 64 + args.nodes = 8 + args.gpus = 8 # per-node + args.max_epochs = 800 + + args.optimizer = "sgd" + args.lars_wrapper = True + args.learning_rate = 4.8 + args.final_lr = 0.0048 + args.start_lr = 0.3 + args.online_ft = True + + dm = ImagenetDataModule( + data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers + ) + + args.num_samples = dm.num_samples + args.input_height = dm.size()[-1] + else: + raise NotImplementedError("other datasets have not been implemented till now") + + dm.train_transforms = SimCLRTrainDataTransform( + input_height=args.input_height, + gaussian_blur=args.gaussian_blur, + jitter_strength=args.jitter_strength, + normalize=normalization, + ) + + dm.val_transforms = SimCLREvalDataTransform( + input_height=args.input_height, + gaussian_blur=args.gaussian_blur, + jitter_strength=args.jitter_strength, + normalize=normalization, + ) model = SimSiam(**args.__dict__) # finetune in real-time - online_eval = SSLOnlineEvaluator( - dataset=args.dataset, z_dim=2048, num_classes=dm.num_classes - ) + online_evaluator = None + if args.online_ft: + # online eval + online_evaluator = SSLOnlineEvaluator( + drop_p=0.0, + hidden_dim=None, + z_dim=args.hidden_mlp, + num_classes=dm.num_classes, + dataset=args.dataset, + ) - trainer = pl.Trainer.from_argparse_args( - args, max_steps=300000, callbacks=[online_eval] + trainer = pl.Trainer( + max_epochs=args.max_epochs, + max_steps=None if args.max_steps == -1 else args.max_steps, + gpus=args.gpus, + num_nodes=args.nodes, + distributed_backend="ddp" if args.gpus > 1 else None, + sync_batchnorm=True if args.gpus > 1 else False, + precision=32 if args.fp32 else 16, + callbacks=[online_evaluator] if args.online_ft else None, + fast_dev_run=args.fast_dev_run, ) trainer.fit(model, dm) From 467ac4db3e91f8e1aefa0e230e7cbbac5cff28b7 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Mon, 7 Dec 2020 08:47:12 +0200 Subject: [PATCH 08/34] further simclr adjustment --- .../self_supervised/simsiam/simsiam_module.py | 160 ++++++++++++++++-- 1 file changed, 143 insertions(+), 17 deletions(-) diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index a7d14a834c..de551daa03 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -1,12 +1,16 @@ +import math from argparse import ArgumentParser from copy import deepcopy from typing import Any +from typing import Callable, Optional import pytorch_lightning as pl import torch +import numpy as np import torch.nn.functional as F +from pytorch_lightning.utilities import AMPType from pytorch_lightning import seed_everything -from torch.optim import Adam +from torch.optim.optimizer import Optimizer from pl_bolts.models.self_supervised.simsiam.models import SiameseArm from pl_bolts.optimizers.lars_scheduling import LARSWrapper @@ -68,13 +72,26 @@ class SimSiam(pl.LightningModule): def __init__( self, - learning_rate: float = 0.2, - weight_decay: float = 1.5e-6, - input_height: int = 32, - batch_size: int = 32, - num_workers: int = 0, + gpus: int, + num_samples: int, + batch_size: int, + dataset: str, + nodes: int = 1, + arch: str = 'resnet50', + hidden_mlp: int = 2048, + feat_dim: int = 128, warmup_epochs: int = 10, - max_epochs: int = 1000, + max_epochs: int = 100, + temperature: float = 0.1, + first_conv: bool = True, + maxpool1: bool = True, + optimizer: str = 'adam', + lars_wrapper: bool = True, + exclude_bn_bias: bool = False, + start_lr: float = 0., + learning_rate: float = 1e-3, + final_lr: float = 0., + weight_decay: float = 1e-6, **kwargs ): """ @@ -91,9 +108,48 @@ def __init__( super().__init__() self.save_hyperparameters() + self.gpus = gpus + self.nodes = nodes + self.arch = arch + self.dataset = dataset + self.num_samples = num_samples + self.batch_size = batch_size + + self.hidden_mlp = hidden_mlp + self.feat_dim = feat_dim + self.first_conv = first_conv + self.maxpool1 = maxpool1 + + self.optim = optimizer + self.lars_wrapper = lars_wrapper + self.exclude_bn_bias = exclude_bn_bias + self.weight_decay = weight_decay + self.temperature = temperature + + self.start_lr = start_lr + self.final_lr = final_lr + self.learning_rate = learning_rate + self.warmup_epochs = warmup_epochs + self.max_epochs = max_epochs + self.online_network = SiameseArm() self._init_target_network() + # compute iters per epoch + global_batch_size = self.nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size + self.train_iters_per_epoch = self.num_samples // global_batch_size + + # define LR schedule + warmup_lr_schedule = np.linspace( + self.start_lr, self.learning_rate, self.train_iters_per_epoch * self.warmup_epochs + ) + iters = np.arange(self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs)) + cosine_lr_schedule = np.array([self.final_lr + 0.5 * (self.learning_rate - self.final_lr) * ( + 1 + math.cos(math.pi * t / (self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs))) + ) for t in iters]) + + self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) + def _init_target_network(self): self.target_network = deepcopy(self.online_network) @@ -145,17 +201,87 @@ def validation_step(self, batch, batch_idx): return total_loss + def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']): + params = [] + excluded_params = [] + + for name, param in named_params: + if not param.requires_grad: + continue + elif any(layer_name in name for layer_name in skip_list): + excluded_params.append(param) + else: + params.append(param) + + return [ + {'params': params, 'weight_decay': weight_decay}, + {'params': excluded_params, 'weight_decay': 0.} + ] + def configure_optimizers(self): - optimizer = Adam( - self.online_network.parameters(), - lr=self.hparams.learning_rate, - weight_decay=self.hparams.weight_decay, - ) - optimizer = LARSWrapper(optimizer) - scheduler = LinearWarmupCosineAnnealingLR( - optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs, - ) - return [optimizer], [scheduler] + if self.exclude_bn_bias: + params = self.exclude_from_wt_decay( + self.named_parameters(), + weight_decay=self.weight_decay + ) + else: + params = self.parameters() + + if self.optim == 'sgd': + optimizer = torch.optim.SGD( + params, + lr=self.learning_rate, + momentum=0.9, + weight_decay=self.weight_decay + ) + elif self.optim == 'adam': + optimizer = torch.optim.Adam( + params, + lr=self.learning_rate, + weight_decay=self.weight_decay + ) + + if self.lars_wrapper: + optimizer = LARSWrapper( + optimizer, + eta=0.001, # trust coefficient + clip=False + ) + + return optimizer + + def optimizer_step( + self, + epoch: int, + batch_idx: int, + optimizer: Optimizer, + optimizer_idx: int, + optimizer_closure: Optional[Callable] = None, + on_tpu: bool = False, + using_native_amp: bool = False, + using_lbfgs: bool = False, + ) -> None: + # warm-up + decay schedule placed here since LARSWrapper is not optimizer class + # adjust LR of optim contained within LARSWrapper + if self.lars_wrapper: + for param_group in optimizer.optim.param_groups: + param_group["lr"] = self.lr_schedule[self.trainer.global_step] + else: + for param_group in optimizer.param_groups: + param_group["lr"] = self.lr_schedule[self.trainer.global_step] + + # log LR (LearningRateLogger callback doesn't work with LARSWrapper) + self.log('learning_rate', self.lr_schedule[self.trainer.global_step], on_step=True, on_epoch=False) + + # from lightning + if self.trainer.amp_backend == AMPType.NATIVE: + optimizer_closure() + self.trainer.scaler.step(optimizer) + elif self.trainer.amp_backend == AMPType.APEX: + optimizer_closure() + optimizer.step() + else: + optimizer.step(closure=optimizer_closure) @staticmethod def add_model_specific_args(parent_parser): From 3fb3c48cedba97aa8203c2d30cea8a4b21c29a86 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Mon, 7 Dec 2020 09:15:36 +0200 Subject: [PATCH 09/34] support resnet18 backbone --- .../models/self_supervised/simsiam/models.py | 6 ++--- .../self_supervised/simsiam/simsiam_module.py | 22 ++++++++++++++----- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/pl_bolts/models/self_supervised/simsiam/models.py b/pl_bolts/models/self_supervised/simsiam/models.py index 9e587d0455..14c55012a9 100644 --- a/pl_bolts/models/self_supervised/simsiam/models.py +++ b/pl_bolts/models/self_supervised/simsiam/models.py @@ -20,7 +20,7 @@ def forward(self, x): class SiameseArm(nn.Module): - def __init__(self, encoder=None): + def __init__(self, encoder=None, input_dim=2048, hidden_size=4096, output_dim=256): super().__init__() if encoder is None: @@ -28,9 +28,9 @@ def __init__(self, encoder=None): # Encoder self.encoder = encoder # Projector - self.projector = MLP() + self.projector = MLP(input_dim, hidden_size, output_dim) # Predictor - self.predictor = MLP(input_dim=256) + self.predictor = MLP(output_dim, hidden_size, output_dim) def forward(self, x): y = self.encoder(x)[0] diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index de551daa03..fc3e0ee758 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -13,6 +13,7 @@ from torch.optim.optimizer import Optimizer from pl_bolts.models.self_supervised.simsiam.models import SiameseArm +from pl_bolts.models.self_supervised.resnets import resnet18, resnet50 from pl_bolts.optimizers.lars_scheduling import LARSWrapper from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR @@ -132,9 +133,8 @@ def __init__( self.warmup_epochs = warmup_epochs self.max_epochs = max_epochs - self.online_network = SiameseArm() - self._init_target_network() - + self.init_model() + # compute iters per epoch global_batch_size = self.nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size self.train_iters_per_epoch = self.num_samples // global_batch_size @@ -150,11 +150,23 @@ def __init__( self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) - def _init_target_network(self): + def init_model(self): + if self.arch == 'resnet18': + backbone = resnet18 + elif self.arch == 'resnet50': + backbone = resnet50 + + encoder =backbone( + first_conv=self.first_conv, maxpool1=self.maxpool1, return_all_feature_maps=False + ) + self.online_network = SiameseArm(encoder,input_dim=self.hidden_mlp, hidden_size=self.hidden_mlp, output_dim=self.feat_dim) + self.init_target_network() + + def init_target_network(self): self.target_network = deepcopy(self.online_network) def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - self._init_target_network() + self.init_target_network() def forward(self, x): y, _, _ = self.online_network(x) From 28cfd69eb347021091cb7f8d0d3c9a1ff7d997bb Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Mon, 7 Dec 2020 11:45:29 +0200 Subject: [PATCH 10/34] knn online callback --- pl_bolts/callbacks/knn_online.py | 90 ++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 pl_bolts/callbacks/knn_online.py diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py new file mode 100644 index 0000000000..41a50909d1 --- /dev/null +++ b/pl_bolts/callbacks/knn_online.py @@ -0,0 +1,90 @@ +from typing import Optional + +import torch +from pytorch_lightning import Callback + +from sklearn.neighbors import KNeighborsClassifier + +class KNNOnlineEvaluator(Callback): # pragma: no-cover + """ + Evaluates self-supervised K nearest neighbors. + + Example:: + + # your model must have 1 attribute + model = Model() + model.num_classes = ... # the num of classes in the model + + online_eval = KNNOnlineEvaluator( + num_classes=model.num_classes, + dataset='imagenet' + ) + + """ + def __init__( + self, + dataset: str, + num_classes: int = None, + ): + """ + Args: + dataset: if stl10, need to get the labeled batch + num_classes: Number of classes + """ + super().__init__() + + self.num_classes = num_classes + self.dataset = dataset + + def on_pretrain_routine_start(self, trainer, pl_module): + pl_module.knn_evaluator = KNeighborsClassifier(n_neighbors=self.num_classes) + + + def get_representations(self, pl_module, x): + representations = pl_module(x) + representations = representations.reshape(representations.size(0), -1) + return representations + + def to_device(self, batch, device): + # get the labeled batch + if self.dataset == 'stl10': + labeled_batch = batch[1] + batch = labeled_batch + + inputs, y = batch + + # last input is for online eval + x = inputs[-1] + x = x.to(device) + y = y.to(device) + + return x, y + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + x, y = self.to_device(batch, pl_module.device) + + with torch.no_grad(): + representations = self.get_representations(pl_module, x) + + representations = representations.detach() + + # knn fit + pl_module.knn_evaluator.fit(representations, y) + train_acc = pl_module.knn_evaluator.score(representations, y) + + # log metrics + pl_module.log('online_knn_train_acc', train_acc, on_step=True, on_epoch=False) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + x, y = self.to_device(batch, pl_module.device) + + with torch.no_grad(): + representations = self.get_representations(pl_module, x) + + representations = representations.detach() + + # train knn + val_acc = pl_module.knn_evaluator.score(representations, y) + + # log metrics + pl_module.log('online_knn_val_acc', val_acc, on_step=False, on_epoch=True, sync_dist=True) From c0e1f3569d47828cbd9903da19d60a6fe6ab6c52 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Mon, 7 Dec 2020 12:34:11 +0200 Subject: [PATCH 11/34] fit and eval knn on val epoch end --- pl_bolts/callbacks/knn_online.py | 48 ++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py index 41a50909d1..fc475cc52e 100644 --- a/pl_bolts/callbacks/knn_online.py +++ b/pl_bolts/callbacks/knn_online.py @@ -36,15 +36,34 @@ def __init__( self.num_classes = num_classes self.dataset = dataset - def on_pretrain_routine_start(self, trainer, pl_module): - pl_module.knn_evaluator = KNeighborsClassifier(n_neighbors=self.num_classes) - def get_representations(self, pl_module, x): representations = pl_module(x) representations = representations.reshape(representations.size(0), -1) return representations + def get_all_representations(self, pl_module, dataloader): + all_representations = None + ys = None + + for batch in dataloader: + x, y = self.to_device(batch, pl_module.device) + + with torch.no_grad(): + representations = self.get_representations(pl_module, x) + + if all_representations is None: + all_representations = representations.detach() + else: + all_representations = torch.cat([all_representations,representations]) + + if ys is None: + ys = y + else: + ys = torch.cat([ys,y]) + + return all_representations.numpy(), ys.numpy() + def to_device(self, batch, device): # get the labeled batch if self.dataset == 'stl10': @@ -60,31 +79,24 @@ def to_device(self, batch, device): return x, y - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - x, y = self.to_device(batch, pl_module.device) - - with torch.no_grad(): - representations = self.get_representations(pl_module, x) + def on_validation_epoch_end(self, trainer, pl_module): + pl_module.knn_evaluator = KNeighborsClassifier(n_neighbors=self.num_classes) - representations = representations.detach() + train_dataloader = pl_module.train_dataloader() + representations, y = self.get_all_representations(pl_module, train_dataloader) # knn fit pl_module.knn_evaluator.fit(representations, y) train_acc = pl_module.knn_evaluator.score(representations, y) # log metrics - pl_module.log('online_knn_train_acc', train_acc, on_step=True, on_epoch=False) - - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - x, y = self.to_device(batch, pl_module.device) - - with torch.no_grad(): - representations = self.get_representations(pl_module, x) - representations = representations.detach() + val_dataloader = pl_module.val_dataloader() + representations, y = self.get_all_representations(pl_module, val_dataloader) - # train knn + # knn val acc val_acc = pl_module.knn_evaluator.score(representations, y) # log metrics + pl_module.log('online_knn_train_acc', train_acc, on_step=False, on_epoch=True, sync_dist=True) pl_module.log('online_knn_val_acc', val_acc, on_step=False, on_epoch=True, sync_dist=True) From ade7ad5a6b517fd72b39ad840581edccde774c7d Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Sun, 13 Dec 2020 10:19:38 +0200 Subject: [PATCH 12/34] simsiam tests --- tests/models/self_supervised/test_models.py | 16 +++++++++++++++- tests/models/self_supervised/test_scripts.py | 12 ++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index 6770f83d61..e12f0eee79 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -4,7 +4,7 @@ from pytorch_lightning import seed_everything from pl_bolts.datamodules import CIFAR10DataModule -from pl_bolts.models.self_supervised import AMDIM, BYOL, CPCV2, MocoV2, SimCLR, SwAV +from pl_bolts.models.self_supervised import AMDIM, BYOL, CPCV2, MocoV2, SimCLR, SwAV, SimSiam from pl_bolts.models.self_supervised.cpc import CPCEvalTransformsCIFAR10, CPCTrainTransformsCIFAR10 from pl_bolts.models.self_supervised.moco.callbacks import MocoLRScheduler from pl_bolts.models.self_supervised.moco.transforms import Moco2EvalCIFAR10Transforms, Moco2TrainCIFAR10Transforms @@ -125,3 +125,17 @@ def test_swav(tmpdir, datadir): loss = trainer.progress_bar_dict['loss'] assert float(loss) > 0 + +def test_simsiam(tmpdir, datadir): + seed_everything() + + datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) + datamodule.train_transforms = SimCLRTrainDataTransform(32) + datamodule.val_transforms = SimCLREvalDataTransform(32) + + model = SimSiam(batch_size=2, num_samples=datamodule.num_samples, gpus=0, nodes=1, dataset='cifar10') + trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir) + trainer.fit(model, datamodule) + loss = trainer.progress_bar_dict['loss'] + + assert float(loss) < 0 diff --git a/tests/models/self_supervised/test_scripts.py b/tests/models/self_supervised/test_scripts.py index e7cf0b19f4..9f64b826ad 100644 --- a/tests/models/self_supervised/test_scripts.py +++ b/tests/models/self_supervised/test_scripts.py @@ -86,3 +86,15 @@ def test_cli_run_self_supervised_swav(cli_args): cli_args = cli_args.split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() + + +@pytest.mark.parametrize('cli_args', [ + f'--data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run --batch_size 2 --online_ft' +]) +def test_cli_run_self_supervised_simsiam(cli_args): + """Test running CLI for an example with default params.""" + from pl_bolts.models.self_supervised.simsiam.simsiam_module import cli_main + + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main() \ No newline at end of file From 6cd4985ea4eb0c860ba785cba0384afacc3c2981 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Sun, 13 Dec 2020 10:34:58 +0200 Subject: [PATCH 13/34] added import --- pl_bolts/models/self_supervised/__init__.py | 1 + tests/models/self_supervised/test_models.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pl_bolts/models/self_supervised/__init__.py b/pl_bolts/models/self_supervised/__init__.py index 0960e8e297..11e61d1b10 100644 --- a/pl_bolts/models/self_supervised/__init__.py +++ b/pl_bolts/models/self_supervised/__init__.py @@ -26,6 +26,7 @@ from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR # noqa: F401 from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner # noqa: F401 from pl_bolts.models.self_supervised.swav.swav_module import SwAV # noqa: F401 +from pl_bolts.models.self_supervised.simsiam.simsiam_module import SimSiam # noqa: F401 __all__ = [ "AMDIM", diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index e12f0eee79..47b28bdc7a 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -139,3 +139,4 @@ def test_simsiam(tmpdir, datadir): loss = trainer.progress_bar_dict['loss'] assert float(loss) < 0 + From 9a091a83200e4db3df776424799bc971781561d9 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Mon, 14 Dec 2020 18:50:34 +0200 Subject: [PATCH 14/34] gpus 0 --- tests/models/self_supervised/test_models.py | 2 +- tests/models/self_supervised/test_scripts.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index 47b28bdc7a..ac6ce41fdd 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -134,7 +134,7 @@ def test_simsiam(tmpdir, datadir): datamodule.val_transforms = SimCLREvalDataTransform(32) model = SimSiam(batch_size=2, num_samples=datamodule.num_samples, gpus=0, nodes=1, dataset='cifar10') - trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir) + trainer = pl.Trainer(gpus=0, fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model, datamodule) loss = trainer.progress_bar_dict['loss'] diff --git a/tests/models/self_supervised/test_scripts.py b/tests/models/self_supervised/test_scripts.py index 9f64b826ad..4aeaa694ad 100644 --- a/tests/models/self_supervised/test_scripts.py +++ b/tests/models/self_supervised/test_scripts.py @@ -89,7 +89,7 @@ def test_cli_run_self_supervised_swav(cli_args): @pytest.mark.parametrize('cli_args', [ - f'--data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run --batch_size 2 --online_ft' + f'--data_dir {DATASETS_PATH} --gpus 0 --max_epochs 1 --max_steps 3 --fast_dev_run --batch_size 2 --online_ft' ]) def test_cli_run_self_supervised_simsiam(cli_args): """Test running CLI for an example with default params.""" From 3b8dada3327a79719e7f6402bf266e5faf512a36 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Mon, 14 Dec 2020 19:05:39 +0200 Subject: [PATCH 15/34] scikit-learn req --- requirements/test.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements/test.txt b/requirements/test.txt index 7554b5c88a..de9e5b9514 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -13,3 +13,5 @@ mypy yapf atari-py==0.2.6 # needed for RL + +scikit-learn>=0.20.0 \ No newline at end of file From 5ab5f4bd0fc59c431fcd420662e405e3e102a713 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Mon, 14 Dec 2020 19:05:50 +0200 Subject: [PATCH 16/34] flake8 --- pl_bolts/callbacks/knn_online.py | 11 ++++++----- .../models/self_supervised/simsiam/simsiam_module.py | 9 +++++---- tests/models/self_supervised/test_models.py | 2 +- tests/models/self_supervised/test_scripts.py | 2 +- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py index fc475cc52e..aeb0255267 100644 --- a/pl_bolts/callbacks/knn_online.py +++ b/pl_bolts/callbacks/knn_online.py @@ -5,6 +5,7 @@ from sklearn.neighbors import KNeighborsClassifier + class KNNOnlineEvaluator(Callback): # pragma: no-cover """ Evaluates self-supervised K nearest neighbors. @@ -21,6 +22,7 @@ class KNNOnlineEvaluator(Callback): # pragma: no-cover ) """ + def __init__( self, dataset: str, @@ -36,7 +38,6 @@ def __init__( self.num_classes = num_classes self.dataset = dataset - def get_representations(self, pl_module, x): representations = pl_module(x) representations = representations.reshape(representations.size(0), -1) @@ -51,16 +52,16 @@ def get_all_representations(self, pl_module, dataloader): with torch.no_grad(): representations = self.get_representations(pl_module, x) - + if all_representations is None: all_representations = representations.detach() else: - all_representations = torch.cat([all_representations,representations]) + all_representations = torch.cat([all_representations, representations]) if ys is None: ys = y else: - ys = torch.cat([ys,y]) + ys = torch.cat([ys, y]) return all_representations.numpy(), ys.numpy() @@ -96,7 +97,7 @@ def on_validation_epoch_end(self, trainer, pl_module): # knn val acc val_acc = pl_module.knn_evaluator.score(representations, y) - + # log metrics pl_module.log('online_knn_train_acc', train_acc, on_step=False, on_epoch=True, sync_dist=True) pl_module.log('online_knn_val_acc', val_acc, on_step=False, on_epoch=True, sync_dist=True) diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index fc3e0ee758..9a45f6f3e3 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -134,7 +134,7 @@ def __init__( self.max_epochs = max_epochs self.init_model() - + # compute iters per epoch global_batch_size = self.nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size self.train_iters_per_epoch = self.num_samples // global_batch_size @@ -155,11 +155,12 @@ def init_model(self): backbone = resnet18 elif self.arch == 'resnet50': backbone = resnet50 - - encoder =backbone( + + encoder = backbone( first_conv=self.first_conv, maxpool1=self.maxpool1, return_all_feature_maps=False ) - self.online_network = SiameseArm(encoder,input_dim=self.hidden_mlp, hidden_size=self.hidden_mlp, output_dim=self.feat_dim) + self.online_network = SiameseArm(encoder, input_dim=self.hidden_mlp, + hidden_size=self.hidden_mlp, output_dim=self.feat_dim) self.init_target_network() def init_target_network(self): diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index ac6ce41fdd..3911ab2253 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -126,6 +126,7 @@ def test_swav(tmpdir, datadir): assert float(loss) > 0 + def test_simsiam(tmpdir, datadir): seed_everything() @@ -139,4 +140,3 @@ def test_simsiam(tmpdir, datadir): loss = trainer.progress_bar_dict['loss'] assert float(loss) < 0 - diff --git a/tests/models/self_supervised/test_scripts.py b/tests/models/self_supervised/test_scripts.py index 4aeaa694ad..a0e58e8276 100644 --- a/tests/models/self_supervised/test_scripts.py +++ b/tests/models/self_supervised/test_scripts.py @@ -97,4 +97,4 @@ def test_cli_run_self_supervised_simsiam(cli_args): cli_args = cli_args.split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): - cli_main() \ No newline at end of file + cli_main() From 81fe48e63281d0f2482a7c176a8e673215f4d549 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Mon, 14 Dec 2020 19:09:39 +0200 Subject: [PATCH 17/34] scikit-learn bump version 0.23 --- requirements/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/test.txt b/requirements/test.txt index de9e5b9514..c8257bd5e3 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -14,4 +14,4 @@ yapf atari-py==0.2.6 # needed for RL -scikit-learn>=0.20.0 \ No newline at end of file +scikit-learn>=0.23 \ No newline at end of file From 2ceb2eb0907c615c954172cf1ce38b1a835fa93a Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Mon, 14 Dec 2020 19:24:30 +0200 Subject: [PATCH 18/34] isort --- pl_bolts/callbacks/knn_online.py | 1 - pl_bolts/models/self_supervised/__init__.py | 2 +- .../self_supervised/simsiam/simsiam_module.py | 17 ++++++----------- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py index aeb0255267..82077acab7 100644 --- a/pl_bolts/callbacks/knn_online.py +++ b/pl_bolts/callbacks/knn_online.py @@ -2,7 +2,6 @@ import torch from pytorch_lightning import Callback - from sklearn.neighbors import KNeighborsClassifier diff --git a/pl_bolts/models/self_supervised/__init__.py b/pl_bolts/models/self_supervised/__init__.py index 11e61d1b10..c05ee1030f 100644 --- a/pl_bolts/models/self_supervised/__init__.py +++ b/pl_bolts/models/self_supervised/__init__.py @@ -24,9 +24,9 @@ from pl_bolts.models.self_supervised.evaluator import SSLEvaluator # noqa: F401 from pl_bolts.models.self_supervised.moco.moco2_module import MocoV2 # noqa: F401 from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR # noqa: F401 +from pl_bolts.models.self_supervised.simsiam.simsiam_module import SimSiam # noqa: F401 from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner # noqa: F401 from pl_bolts.models.self_supervised.swav.swav_module import SwAV # noqa: F401 -from pl_bolts.models.self_supervised.simsiam.simsiam_module import SimSiam # noqa: F401 __all__ = [ "AMDIM", diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index 9a45f6f3e3..70742d86be 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -1,27 +1,22 @@ import math from argparse import ArgumentParser from copy import deepcopy -from typing import Any -from typing import Callable, Optional +from typing import Any, Callable, Optional +import numpy as np import pytorch_lightning as pl import torch -import numpy as np import torch.nn.functional as F -from pytorch_lightning.utilities import AMPType from pytorch_lightning import seed_everything +from pytorch_lightning.utilities import AMPType from torch.optim.optimizer import Optimizer -from pl_bolts.models.self_supervised.simsiam.models import SiameseArm from pl_bolts.models.self_supervised.resnets import resnet18, resnet50 +from pl_bolts.models.self_supervised.simsiam.models import SiameseArm from pl_bolts.optimizers.lars_scheduling import LARSWrapper from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR - -from pl_bolts.transforms.dataset_normalizations import ( - cifar10_normalization, - imagenet_normalization, - stl10_normalization, -) +from pl_bolts.transforms.dataset_normalizations import (cifar10_normalization, imagenet_normalization, + stl10_normalization) class SimSiam(pl.LightningModule): From 4b56fcd4963e75a2d0a7e9b9a066f4b6bf17ead5 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Mon, 14 Dec 2020 19:29:04 +0200 Subject: [PATCH 19/34] isort --- tests/models/self_supervised/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index 3911ab2253..edaa6f6902 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -4,7 +4,7 @@ from pytorch_lightning import seed_everything from pl_bolts.datamodules import CIFAR10DataModule -from pl_bolts.models.self_supervised import AMDIM, BYOL, CPCV2, MocoV2, SimCLR, SwAV, SimSiam +from pl_bolts.models.self_supervised import AMDIM, BYOL, CPCV2, MocoV2, SimCLR, SimSiam, SwAV from pl_bolts.models.self_supervised.cpc import CPCEvalTransformsCIFAR10, CPCTrainTransformsCIFAR10 from pl_bolts.models.self_supervised.moco.callbacks import MocoLRScheduler from pl_bolts.models.self_supervised.moco.transforms import Moco2EvalCIFAR10Transforms, Moco2TrainCIFAR10Transforms From 6c63f0b9e7b4a3da674b2fc83a36c7d376dd1855 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Mon, 21 Dec 2020 15:50:47 +0200 Subject: [PATCH 20/34] fix detatch --- pl_bolts/callbacks/knn_online.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py index 82077acab7..d290ec61ef 100644 --- a/pl_bolts/callbacks/knn_online.py +++ b/pl_bolts/callbacks/knn_online.py @@ -38,7 +38,8 @@ def __init__( self.dataset = dataset def get_representations(self, pl_module, x): - representations = pl_module(x) + with torch.no_grad(): + representations = pl_module(x) representations = representations.reshape(representations.size(0), -1) return representations @@ -55,14 +56,14 @@ def get_all_representations(self, pl_module, dataloader): if all_representations is None: all_representations = representations.detach() else: - all_representations = torch.cat([all_representations, representations]) + all_representations = torch.cat([all_representations, representations.detach()]) if ys is None: ys = y else: ys = torch.cat([ys, y]) - return all_representations.numpy(), ys.numpy() + return all_representations.cpu().numpy(), ys.cpu().numpy() def to_device(self, batch, device): # get the labeled batch From 696f27112e160c66ce633b21e8403ff21779dfa9 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Sun, 27 Dec 2020 18:41:39 +0200 Subject: [PATCH 21/34] rm deep copy --- .../self_supervised/simsiam/simsiam_module.py | 72 +++++++------------ 1 file changed, 26 insertions(+), 46 deletions(-) diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index 70742d86be..a24529b78a 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -1,7 +1,6 @@ import math from argparse import ArgumentParser -from copy import deepcopy -from typing import Any, Callable, Optional +from typing import Callable, Optional import numpy as np import pytorch_lightning as pl @@ -15,8 +14,11 @@ from pl_bolts.models.self_supervised.simsiam.models import SiameseArm from pl_bolts.optimizers.lars_scheduling import LARSWrapper from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR -from pl_bolts.transforms.dataset_normalizations import (cifar10_normalization, imagenet_normalization, - stl10_normalization) +from pl_bolts.transforms.dataset_normalizations import ( + cifar10_normalization, + imagenet_normalization, + stl10_normalization, +) class SimSiam(pl.LightningModule): @@ -129,7 +131,7 @@ def __init__( self.max_epochs = max_epochs self.init_model() - + # compute iters per epoch global_batch_size = self.nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size self.train_iters_per_epoch = self.num_samples // global_batch_size @@ -150,19 +152,12 @@ def init_model(self): backbone = resnet18 elif self.arch == 'resnet50': backbone = resnet50 - + encoder = backbone( first_conv=self.first_conv, maxpool1=self.maxpool1, return_all_feature_maps=False ) - self.online_network = SiameseArm(encoder, input_dim=self.hidden_mlp, - hidden_size=self.hidden_mlp, output_dim=self.feat_dim) - self.init_target_network() - - def init_target_network(self): - self.target_network = deepcopy(self.online_network) - - def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - self.init_target_network() + self.online_network = SiameseArm(encoder, input_dim=self.hidden_mlp, + hidden_size=self.hidden_mlp, output_dim=self.feat_dim) def forward(self, x): y, _, _ = self.online_network(x) @@ -172,42 +167,34 @@ def cosine_similarity(self, a, b): b = b.detach() # stop gradient of backbone + projection mlp a = F.normalize(a, dim=-1) b = F.normalize(b, dim=-1) - sim = (a * b).sum(-1).mean() + sim = -1 * (a * b).sum(-1).mean() return sim - def shared_step(self, batch, batch_idx): + def training_step(self, batch, batch_idx): (img_1, img_2, _), y = batch # Image 1 to image 2 loss _, z1, h1 = self.online_network(img_1) - _, z2, h2 = self.target_network(img_2) - loss_a = -1.0 * self.cosine_similarity(h1, z2) - - # Image 2 to image 1 loss - _, z1, h1 = self.online_network(img_2) - _, z2, h2 = self.target_network(img_1) - loss_b = -1.0 * self.cosine_similarity(h1, z2) - - # Final loss - total_loss = loss_a / 2.0 + loss_b / 2.0 - - return loss_a, loss_b, total_loss - - def training_step(self, batch, batch_idx): - loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) + _, z2, h2 = self.online_network(img_2) + loss = self.cosine_similarity(h1, z2) / 2 + self.cosine_similarity(h2, z1) / 2 # log results - self.log_dict({"1_2_loss": loss_a, "2_1_loss": loss_b, "train_loss": total_loss}) + self.log_dict({"loss": loss}) - return total_loss + return loss def validation_step(self, batch, batch_idx): - loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) + (img_1, img_2, _), y = batch + + # Image 1 to image 2 loss + _, z1, h1 = self.online_network(img_1) + _, z2, h2 = self.online_network(img_2) + loss = self.cosine_similarity(h1, z2) / 2 + self.cosine_similarity(h2, z1) / 2 # log results - self.log_dict({"1_2_loss": loss_a, "2_1_loss": loss_b, "train_loss": total_loss}) + self.log_dict({"loss": loss}) - return total_loss + return loss def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']): params = [] @@ -346,15 +333,8 @@ def add_model_specific_args(parent_parser): def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator - from pl_bolts.datamodules import ( - CIFAR10DataModule, - ImagenetDataModule, - STL10DataModule, - ) - from pl_bolts.models.self_supervised.simclr import ( - SimCLREvalDataTransform, - SimCLRTrainDataTransform, - ) + from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule + from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform seed_everything(1234) From bc9c24015eec5f8c43c2186cc71fa0bec5a25035 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 8 Jan 2021 00:03:24 +0900 Subject: [PATCH 22/34] Add types to knn_online.py --- pl_bolts/callbacks/knn_online.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py index d290ec61ef..0f49e717dd 100644 --- a/pl_bolts/callbacks/knn_online.py +++ b/pl_bolts/callbacks/knn_online.py @@ -1,7 +1,9 @@ -from typing import Optional +from typing import Optional, Tuple +import numpy as np import torch -from pytorch_lightning import Callback +from torch.utils.data import DataLoader +from pytorch_lightning import Callback, LightningModule, Trainer from sklearn.neighbors import KNeighborsClassifier @@ -25,8 +27,8 @@ class KNNOnlineEvaluator(Callback): # pragma: no-cover def __init__( self, dataset: str, - num_classes: int = None, - ): + num_classes: Optional[int] = None, + ) -> None: """ Args: dataset: if stl10, need to get the labeled batch @@ -37,13 +39,13 @@ def __init__( self.num_classes = num_classes self.dataset = dataset - def get_representations(self, pl_module, x): + def get_representations(self, pl_module: LightningModule, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): representations = pl_module(x) representations = representations.reshape(representations.size(0), -1) return representations - def get_all_representations(self, pl_module, dataloader): + def get_all_representations(self, pl_module: LightningModule, dataloader: DataLoader) -> Tuple[np.ndarray, np.ndarray]: all_representations = None ys = None @@ -65,7 +67,7 @@ def get_all_representations(self, pl_module, dataloader): return all_representations.cpu().numpy(), ys.cpu().numpy() - def to_device(self, batch, device): + def to_device(self, batch: torch.Tensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: # get the labeled batch if self.dataset == 'stl10': labeled_batch = batch[1] @@ -80,7 +82,7 @@ def to_device(self, batch, device): return x, y - def on_validation_epoch_end(self, trainer, pl_module): + def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: pl_module.knn_evaluator = KNeighborsClassifier(n_neighbors=self.num_classes) train_dataloader = pl_module.train_dataloader() From 217b36f8b14a255ab727c6136c8fa9c925bb4134 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 8 Jan 2021 00:07:57 +0900 Subject: [PATCH 23/34] Add types to models.py --- pl_bolts/models/self_supervised/simsiam/models.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pl_bolts/models/self_supervised/simsiam/models.py b/pl_bolts/models/self_supervised/simsiam/models.py index 14c55012a9..569a4a4ef7 100644 --- a/pl_bolts/models/self_supervised/simsiam/models.py +++ b/pl_bolts/models/self_supervised/simsiam/models.py @@ -1,10 +1,13 @@ +from typing import Optional + +import torch from torch import nn from pl_bolts.utils.self_supervised import torchvision_ssl_encoder class MLP(nn.Module): - def __init__(self, input_dim=2048, hidden_size=4096, output_dim=256): + def __init__(self, input_dim: int = 2048, hidden_size: int = 4096, output_dim: int = 256) -> None: super().__init__() self.output_dim = output_dim self.input_dim = input_dim @@ -14,13 +17,13 @@ def __init__(self, input_dim=2048, hidden_size=4096, output_dim=256): nn.ReLU(inplace=True), nn.Linear(hidden_size, output_dim, bias=True)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.model(x) return x class SiameseArm(nn.Module): - def __init__(self, encoder=None, input_dim=2048, hidden_size=4096, output_dim=256): + def __init__(self, encoder: Optional[nn.Module] = None, input_dim: int = 2048, hidden_size: int = 4096, output_dim: int = 256) -> None: super().__init__() if encoder is None: @@ -32,7 +35,7 @@ def __init__(self, encoder=None, input_dim=2048, hidden_size=4096, output_dim=25 # Predictor self.predictor = MLP(output_dim, hidden_size, output_dim) - def forward(self, x): + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: y = self.encoder(x)[0] z = self.projector(y) h = self.predictor(z) From 37642208f7345271e26aa5495ab7b4d5aff21004 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 8 Jan 2021 00:08:22 +0900 Subject: [PATCH 24/34] Fix types in models.py --- pl_bolts/models/self_supervised/simsiam/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/models/self_supervised/simsiam/models.py b/pl_bolts/models/self_supervised/simsiam/models.py index 569a4a4ef7..d581b340af 100644 --- a/pl_bolts/models/self_supervised/simsiam/models.py +++ b/pl_bolts/models/self_supervised/simsiam/models.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch from torch import nn From 3214c8514ec90d5010e4179dc20786533198ae58 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 8 Jan 2021 00:14:40 +0900 Subject: [PATCH 25/34] Fix tests --- tests/models/self_supervised/test_scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/self_supervised/test_scripts.py b/tests/models/self_supervised/test_scripts.py index a0e58e8276..b669f359e7 100644 --- a/tests/models/self_supervised/test_scripts.py +++ b/tests/models/self_supervised/test_scripts.py @@ -89,7 +89,7 @@ def test_cli_run_self_supervised_swav(cli_args): @pytest.mark.parametrize('cli_args', [ - f'--data_dir {DATASETS_PATH} --gpus 0 --max_epochs 1 --max_steps 3 --fast_dev_run --batch_size 2 --online_ft' + f'--data_dir {DATASETS_PATH} --gpus 0 --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2 --online_ft' ]) def test_cli_run_self_supervised_simsiam(cli_args): """Test running CLI for an example with default params.""" From 13f5af3e8c5cbce9599aece4ae094d6258a92cdb Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 8 Jan 2021 00:51:26 +0900 Subject: [PATCH 26/34] Fix types in knn_online.py --- pl_bolts/callbacks/knn_online.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py index 0f49e717dd..b0d86de7ee 100644 --- a/pl_bolts/callbacks/knn_online.py +++ b/pl_bolts/callbacks/knn_online.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import numpy as np import torch @@ -65,9 +65,9 @@ def get_all_representations(self, pl_module: LightningModule, dataloader: DataLo else: ys = torch.cat([ys, y]) - return all_representations.cpu().numpy(), ys.cpu().numpy() + return all_representations.cpu().numpy(), ys.cpu().numpy() # type: ignore[union-attr] - def to_device(self, batch: torch.Tensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: + def to_device(self, batch: torch.Tensor, device: Union[str, torch.device]) -> Tuple[torch.Tensor, torch.Tensor]: # get the labeled batch if self.dataset == 'stl10': labeled_batch = batch[1] @@ -89,16 +89,16 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) representations, y = self.get_all_representations(pl_module, train_dataloader) # knn fit - pl_module.knn_evaluator.fit(representations, y) - train_acc = pl_module.knn_evaluator.score(representations, y) + pl_module.knn_evaluator.fit(representations, y) # type: ignore[union-attr,operator] + train_acc = pl_module.knn_evaluator.score(representations, y) # type: ignore[union-attr,operator] # log metrics val_dataloader = pl_module.val_dataloader() - representations, y = self.get_all_representations(pl_module, val_dataloader) + representations, y = self.get_all_representations(pl_module, val_dataloader) # type: ignore[arg-type] # knn val acc - val_acc = pl_module.knn_evaluator.score(representations, y) + val_acc = pl_module.knn_evaluator.score(representations, y) # type: ignore[union-attr,operator] # log metrics pl_module.log('online_knn_train_acc', train_acc, on_step=False, on_epoch=True, sync_dist=True) From 0083276c06a02336c00a42b0eef9afcc15d29eec Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 8 Jan 2021 00:56:58 +0900 Subject: [PATCH 27/34] Add SimSiam --- pl_bolts/models/self_supervised/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pl_bolts/models/self_supervised/__init__.py b/pl_bolts/models/self_supervised/__init__.py index c05ee1030f..c88421a3ec 100644 --- a/pl_bolts/models/self_supervised/__init__.py +++ b/pl_bolts/models/self_supervised/__init__.py @@ -35,6 +35,7 @@ "SSLEvaluator", "MocoV2", "SimCLR", + "SimSiam", "SSLFineTuner", "SwAV", ] From 829274e8f36354c7cbef550fab1d09f85ec65c7a Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 8 Jan 2021 00:57:17 +0900 Subject: [PATCH 28/34] Apply isort --- pl_bolts/callbacks/knn_online.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py index b0d86de7ee..2cb324e211 100644 --- a/pl_bolts/callbacks/knn_online.py +++ b/pl_bolts/callbacks/knn_online.py @@ -2,9 +2,9 @@ import numpy as np import torch -from torch.utils.data import DataLoader from pytorch_lightning import Callback, LightningModule, Trainer from sklearn.neighbors import KNeighborsClassifier +from torch.utils.data import DataLoader class KNNOnlineEvaluator(Callback): # pragma: no-cover From 781982018de6f6236a8a6237124c93115498cc9b Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 8 Jan 2021 01:05:42 +0900 Subject: [PATCH 29/34] Import sklearn as optional package --- pl_bolts/callbacks/knn_online.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py index 2cb324e211..be332b5845 100644 --- a/pl_bolts/callbacks/knn_online.py +++ b/pl_bolts/callbacks/knn_online.py @@ -3,9 +3,16 @@ import numpy as np import torch from pytorch_lightning import Callback, LightningModule, Trainer -from sklearn.neighbors import KNeighborsClassifier from torch.utils.data import DataLoader +from pl_bolts.utils import _SKLEARN_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _SKLEARN_AVAILABLE: + from sklearn.neighbors import KNeighborsClassifier +else: # pragma: no cover + warn_missing_pkg("sklearn", pypi_name="scikit-learn") + class KNNOnlineEvaluator(Callback): # pragma: no-cover """ @@ -34,6 +41,11 @@ def __init__( dataset: if stl10, need to get the labeled batch num_classes: Number of classes """ + if not _SKLEARN_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( + "You want to use `KNeighborsClassifier` function from `scikit-learn` which is not installed yet." + ) + super().__init__() self.num_classes = num_classes From 75df5b98eb981a4b2db6a4cec4813f9682a62c88 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sat, 9 Jan 2021 22:15:14 +0900 Subject: [PATCH 30/34] Fix flake8 --- pl_bolts/callbacks/knn_online.py | 6 +++++- pl_bolts/models/self_supervised/simsiam/models.py | 8 +++++++- .../models/self_supervised/simsiam/simsiam_module.py | 9 ++++----- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py index be332b5845..ff8cd35af6 100644 --- a/pl_bolts/callbacks/knn_online.py +++ b/pl_bolts/callbacks/knn_online.py @@ -57,7 +57,11 @@ def get_representations(self, pl_module: LightningModule, x: torch.Tensor) -> to representations = representations.reshape(representations.size(0), -1) return representations - def get_all_representations(self, pl_module: LightningModule, dataloader: DataLoader) -> Tuple[np.ndarray, np.ndarray]: + def get_all_representations( + self, + pl_module: LightningModule, + dataloader: DataLoader + ) -> Tuple[np.ndarray, np.ndarray]: all_representations = None ys = None diff --git a/pl_bolts/models/self_supervised/simsiam/models.py b/pl_bolts/models/self_supervised/simsiam/models.py index d581b340af..61ebfaceb3 100644 --- a/pl_bolts/models/self_supervised/simsiam/models.py +++ b/pl_bolts/models/self_supervised/simsiam/models.py @@ -23,7 +23,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SiameseArm(nn.Module): - def __init__(self, encoder: Optional[nn.Module] = None, input_dim: int = 2048, hidden_size: int = 4096, output_dim: int = 256) -> None: + def __init__( + self, + encoder: Optional[nn.Module] = None, + input_dim: int = 2048, + hidden_size: int = 4096, + output_dim: int = 256, + ) -> None: super().__init__() if encoder is None: diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index a24529b78a..cbff17ae82 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -13,7 +13,6 @@ from pl_bolts.models.self_supervised.resnets import resnet18, resnet50 from pl_bolts.models.self_supervised.simsiam.models import SiameseArm from pl_bolts.optimizers.lars_scheduling import LARSWrapper -from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR from pl_bolts.transforms.dataset_normalizations import ( cifar10_normalization, imagenet_normalization, @@ -131,7 +130,7 @@ def __init__( self.max_epochs = max_epochs self.init_model() - + # compute iters per epoch global_batch_size = self.nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size self.train_iters_per_epoch = self.num_samples // global_batch_size @@ -152,12 +151,12 @@ def init_model(self): backbone = resnet18 elif self.arch == 'resnet50': backbone = resnet50 - + encoder = backbone( first_conv=self.first_conv, maxpool1=self.maxpool1, return_all_feature_maps=False ) - self.online_network = SiameseArm(encoder, input_dim=self.hidden_mlp, - hidden_size=self.hidden_mlp, output_dim=self.feat_dim) + self.online_network = SiameseArm(encoder, input_dim=self.hidden_mlp, + hidden_size=self.hidden_mlp, output_dim=self.feat_dim) def forward(self, x): y, _, _ = self.online_network(x) From adf184df2ff52e394601c76b756cd750961afe0f Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sat, 9 Jan 2021 22:46:27 +0900 Subject: [PATCH 31/34] Add args via Trainer and make the tests work on cpu --- .../models/self_supervised/simsiam/simsiam_module.py | 9 +++------ tests/models/self_supervised/test_scripts.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index cbff17ae82..1d7f933b13 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -299,9 +299,7 @@ def add_model_specific_args(parent_parser): parser.add_argument("--data_dir", type=str, default=".", help="path to download data") # training params - parser.add_argument("--fast_dev_run", action="store_true") parser.add_argument("--nodes", default=1, type=int, help="number of nodes for training") - parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on") parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd") parser.add_argument( @@ -310,10 +308,6 @@ def add_model_specific_args(parent_parser): parser.add_argument( "--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay" ) - parser.add_argument( - "--max_epochs", default=100, type=int, help="number of total epochs to run" - ) - parser.add_argument("--max_steps", default=-1, type=int, help="max steps") parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu") @@ -339,6 +333,9 @@ def cli_main(): parser = ArgumentParser() + # trainer args + parser = pl.Trainer.add_argparse_args(parser) + # model args parser = SimSiam.add_model_specific_args(parser) args = parser.parse_args() diff --git a/tests/models/self_supervised/test_scripts.py b/tests/models/self_supervised/test_scripts.py index b669f359e7..919c42f23f 100644 --- a/tests/models/self_supervised/test_scripts.py +++ b/tests/models/self_supervised/test_scripts.py @@ -89,7 +89,7 @@ def test_cli_run_self_supervised_swav(cli_args): @pytest.mark.parametrize('cli_args', [ - f'--data_dir {DATASETS_PATH} --gpus 0 --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2 --online_ft' + f'--dataset cifar10 --data_dir {DATASETS_PATH} --gpus 0 --fp32 --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2 --online_ft' ]) def test_cli_run_self_supervised_simsiam(cli_args): """Test running CLI for an example with default params.""" From 715d934dcb23e41be0fb004f0f30c2f1f7a15281 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sat, 9 Jan 2021 22:47:59 +0900 Subject: [PATCH 32/34] Fix flake8 --- tests/models/self_supervised/test_scripts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/self_supervised/test_scripts.py b/tests/models/self_supervised/test_scripts.py index 919c42f23f..0ad0ef4139 100644 --- a/tests/models/self_supervised/test_scripts.py +++ b/tests/models/self_supervised/test_scripts.py @@ -89,7 +89,8 @@ def test_cli_run_self_supervised_swav(cli_args): @pytest.mark.parametrize('cli_args', [ - f'--dataset cifar10 --data_dir {DATASETS_PATH} --gpus 0 --fp32 --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2 --online_ft' + f'--dataset cifar10 --data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2' + ' --gpus 0 --fp32 --online_ft' ]) def test_cli_run_self_supervised_simsiam(cli_args): """Test running CLI for an example with default params.""" From e648eec7c6b27e8f438704e20f8bcc9d42479322 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 17 Jan 2021 21:46:20 +0100 Subject: [PATCH 33/34] chlog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b1df0f340a..6195660283 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added flags to datamodules ([#388](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/388)) - Added metric GIoU ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347)) - Added Intersection over Union Metric/Loss ([#469](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/469)) +- Added SimSiam model ([#407](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/407)) ### Changed From 1717c0aac12195914edb9b616b3de42a8d2ba297 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 17 Jan 2021 21:55:39 +0100 Subject: [PATCH 34/34] yapf --- .github/workflows/code-format.yml | 9 +-- pl_bolts/callbacks/knn_online.py | 2 +- .../models/self_supervised/simsiam/models.py | 5 +- .../self_supervised/simsiam/simsiam_module.py | 71 +++++++------------ tests/models/self_supervised/test_scripts.py | 10 +-- 5 files changed, 40 insertions(+), 57 deletions(-) diff --git a/.github/workflows/code-format.yml b/.github/workflows/code-format.yml index e799af8c01..972b2e2452 100644 --- a/.github/workflows/code-format.yml +++ b/.github/workflows/code-format.yml @@ -21,8 +21,7 @@ jobs: pip --version shell: bash - name: PEP8 - run: | - flake8 . + run: flake8 . format-check-yapf: runs-on: ubuntu-20.04 @@ -38,8 +37,7 @@ jobs: pip --version shell: bash - name: yapf - run: | - yapf --diff --parallel --recursive . + run: yapf --diff --parallel --recursive . imports-check-isort: runs-on: ubuntu-20.04 @@ -67,5 +65,4 @@ jobs: pip install mypy pip list - name: mypy - run: | - mypy + run: mypy diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py index ff8cd35af6..32eda875bb 100644 --- a/pl_bolts/callbacks/knn_online.py +++ b/pl_bolts/callbacks/knn_online.py @@ -60,7 +60,7 @@ def get_representations(self, pl_module: LightningModule, x: torch.Tensor) -> to def get_all_representations( self, pl_module: LightningModule, - dataloader: DataLoader + dataloader: DataLoader, ) -> Tuple[np.ndarray, np.ndarray]: all_representations = None ys = None diff --git a/pl_bolts/models/self_supervised/simsiam/models.py b/pl_bolts/models/self_supervised/simsiam/models.py index 61ebfaceb3..ad020ad9fc 100644 --- a/pl_bolts/models/self_supervised/simsiam/models.py +++ b/pl_bolts/models/self_supervised/simsiam/models.py @@ -7,6 +7,7 @@ class MLP(nn.Module): + def __init__(self, input_dim: int = 2048, hidden_size: int = 4096, output_dim: int = 256) -> None: super().__init__() self.output_dim = output_dim @@ -15,7 +16,8 @@ def __init__(self, input_dim: int = 2048, hidden_size: int = 4096, output_dim: i nn.Linear(input_dim, hidden_size, bias=False), nn.BatchNorm1d(hidden_size), nn.ReLU(inplace=True), - nn.Linear(hidden_size, output_dim, bias=True)) + nn.Linear(hidden_size, output_dim, bias=True), + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.model(x) @@ -23,6 +25,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SiameseArm(nn.Module): + def __init__( self, encoder: Optional[nn.Module] = None, diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index 1d7f933b13..cc69d4169c 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -140,9 +140,11 @@ def __init__( self.start_lr, self.learning_rate, self.train_iters_per_epoch * self.warmup_epochs ) iters = np.arange(self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs)) - cosine_lr_schedule = np.array([self.final_lr + 0.5 * (self.learning_rate - self.final_lr) * ( - 1 + math.cos(math.pi * t / (self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs))) - ) for t in iters]) + cosine_lr_schedule = np.array([ + self.final_lr + 0.5 * (self.learning_rate - self.final_lr) * + (1 + math.cos(math.pi * t / (self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs)))) + for t in iters + ]) self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) @@ -152,11 +154,10 @@ def init_model(self): elif self.arch == 'resnet50': backbone = resnet50 - encoder = backbone( - first_conv=self.first_conv, maxpool1=self.maxpool1, return_all_feature_maps=False + encoder = backbone(first_conv=self.first_conv, maxpool1=self.maxpool1, return_all_feature_maps=False) + self.online_network = SiameseArm( + encoder, input_dim=self.hidden_mlp, hidden_size=self.hidden_mlp, output_dim=self.feat_dim ) - self.online_network = SiameseArm(encoder, input_dim=self.hidden_mlp, - hidden_size=self.hidden_mlp, output_dim=self.feat_dim) def forward(self, x): y, _, _ = self.online_network(x) @@ -208,32 +209,26 @@ def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', ' params.append(param) return [ - {'params': params, 'weight_decay': weight_decay}, - {'params': excluded_params, 'weight_decay': 0.} + { + 'params': params, + 'weight_decay': weight_decay + }, + { + 'params': excluded_params, + 'weight_decay': 0. + }, ] def configure_optimizers(self): if self.exclude_bn_bias: - params = self.exclude_from_wt_decay( - self.named_parameters(), - weight_decay=self.weight_decay - ) + params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay) else: params = self.parameters() if self.optim == 'sgd': - optimizer = torch.optim.SGD( - params, - lr=self.learning_rate, - momentum=0.9, - weight_decay=self.weight_decay - ) + optimizer = torch.optim.SGD(params, lr=self.learning_rate, momentum=0.9, weight_decay=self.weight_decay) elif self.optim == 'adam': - optimizer = torch.optim.Adam( - params, - lr=self.learning_rate, - weight_decay=self.weight_decay - ) + optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) if self.lars_wrapper: optimizer = LARSWrapper( @@ -285,9 +280,7 @@ def add_model_specific_args(parent_parser): # specify flags to store false parser.add_argument("--first_conv", action="store_false") parser.add_argument("--maxpool1", action="store_false") - parser.add_argument( - "--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head" - ) + parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head") parser.add_argument("--feat_dim", default=128, type=int, help="feature dimension") parser.add_argument("--online_ft", action="store_true") parser.add_argument("--fp32", action="store_true") @@ -302,23 +295,15 @@ def add_model_specific_args(parent_parser): parser.add_argument("--nodes", default=1, type=int, help="number of nodes for training") parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd") - parser.add_argument( - "--lars_wrapper", action="store_true", help="apple lars wrapper over optimizer used" - ) - parser.add_argument( - "--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay" - ) + parser.add_argument("--lars_wrapper", action="store_true", help="apple lars wrapper over optimizer used") + parser.add_argument("--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay") parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu") - parser.add_argument( - "--temperature", default=0.1, type=float, help="temperature parameter in training loss" - ) + parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss") parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay") parser.add_argument("--learning_rate", default=1e-3, type=float, help="base learning rate") - parser.add_argument( - "--start_lr", default=0, type=float, help="initial warmup learning rate" - ) + parser.add_argument("--start_lr", default=0, type=float, help="initial warmup learning rate") parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate") return parser @@ -345,9 +330,7 @@ def cli_main(): # init datamodule if args.dataset == "stl10": - dm = STL10DataModule( - data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers - ) + dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed @@ -404,9 +387,7 @@ def cli_main(): args.start_lr = 0.3 args.online_ft = True - dm = ImagenetDataModule( - data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers - ) + dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.input_height = dm.size()[-1] diff --git a/tests/models/self_supervised/test_scripts.py b/tests/models/self_supervised/test_scripts.py index 0ad0ef4139..76ef8d6053 100644 --- a/tests/models/self_supervised/test_scripts.py +++ b/tests/models/self_supervised/test_scripts.py @@ -88,10 +88,12 @@ def test_cli_run_self_supervised_swav(cli_args): cli_main() -@pytest.mark.parametrize('cli_args', [ - f'--dataset cifar10 --data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2' - ' --gpus 0 --fp32 --online_ft' -]) +@pytest.mark.parametrize( + 'cli_args', [ + f'--dataset cifar10 --data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2' + ' --gpus 0 --fp32 --online_ft' + ] +) def test_cli_run_self_supervised_simsiam(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.self_supervised.simsiam.simsiam_module import cli_main