diff --git a/.github/workflows/code-format.yml b/.github/workflows/code-format.yml index e799af8c01..972b2e2452 100644 --- a/.github/workflows/code-format.yml +++ b/.github/workflows/code-format.yml @@ -21,8 +21,7 @@ jobs: pip --version shell: bash - name: PEP8 - run: | - flake8 . + run: flake8 . format-check-yapf: runs-on: ubuntu-20.04 @@ -38,8 +37,7 @@ jobs: pip --version shell: bash - name: yapf - run: | - yapf --diff --parallel --recursive . + run: yapf --diff --parallel --recursive . imports-check-isort: runs-on: ubuntu-20.04 @@ -67,5 +65,4 @@ jobs: pip install mypy pip list - name: mypy - run: | - mypy + run: mypy diff --git a/CHANGELOG.md b/CHANGELOG.md index b1df0f340a..6195660283 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added flags to datamodules ([#388](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/388)) - Added metric GIoU ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347)) - Added Intersection over Union Metric/Loss ([#469](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/469)) +- Added SimSiam model ([#407](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/407)) ### Changed diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py new file mode 100644 index 0000000000..32eda875bb --- /dev/null +++ b/pl_bolts/callbacks/knn_online.py @@ -0,0 +1,121 @@ +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from pytorch_lightning import Callback, LightningModule, Trainer +from torch.utils.data import DataLoader + +from pl_bolts.utils import _SKLEARN_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _SKLEARN_AVAILABLE: + from sklearn.neighbors import KNeighborsClassifier +else: # pragma: no cover + warn_missing_pkg("sklearn", pypi_name="scikit-learn") + + +class KNNOnlineEvaluator(Callback): # pragma: no-cover + """ + Evaluates self-supervised K nearest neighbors. + + Example:: + + # your model must have 1 attribute + model = Model() + model.num_classes = ... # the num of classes in the model + + online_eval = KNNOnlineEvaluator( + num_classes=model.num_classes, + dataset='imagenet' + ) + + """ + + def __init__( + self, + dataset: str, + num_classes: Optional[int] = None, + ) -> None: + """ + Args: + dataset: if stl10, need to get the labeled batch + num_classes: Number of classes + """ + if not _SKLEARN_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( + "You want to use `KNeighborsClassifier` function from `scikit-learn` which is not installed yet." + ) + + super().__init__() + + self.num_classes = num_classes + self.dataset = dataset + + def get_representations(self, pl_module: LightningModule, x: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + representations = pl_module(x) + representations = representations.reshape(representations.size(0), -1) + return representations + + def get_all_representations( + self, + pl_module: LightningModule, + dataloader: DataLoader, + ) -> Tuple[np.ndarray, np.ndarray]: + all_representations = None + ys = None + + for batch in dataloader: + x, y = self.to_device(batch, pl_module.device) + + with torch.no_grad(): + representations = self.get_representations(pl_module, x) + + if all_representations is None: + all_representations = representations.detach() + else: + all_representations = torch.cat([all_representations, representations.detach()]) + + if ys is None: + ys = y + else: + ys = torch.cat([ys, y]) + + return all_representations.cpu().numpy(), ys.cpu().numpy() # type: ignore[union-attr] + + def to_device(self, batch: torch.Tensor, device: Union[str, torch.device]) -> Tuple[torch.Tensor, torch.Tensor]: + # get the labeled batch + if self.dataset == 'stl10': + labeled_batch = batch[1] + batch = labeled_batch + + inputs, y = batch + + # last input is for online eval + x = inputs[-1] + x = x.to(device) + y = y.to(device) + + return x, y + + def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + pl_module.knn_evaluator = KNeighborsClassifier(n_neighbors=self.num_classes) + + train_dataloader = pl_module.train_dataloader() + representations, y = self.get_all_representations(pl_module, train_dataloader) + + # knn fit + pl_module.knn_evaluator.fit(representations, y) # type: ignore[union-attr,operator] + train_acc = pl_module.knn_evaluator.score(representations, y) # type: ignore[union-attr,operator] + + # log metrics + + val_dataloader = pl_module.val_dataloader() + representations, y = self.get_all_representations(pl_module, val_dataloader) # type: ignore[arg-type] + + # knn val acc + val_acc = pl_module.knn_evaluator.score(representations, y) # type: ignore[union-attr,operator] + + # log metrics + pl_module.log('online_knn_train_acc', train_acc, on_step=False, on_epoch=True, sync_dist=True) + pl_module.log('online_knn_val_acc', val_acc, on_step=False, on_epoch=True, sync_dist=True) diff --git a/pl_bolts/models/self_supervised/__init__.py b/pl_bolts/models/self_supervised/__init__.py index 0960e8e297..c88421a3ec 100644 --- a/pl_bolts/models/self_supervised/__init__.py +++ b/pl_bolts/models/self_supervised/__init__.py @@ -24,6 +24,7 @@ from pl_bolts.models.self_supervised.evaluator import SSLEvaluator # noqa: F401 from pl_bolts.models.self_supervised.moco.moco2_module import MocoV2 # noqa: F401 from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR # noqa: F401 +from pl_bolts.models.self_supervised.simsiam.simsiam_module import SimSiam # noqa: F401 from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner # noqa: F401 from pl_bolts.models.self_supervised.swav.swav_module import SwAV # noqa: F401 @@ -34,6 +35,7 @@ "SSLEvaluator", "MocoV2", "SimCLR", + "SimSiam", "SSLFineTuner", "SwAV", ] diff --git a/pl_bolts/models/self_supervised/simsiam/__init__.py b/pl_bolts/models/self_supervised/simsiam/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pl_bolts/models/self_supervised/simsiam/models.py b/pl_bolts/models/self_supervised/simsiam/models.py new file mode 100644 index 0000000000..ad020ad9fc --- /dev/null +++ b/pl_bolts/models/self_supervised/simsiam/models.py @@ -0,0 +1,51 @@ +from typing import Optional, Tuple + +import torch +from torch import nn + +from pl_bolts.utils.self_supervised import torchvision_ssl_encoder + + +class MLP(nn.Module): + + def __init__(self, input_dim: int = 2048, hidden_size: int = 4096, output_dim: int = 256) -> None: + super().__init__() + self.output_dim = output_dim + self.input_dim = input_dim + self.model = nn.Sequential( + nn.Linear(input_dim, hidden_size, bias=False), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, output_dim, bias=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.model(x) + return x + + +class SiameseArm(nn.Module): + + def __init__( + self, + encoder: Optional[nn.Module] = None, + input_dim: int = 2048, + hidden_size: int = 4096, + output_dim: int = 256, + ) -> None: + super().__init__() + + if encoder is None: + encoder = torchvision_ssl_encoder('resnet50') + # Encoder + self.encoder = encoder + # Projector + self.projector = MLP(input_dim, hidden_size, output_dim) + # Predictor + self.predictor = MLP(output_dim, hidden_size, output_dim) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + y = self.encoder(x)[0] + z = self.projector(y) + h = self.predictor(z) + return y, z, h diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py new file mode 100644 index 0000000000..cc69d4169c --- /dev/null +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -0,0 +1,441 @@ +import math +from argparse import ArgumentParser +from typing import Callable, Optional + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from pytorch_lightning import seed_everything +from pytorch_lightning.utilities import AMPType +from torch.optim.optimizer import Optimizer + +from pl_bolts.models.self_supervised.resnets import resnet18, resnet50 +from pl_bolts.models.self_supervised.simsiam.models import SiameseArm +from pl_bolts.optimizers.lars_scheduling import LARSWrapper +from pl_bolts.transforms.dataset_normalizations import ( + cifar10_normalization, + imagenet_normalization, + stl10_normalization, +) + + +class SimSiam(pl.LightningModule): + """ + PyTorch Lightning implementation of `Exploring Simple Siamese Representation Learning (SimSiam) + `_ + + Paper authors: Xinlei Chen, Kaiming He. + + Model implemented by: + - `Zvi Lapp `_ + + .. warning:: Work in progress. This implementation is still being verified. + + TODOs: + - verify on CIFAR-10 + - verify on STL-10 + - pre-train on imagenet + + Example:: + + model = SimSiam() + + dm = CIFAR10DataModule(num_workers=0) + dm.train_transforms = SimCLRTrainDataTransform(32) + dm.val_transforms = SimCLREvalDataTransform(32) + + trainer = pl.Trainer() + trainer.fit(model, dm) + + Train:: + + trainer = Trainer() + trainer.fit(model) + + CLI command:: + + # cifar10 + python simsiam_module.py --gpus 1 + + # imagenet + python simsiam_module.py + --gpus 8 + --dataset imagenet2012 + --data_dir /path/to/imagenet/ + --meta_dir /path/to/folder/with/meta.bin/ + --batch_size 32 + """ + + def __init__( + self, + gpus: int, + num_samples: int, + batch_size: int, + dataset: str, + nodes: int = 1, + arch: str = 'resnet50', + hidden_mlp: int = 2048, + feat_dim: int = 128, + warmup_epochs: int = 10, + max_epochs: int = 100, + temperature: float = 0.1, + first_conv: bool = True, + maxpool1: bool = True, + optimizer: str = 'adam', + lars_wrapper: bool = True, + exclude_bn_bias: bool = False, + start_lr: float = 0., + learning_rate: float = 1e-3, + final_lr: float = 0., + weight_decay: float = 1e-6, + **kwargs + ): + """ + Args: + datamodule: The datamodule + learning_rate: the learning rate + weight_decay: optimizer weight decay + input_height: image input height + batch_size: the batch size + num_workers: number of workers + warmup_epochs: num of epochs for scheduler warm up + max_epochs: max epochs for scheduler + """ + super().__init__() + self.save_hyperparameters() + + self.gpus = gpus + self.nodes = nodes + self.arch = arch + self.dataset = dataset + self.num_samples = num_samples + self.batch_size = batch_size + + self.hidden_mlp = hidden_mlp + self.feat_dim = feat_dim + self.first_conv = first_conv + self.maxpool1 = maxpool1 + + self.optim = optimizer + self.lars_wrapper = lars_wrapper + self.exclude_bn_bias = exclude_bn_bias + self.weight_decay = weight_decay + self.temperature = temperature + + self.start_lr = start_lr + self.final_lr = final_lr + self.learning_rate = learning_rate + self.warmup_epochs = warmup_epochs + self.max_epochs = max_epochs + + self.init_model() + + # compute iters per epoch + global_batch_size = self.nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size + self.train_iters_per_epoch = self.num_samples // global_batch_size + + # define LR schedule + warmup_lr_schedule = np.linspace( + self.start_lr, self.learning_rate, self.train_iters_per_epoch * self.warmup_epochs + ) + iters = np.arange(self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs)) + cosine_lr_schedule = np.array([ + self.final_lr + 0.5 * (self.learning_rate - self.final_lr) * + (1 + math.cos(math.pi * t / (self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs)))) + for t in iters + ]) + + self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) + + def init_model(self): + if self.arch == 'resnet18': + backbone = resnet18 + elif self.arch == 'resnet50': + backbone = resnet50 + + encoder = backbone(first_conv=self.first_conv, maxpool1=self.maxpool1, return_all_feature_maps=False) + self.online_network = SiameseArm( + encoder, input_dim=self.hidden_mlp, hidden_size=self.hidden_mlp, output_dim=self.feat_dim + ) + + def forward(self, x): + y, _, _ = self.online_network(x) + return y + + def cosine_similarity(self, a, b): + b = b.detach() # stop gradient of backbone + projection mlp + a = F.normalize(a, dim=-1) + b = F.normalize(b, dim=-1) + sim = -1 * (a * b).sum(-1).mean() + return sim + + def training_step(self, batch, batch_idx): + (img_1, img_2, _), y = batch + + # Image 1 to image 2 loss + _, z1, h1 = self.online_network(img_1) + _, z2, h2 = self.online_network(img_2) + loss = self.cosine_similarity(h1, z2) / 2 + self.cosine_similarity(h2, z1) / 2 + + # log results + self.log_dict({"loss": loss}) + + return loss + + def validation_step(self, batch, batch_idx): + (img_1, img_2, _), y = batch + + # Image 1 to image 2 loss + _, z1, h1 = self.online_network(img_1) + _, z2, h2 = self.online_network(img_2) + loss = self.cosine_similarity(h1, z2) / 2 + self.cosine_similarity(h2, z1) / 2 + + # log results + self.log_dict({"loss": loss}) + + return loss + + def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']): + params = [] + excluded_params = [] + + for name, param in named_params: + if not param.requires_grad: + continue + elif any(layer_name in name for layer_name in skip_list): + excluded_params.append(param) + else: + params.append(param) + + return [ + { + 'params': params, + 'weight_decay': weight_decay + }, + { + 'params': excluded_params, + 'weight_decay': 0. + }, + ] + + def configure_optimizers(self): + if self.exclude_bn_bias: + params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay) + else: + params = self.parameters() + + if self.optim == 'sgd': + optimizer = torch.optim.SGD(params, lr=self.learning_rate, momentum=0.9, weight_decay=self.weight_decay) + elif self.optim == 'adam': + optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.lars_wrapper: + optimizer = LARSWrapper( + optimizer, + eta=0.001, # trust coefficient + clip=False + ) + + return optimizer + + def optimizer_step( + self, + epoch: int, + batch_idx: int, + optimizer: Optimizer, + optimizer_idx: int, + optimizer_closure: Optional[Callable] = None, + on_tpu: bool = False, + using_native_amp: bool = False, + using_lbfgs: bool = False, + ) -> None: + # warm-up + decay schedule placed here since LARSWrapper is not optimizer class + # adjust LR of optim contained within LARSWrapper + if self.lars_wrapper: + for param_group in optimizer.optim.param_groups: + param_group["lr"] = self.lr_schedule[self.trainer.global_step] + else: + for param_group in optimizer.param_groups: + param_group["lr"] = self.lr_schedule[self.trainer.global_step] + + # log LR (LearningRateLogger callback doesn't work with LARSWrapper) + self.log('learning_rate', self.lr_schedule[self.trainer.global_step], on_step=True, on_epoch=False) + + # from lightning + if self.trainer.amp_backend == AMPType.NATIVE: + optimizer_closure() + self.trainer.scaler.step(optimizer) + elif self.trainer.amp_backend == AMPType.APEX: + optimizer_closure() + optimizer.step() + else: + optimizer.step(closure=optimizer_closure) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + # model params + parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture") + # specify flags to store false + parser.add_argument("--first_conv", action="store_false") + parser.add_argument("--maxpool1", action="store_false") + parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head") + parser.add_argument("--feat_dim", default=128, type=int, help="feature dimension") + parser.add_argument("--online_ft", action="store_true") + parser.add_argument("--fp32", action="store_true") + + # transform params + parser.add_argument("--gaussian_blur", action="store_true", help="add gaussian blur") + parser.add_argument("--jitter_strength", type=float, default=1.0, help="jitter strength") + parser.add_argument("--dataset", type=str, default="cifar10", help="stl10, cifar10") + parser.add_argument("--data_dir", type=str, default=".", help="path to download data") + + # training params + parser.add_argument("--nodes", default=1, type=int, help="number of nodes for training") + parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") + parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd") + parser.add_argument("--lars_wrapper", action="store_true", help="apple lars wrapper over optimizer used") + parser.add_argument("--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay") + parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") + parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu") + + parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss") + parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay") + parser.add_argument("--learning_rate", default=1e-3, type=float, help="base learning rate") + parser.add_argument("--start_lr", default=0, type=float, help="initial warmup learning rate") + parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate") + + return parser + + +def cli_main(): + from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator + from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule + from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform + + seed_everything(1234) + + parser = ArgumentParser() + + # trainer args + parser = pl.Trainer.add_argparse_args(parser) + + # model args + parser = SimSiam.add_model_specific_args(parser) + args = parser.parse_args() + + # pick data + dm = None + + # init datamodule + if args.dataset == "stl10": + dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) + + dm.train_dataloader = dm.train_dataloader_mixed + dm.val_dataloader = dm.val_dataloader_mixed + args.num_samples = dm.num_unlabeled_samples + + args.maxpool1 = False + args.first_conv = True + args.input_height = dm.size()[-1] + + normalization = stl10_normalization() + + args.gaussian_blur = True + args.jitter_strength = 1.0 + elif args.dataset == "cifar10": + val_split = 5000 + if args.nodes * args.gpus * args.batch_size > val_split: + val_split = args.nodes * args.gpus * args.batch_size + + dm = CIFAR10DataModule( + data_dir=args.data_dir, + batch_size=args.batch_size, + num_workers=args.num_workers, + val_split=val_split, + ) + + args.num_samples = dm.num_samples + + args.maxpool1 = False + args.first_conv = False + args.input_height = dm.size()[-1] + args.temperature = 0.5 + + normalization = cifar10_normalization() + + args.gaussian_blur = False + args.jitter_strength = 0.5 + elif args.dataset == "imagenet": + args.maxpool1 = True + args.first_conv = True + normalization = imagenet_normalization() + + args.gaussian_blur = True + args.jitter_strength = 1.0 + + args.batch_size = 64 + args.nodes = 8 + args.gpus = 8 # per-node + args.max_epochs = 800 + + args.optimizer = "sgd" + args.lars_wrapper = True + args.learning_rate = 4.8 + args.final_lr = 0.0048 + args.start_lr = 0.3 + args.online_ft = True + + dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) + + args.num_samples = dm.num_samples + args.input_height = dm.size()[-1] + else: + raise NotImplementedError("other datasets have not been implemented till now") + + dm.train_transforms = SimCLRTrainDataTransform( + input_height=args.input_height, + gaussian_blur=args.gaussian_blur, + jitter_strength=args.jitter_strength, + normalize=normalization, + ) + + dm.val_transforms = SimCLREvalDataTransform( + input_height=args.input_height, + gaussian_blur=args.gaussian_blur, + jitter_strength=args.jitter_strength, + normalize=normalization, + ) + + model = SimSiam(**args.__dict__) + + # finetune in real-time + online_evaluator = None + if args.online_ft: + # online eval + online_evaluator = SSLOnlineEvaluator( + drop_p=0.0, + hidden_dim=None, + z_dim=args.hidden_mlp, + num_classes=dm.num_classes, + dataset=args.dataset, + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + max_steps=None if args.max_steps == -1 else args.max_steps, + gpus=args.gpus, + num_nodes=args.nodes, + distributed_backend="ddp" if args.gpus > 1 else None, + sync_batchnorm=True if args.gpus > 1 else False, + precision=32 if args.fp32 else 16, + callbacks=[online_evaluator] if args.online_ft else None, + fast_dev_run=args.fast_dev_run, + ) + + trainer.fit(model, dm) + + +if __name__ == "__main__": + cli_main() diff --git a/requirements/test.txt b/requirements/test.txt index 7554b5c88a..c8257bd5e3 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -13,3 +13,5 @@ mypy yapf atari-py==0.2.6 # needed for RL + +scikit-learn>=0.23 \ No newline at end of file diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index 6770f83d61..edaa6f6902 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -4,7 +4,7 @@ from pytorch_lightning import seed_everything from pl_bolts.datamodules import CIFAR10DataModule -from pl_bolts.models.self_supervised import AMDIM, BYOL, CPCV2, MocoV2, SimCLR, SwAV +from pl_bolts.models.self_supervised import AMDIM, BYOL, CPCV2, MocoV2, SimCLR, SimSiam, SwAV from pl_bolts.models.self_supervised.cpc import CPCEvalTransformsCIFAR10, CPCTrainTransformsCIFAR10 from pl_bolts.models.self_supervised.moco.callbacks import MocoLRScheduler from pl_bolts.models.self_supervised.moco.transforms import Moco2EvalCIFAR10Transforms, Moco2TrainCIFAR10Transforms @@ -125,3 +125,18 @@ def test_swav(tmpdir, datadir): loss = trainer.progress_bar_dict['loss'] assert float(loss) > 0 + + +def test_simsiam(tmpdir, datadir): + seed_everything() + + datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) + datamodule.train_transforms = SimCLRTrainDataTransform(32) + datamodule.val_transforms = SimCLREvalDataTransform(32) + + model = SimSiam(batch_size=2, num_samples=datamodule.num_samples, gpus=0, nodes=1, dataset='cifar10') + trainer = pl.Trainer(gpus=0, fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir) + trainer.fit(model, datamodule) + loss = trainer.progress_bar_dict['loss'] + + assert float(loss) < 0 diff --git a/tests/models/self_supervised/test_scripts.py b/tests/models/self_supervised/test_scripts.py index e7cf0b19f4..76ef8d6053 100644 --- a/tests/models/self_supervised/test_scripts.py +++ b/tests/models/self_supervised/test_scripts.py @@ -86,3 +86,18 @@ def test_cli_run_self_supervised_swav(cli_args): cli_args = cli_args.split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() + + +@pytest.mark.parametrize( + 'cli_args', [ + f'--dataset cifar10 --data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2' + ' --gpus 0 --fp32 --online_ft' + ] +) +def test_cli_run_self_supervised_simsiam(cli_args): + """Test running CLI for an example with default params.""" + from pl_bolts.models.self_supervised.simsiam.simsiam_module import cli_main + + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main()