From 760033c7616433e6a0b538fdde35c183e7f506b7 Mon Sep 17 00:00:00 2001 From: jameschapman19 Date: Wed, 6 Dec 2023 12:12:22 +0000 Subject: [PATCH 1/9] Adding SSL-EY --- docs/source/solo/losses/ssley.rst | 5 + docs/source/solo/methods/ssley.rst | 26 ++++ docs/source/start/available.rst | 1 + scripts/linear/imagenet-100/ssley.yaml | 45 ++++++ scripts/pretrain/cifar/ssley.yaml | 80 ++++++++++ scripts/pretrain/imagenet-100/ssley.yaml | 81 ++++++++++ solo/losses/__init__.py | 2 + solo/losses/ssley.py | 53 +++++++ solo/methods/__init__.py | 3 + solo/methods/ssley.py | 147 ++++++++++++++++++ tests/methods/test_ssley.py | 90 +++++++++++ .../linear/test_imagenet100_scripts.sh | 3 +- tests/scripts/pretrain/test_cifar_scripts.sh | 2 +- .../pretrain/test_imagenet100_scripts.sh | 2 +- 14 files changed, 537 insertions(+), 3 deletions(-) create mode 100644 docs/source/solo/losses/ssley.rst create mode 100644 docs/source/solo/methods/ssley.rst create mode 100644 scripts/linear/imagenet-100/ssley.yaml create mode 100644 scripts/pretrain/cifar/ssley.yaml create mode 100644 scripts/pretrain/imagenet-100/ssley.yaml create mode 100644 solo/losses/ssley.py create mode 100644 solo/methods/ssley.py create mode 100644 tests/methods/test_ssley.py diff --git a/docs/source/solo/losses/ssley.rst b/docs/source/solo/losses/ssley.rst new file mode 100644 index 000000000..bb9bc4d50 --- /dev/null +++ b/docs/source/solo/losses/ssley.rst @@ -0,0 +1,5 @@ +SSL-EY +------- + +.. autofunction:: solo.losses.ssley.ssley_loss_func + :noindex: diff --git a/docs/source/solo/methods/ssley.rst b/docs/source/solo/methods/ssley.rst new file mode 100644 index 000000000..f90152b88 --- /dev/null +++ b/docs/source/solo/methods/ssley.rst @@ -0,0 +1,26 @@ +SSL-EY +======= + + +.. automethod:: solo.methods.ssley.SSLEY.__init__ + :noindex: + +add_model_specific_args +~~~~~~~~~~~~~~~~~~~~~~~ +.. automethod:: solo.methods.ssley.SSLEY.add_model_specific_args + :noindex: + +learnable_params +~~~~~~~~~~~~~~~~ +.. autoattribute:: solo.methods.ssley.SSLEY.learnable_params + :noindex: + +forward +~~~~~~~ +.. automethod:: solo.methods.ssley.SSLEY.forward + :noindex: + +training_step +~~~~~~~~~~~~~ +.. automethod:: solo.methods.ssley.SSLEY.training_step + :noindex: diff --git a/docs/source/start/available.rst b/docs/source/start/available.rst index b2af14549..2ce058641 100644 --- a/docs/source/start/available.rst +++ b/docs/source/start/available.rst @@ -11,6 +11,7 @@ Methods available * `SwAV `_ * `VICReg `_ * `W-MSE `_ +* `SSL-EYE `_ ************ Extra flavor diff --git a/scripts/linear/imagenet-100/ssley.yaml b/scripts/linear/imagenet-100/ssley.yaml new file mode 100644 index 000000000..6dc201638 --- /dev/null +++ b/scripts/linear/imagenet-100/ssley.yaml @@ -0,0 +1,45 @@ +defaults: + - _self_ + - wandb: private.yaml + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +# disable hydra outputs +hydra: + output_subdir: null + run: + dir: . + +name: "ssley-imagenet100-linear" +pretrained_feature_extractor: None +backbone: + name: "resnet18" +pretrain_method: "ssley" +data: + dataset: imagenet100 + train_path: "./datasets/imagenet-100/train" + val_path: "./datasets/imagenet-100/val" + format: "dali" + num_workers: 4 +optimizer: + name: "sgd" + batch_size: 256 + lr: 0.3 + weight_decay: 0 +scheduler: + name: "step" + lr_decay_steps: [60, 80] +checkpoint: + enabled: True + dir: "trained_models" + frequency: 1 +auto_resume: + enabled: True + +# overwrite PL stuff +max_epochs: 100 +devices: [0] +sync_batchnorm: True +accelerator: "gpu" +strategy: "ddp" +precision: 16 diff --git a/scripts/pretrain/cifar/ssley.yaml b/scripts/pretrain/cifar/ssley.yaml new file mode 100644 index 000000000..b558ec068 --- /dev/null +++ b/scripts/pretrain/cifar/ssley.yaml @@ -0,0 +1,80 @@ +defaults: + - _self_ + - wandb: private.yaml + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +# disable hydra outputs +hydra: + output_subdir: null + run: + dir: . + +name: "ssley-cifar10" # change here for cifar100 +method: "ssley" +backbone: + name: "resnet18" +method_kwargs: + proj_hidden_dim: 2048 + proj_output_dim: 2048 +data: + dataset: cifar10 # change here for cifar100 + train_path: "./datasets" + val_path: "datasets/imagenet100/val" + format: "image_folder" + num_workers: 4 +augmentations: + - rrc: + enabled: True + crop_min_scale: 0.2 + crop_max_scale: 1.0 + color_jitter: + enabled: True + brightness: 0.4 + contrast: 0.4 + saturation: 0.2 + hue: 0.1 + prob: 0.8 + grayscale: + enabled: True + prob: 0.2 + gaussian_blur: + enabled: False + prob: 0.0 + solarization: + enabled: True + prob: 0.1 + equalization: + enabled: False + prob: 0.0 + horizontal_flip: + enabled: True + prob: 0.5 + crop_size: 32 + num_crops: 2 +optimizer: + name: "lars" + batch_size: 256 + lr: 0.3 + classifier_lr: 0.1 + weight_decay: 1e-4 + kwargs: + clip_lr: True + eta: 0.02 + exclude_bias_n_norm: True +scheduler: + name: "warmup_cosine" +checkpoint: + enabled: True + dir: "trained_models" + frequency: 1 +auto_resume: + enabled: True + +# overwrite PL stuff +max_epochs: 1000 +devices: [0] +sync_batchnorm: True +accelerator: "gpu" +strategy: "ddp" +precision: 16-mixed diff --git a/scripts/pretrain/imagenet-100/ssley.yaml b/scripts/pretrain/imagenet-100/ssley.yaml new file mode 100644 index 000000000..7429d92b3 --- /dev/null +++ b/scripts/pretrain/imagenet-100/ssley.yaml @@ -0,0 +1,81 @@ +defaults: + - _self_ + - augmentations: ssley.yaml + - wandb: private.yaml + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +# disable hydra outputs +hydra: + output_subdir: null + run: + dir: . + +name: "ssley-imagenet100" +method: "ssley" +backbone: + name: "resnet18" +method_kwargs: + proj_hidden_dim: 2048 + proj_output_dim: 2048 +data: + dataset: imagenet100 + train_path: "datasets/imagenet100/train" + val_path: "datasets/imagenet100/val" + format: "dali" + num_workers: 4 +augmentations: + - rrc: + enabled: True + crop_min_scale: 0.2 + crop_max_scale: 1.0 + color_jitter: + enabled: True + brightness: 0.4 + contrast: 0.4 + saturation: 0.2 + hue: 0.1 + prob: 0.8 + grayscale: + enabled: True + prob: 0.2 + gaussian_blur: + enabled: True + prob: 0.5 + solarization: + enabled: True + prob: 0.1 + equalization: + enabled: False + prob: 0.0 + horizontal_flip: + enabled: True + prob: 0.5 + crop_size: 224 + num_crops: 2 +optimizer: + name: "lars" + batch_size: 128 + lr: 0.3 + classifier_lr: 0.1 + weight_decay: 1e-4 + kwargs: + clip_lr: True + eta: 0.02 + exclude_bias_n_norm: True +scheduler: + name: "warmup_cosine" +checkpoint: + enabled: True + dir: "trained_models" + frequency: 1 +auto_resume: + enabled: True + +# overwrite PL stuff +max_epochs: 400 +devices: [0, 1] +sync_batchnorm: True +accelerator: "gpu" +strategy: "ddp" +precision: 16-mixed diff --git a/solo/losses/__init__.py b/solo/losses/__init__.py index 23a365d63..c5253f14d 100644 --- a/solo/losses/__init__.py +++ b/solo/losses/__init__.py @@ -32,6 +32,7 @@ from solo.losses.vibcreg import vibcreg_loss_func from solo.losses.vicreg import vicreg_loss_func from solo.losses.wmse import wmse_loss_func +from solo.losses.ssley import ssley_loss_func __all__ = [ "barlow_loss_func", @@ -49,4 +50,5 @@ "vibcreg_loss_func", "vicreg_loss_func", "wmse_loss_func", + "ssley_loss_func" ] diff --git a/solo/losses/ssley.py b/solo/losses/ssley.py new file mode 100644 index 000000000..2586f37ab --- /dev/null +++ b/solo/losses/ssley.py @@ -0,0 +1,53 @@ +# Copyright 2023 solo-learn development team. + +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the +# Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies +# or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import torch +import torch.nn.functional as F +from solo.utils.misc import gather + + +def ssley_loss_func( + z1: torch.Tensor, + z2: torch.Tensor, +) -> torch.Tensor: + """Computes SSL-EY's loss given batch of projected features z1 from view 1 and + projected features z2 from view 2. + + Args: + z1 (torch.Tensor): NxD Tensor containing projected features from view 1. + z2 (torch.Tensor): NxD Tensor containing projected features from view 2. + + Returns: + torch.Tensor: VICReg loss. + """ + + sim_loss = invariance_loss(z1, z2) + + N, D = z1.size() + B = torch.cov(torch.hstack((z1, z2)).T) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(B) + world_size = dist.get_world_size() + B /= world_size + + A = B[:D, D:] + B[D:, :D] + B = B[:D, :D] + B[D:, D:] + + return -torch.trace(2 * A - B @ B) \ No newline at end of file diff --git a/solo/methods/__init__.py b/solo/methods/__init__.py index 64c4af1e4..87dd6458c 100644 --- a/solo/methods/__init__.py +++ b/solo/methods/__init__.py @@ -37,6 +37,7 @@ from solo.methods.vibcreg import VIbCReg from solo.methods.vicreg import VICReg from solo.methods.wmse import WMSE +from solo.methods.ssley import SSLEY METHODS = { # base classes @@ -61,6 +62,7 @@ "vibcreg": VIbCReg, "vicreg": VICReg, "wmse": WMSE, + "ssley": SSLEY, } __all__ = [ "BarlowTwins", @@ -83,4 +85,5 @@ "VIbCReg", "VICReg", "WMSE", + "SSLEY", ] diff --git a/solo/methods/ssley.py b/solo/methods/ssley.py new file mode 100644 index 000000000..9071de3ba --- /dev/null +++ b/solo/methods/ssley.py @@ -0,0 +1,147 @@ +# Copyright 2023 solo-learn development team. + +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the +# Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies +# or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from typing import Any, Dict, List, Sequence + +import omegaconf +import torch +import torch.nn as nn +from solo.losses.ssley import ssley_loss_func +from solo.methods.base import BaseMethod +from solo.utils.misc import omegaconf_select + + +class SSLEY(BaseMethod): + def __init__(self, cfg: omegaconf.DictConfig): + """Implements SSL-EY (https://neurips.cc/virtual/2023/80864) + + Extra cfg settings: + method_kwargs: + proj_output_dim (int): number of dimensions of the projected features. + proj_hidden_dim (int): number of neurons in the hidden layers of the projector. + """ + + super().__init__(cfg) + + self.sim_loss_weight: float = cfg.method_kwargs.sim_loss_weight + self.var_loss_weight: float = cfg.method_kwargs.var_loss_weight + self.cov_loss_weight: float = cfg.method_kwargs.cov_loss_weight + + proj_hidden_dim: int = cfg.method_kwargs.proj_hidden_dim + proj_output_dim: int = cfg.method_kwargs.proj_output_dim + + # projector + self.projector = nn.Sequential( + nn.Linear(self.features_dim, proj_hidden_dim), + nn.BatchNorm1d(proj_hidden_dim), + nn.ReLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.BatchNorm1d(proj_hidden_dim), + nn.ReLU(), + nn.Linear(proj_hidden_dim, proj_output_dim), + ) + + @staticmethod + def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig: + """Adds method specific default values/checks for config. + + Args: + cfg (omegaconf.DictConfig): DictConfig object. + + Returns: + omegaconf.DictConfig: same as the argument, used to avoid errors. + """ + + cfg = super(SSLEY, SSLEY).add_and_assert_specific_cfg(cfg) + + assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_output_dim") + assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_hidden_dim") + + cfg.method_kwargs.sim_loss_weight = omegaconf_select( + cfg, + "method_kwargs.sim_loss_weight", + 25.0, + ) + cfg.method_kwargs.var_loss_weight = omegaconf_select( + cfg, + "method_kwargs.var_loss_weight", + 25.0, + ) + cfg.method_kwargs.cov_loss_weight = omegaconf_select( + cfg, + "method_kwargs.cov_loss_weight", + 1.0, + ) + + return cfg + + @property + def learnable_params(self) -> List[dict]: + """Adds projector parameters to the parent's learnable parameters. + + Returns: + List[dict]: list of learnable parameters. + """ + + extra_learnable_params = [{"name": "projector", "params": self.projector.parameters()}] + return super().learnable_params + extra_learnable_params + + def forward(self, X: torch.Tensor) -> Dict[str, Any]: + """Performs the forward pass of the backbone and the projector. + + Args: + X (torch.Tensor): a batch of images in the tensor format. + + Returns: + Dict[str, Any]: a dict containing the outputs of the parent and the projected features. + """ + + out = super().forward(X) + z = self.projector(out["feats"]) + out.update({"z": z}) + return out + + def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: + """Training step for SSL-EY reusing BaseMethod training step. + + Args: + batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where + [X] is a list of size num_crops containing batches of images. + batch_idx (int): index of the batch. + + Returns: + torch.Tensor: total loss composed of SSL-EY loss and classification loss. + """ + + out = super().training_step(batch, batch_idx) + class_loss = out["loss"] + z1, z2 = out["z"] + + # ------- ssley loss ------- + ssley_loss = ssley_loss_func( + z1, + z2, + sim_loss_weight=self.sim_loss_weight, + var_loss_weight=self.var_loss_weight, + cov_loss_weight=self.cov_loss_weight, + ) + + self.log("train_ssley_loss", ssley_loss, on_epoch=True, sync_dist=True) + + return ssley_loss + class_loss \ No newline at end of file diff --git a/tests/methods/test_ssley.py b/tests/methods/test_ssley.py new file mode 100644 index 000000000..f77309ac6 --- /dev/null +++ b/tests/methods/test_ssley.py @@ -0,0 +1,90 @@ +# Copyright 2023 solo-learn development team. + +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the +# Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies +# or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import torch +from solo.methods import SSLEY + +from .utils import gen_base_cfg, gen_batch, gen_trainer, prepare_dummy_dataloaders + + +def test_ssley(): + method_kwargs = { + "proj_hidden_dim": 2048, + "proj_output_dim": 2048, + "sim_loss_weight": 25.0, + "var_loss_weight": 25.0, + "cov_loss_weight": 1.0, + } + + cfg = gen_base_cfg("ssley", batch_size=2, num_classes=100, momentum=True) + cfg.method_kwargs = method_kwargs + model = SSLEY(cfg) + + # test arguments + model.add_and_assert_specific_cfg(cfg) + + # test parameters + assert model.learnable_params is not None + + # test forward + batch, _ = gen_batch(cfg.optimizer.batch_size, cfg.data.num_classes, "imagenet100") + out = model(batch[1][0]) + assert ( + "logits" in out + and isinstance(out["logits"], torch.Tensor) + and out["logits"].size() == (cfg.optimizer.batch_size, cfg.data.num_classes) + ) + assert ( + "feats" in out + and isinstance(out["feats"], torch.Tensor) + and out["feats"].size() == (cfg.optimizer.batch_size, model.features_dim) + ) + assert ( + "z" in out + and isinstance(out["z"], torch.Tensor) + and out["z"].size() == (cfg.optimizer.batch_size, method_kwargs["proj_output_dim"]) + ) + + # imagenet + model = SSLEY(cfg) + + trainer = gen_trainer(cfg) + train_dl, val_dl = prepare_dummy_dataloaders( + "imagenet100", + num_large_crops=cfg.data.num_large_crops, + num_small_crops=0, + num_classes=cfg.data.num_classes, + batch_size=cfg.optimizer.batch_size, + ) + trainer.fit(model, train_dl, val_dl) + + # cifar + cfg.data.dataset = "cifar10" + cfg.data.num_classes = 10 + model = SSLEY(cfg) + + trainer = gen_trainer(cfg) + train_dl, val_dl = prepare_dummy_dataloaders( + "cifar10", + num_large_crops=cfg.data.num_large_crops, + num_small_crops=0, + num_classes=cfg.data.num_classes, + batch_size=cfg.optimizer.batch_size, + ) + trainer.fit(model, train_dl, val_dl) diff --git a/tests/scripts/linear/test_imagenet100_scripts.sh b/tests/scripts/linear/test_imagenet100_scripts.sh index 683e13cb6..ed36a2923 100644 --- a/tests/scripts/linear/test_imagenet100_scripts.sh +++ b/tests/scripts/linear/test_imagenet100_scripts.sh @@ -1,7 +1,7 @@ TRAIN_PATH=$1 VAL_PATH=$2 FORMAT=$3 -METHODS=("barlow" "byol" "dino" "mocov2plus" "mocov3_vit" "mocov3" "nnclr" "ressl" "simclr" "simsiam" "swav" "vibcreg" "vicreg") +METHODS=("barlow" "byol" "dino" "mocov2plus" "mocov3_vit" "mocov3" "nnclr" "ressl" "simclr" "simsiam" "swav" "vibcreg" "vicreg" "ssley") # first run ../pretrain/test_imagenet_scripts.sh and then fill the paths here # escape path with \"PATH-HERE\" @@ -20,6 +20,7 @@ PRETRAINED_PATHS=( \"PATH-TO-SWAV-MODEL\" \"PATH-TO-VIBCREG-MODEL\" \"PATH-TO-VICREG-MODEL\" + \"PATH-TO-SSLEY-MODEL\" ) for i in ${!METHODS[@]}; do diff --git a/tests/scripts/pretrain/test_cifar_scripts.sh b/tests/scripts/pretrain/test_cifar_scripts.sh index 772dbd2db..472ea541b 100644 --- a/tests/scripts/pretrain/test_cifar_scripts.sh +++ b/tests/scripts/pretrain/test_cifar_scripts.sh @@ -1,4 +1,4 @@ -METHODS=("barlow" "byol" "dino" "mae" "mocov2plus" "mocov3" "nnbyol" "nnclr" "nnsiam" "ressl" "simclr" "simsiam" "supcon" "swav" "vibcreg" "vicreg") +METHODS=("barlow" "byol" "dino" "mae" "mocov2plus" "mocov3" "nnbyol" "nnclr" "nnsiam" "ressl" "simclr" "simsiam" "supcon" "swav" "vibcreg" "vicreg" "ssley") DATASETS=("cifar10") for dataset in ${DATASETS[@]}; do diff --git a/tests/scripts/pretrain/test_imagenet100_scripts.sh b/tests/scripts/pretrain/test_imagenet100_scripts.sh index 26c71c958..41b66041b 100644 --- a/tests/scripts/pretrain/test_imagenet100_scripts.sh +++ b/tests/scripts/pretrain/test_imagenet100_scripts.sh @@ -1,7 +1,7 @@ TRAIN_PATH=$1 VAL_PATH=$2 FORMAT=$3 -METHODS=("barlow" "byol" "dino" "dino_vit" "mae" "mocov2plus" "mocov3_vit" "mocov3" "nnclr" "ressl" "simclr" "simsiam" "supcon" "swav" "vibcreg" "vicreg") +METHODS=("barlow" "byol" "dino" "dino_vit" "mae" "mocov2plus" "mocov3_vit" "mocov3" "nnclr" "ressl" "simclr" "simsiam" "supcon" "swav" "vibcreg" "vicreg" "ssley") for method in ${METHODS[@]}; do echo Running $method From f2f18deda2cc82db9cbeaac4e9f4c1962bfd7185 Mon Sep 17 00:00:00 2001 From: jameschapman19 Date: Wed, 6 Dec 2023 12:17:22 +0000 Subject: [PATCH 2/9] Adding SSL-EY to README.md --- README.md | 110 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 57 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index 41e2c18aa..8a1ca97cd 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,7 @@ The library is self-contained, but it is possible to use the models outside of s * [SimCLR](https://arxiv.org/abs/2002.05709) * [SimSiam](https://arxiv.org/abs/2011.10566) * [Supervised Contrastive Learning](https://arxiv.org/abs/2004.11362) +* [SSL-EY](https://arxiv.org/abs/2310.01012) * [SwAV](https://arxiv.org/abs/2006.09882) * [VIbCReg](https://arxiv.org/abs/2109.00783) * [VICReg](https://arxiv.org/abs/2105.04906) @@ -214,66 +215,69 @@ All pretrained models avaiable can be downloaded directly via the tables below o ### CIFAR-10 -| Method | Backbone | Epochs | Dali | Acc@1 | Acc@5 | Checkpoint | -|--------------|:--------:|:------:|:----:|:--------------:|:--------------:|:----------:| -| Barlow Twins | ResNet18 | 1000 | :x: | 92.10 | 99.73 | [:link:](https://drive.google.com/drive/folders/1L5RAM3lCSViD2zEqLtC-GQKVw6mxtxJ_?usp=sharing) | -| BYOL | ResNet18 | 1000 | :x: | 92.58 | 99.79 | [:link:](https://drive.google.com/drive/folders/1KxeYAEE7Ev9kdFFhXWkPZhG-ya3_UwGP?usp=sharing) | -|DeepCluster V2| ResNet18 | 1000 | :x: | 88.85 | 99.58 | [:link:](https://drive.google.com/drive/folders/1tkEbiDQ38vZaQUsT6_vEpxbDxSUAGwF-?usp=sharing) | -| DINO | ResNet18 | 1000 | :x: | 89.52 | 99.71 | [:link:](https://drive.google.com/drive/folders/1vyqZKUyP8sQyEyf2cqonxlGMbQC-D1Gi?usp=sharing) | -| MoCo V2+ | ResNet18 | 1000 | :x: | 92.94 | 99.79 | [:link:](https://drive.google.com/drive/folders/1ruNFEB3F-Otxv2Y0p62wrjA4v5Fr2cKC?usp=sharing) | -| MoCo V3 | ResNet18 | 1000 | :x: | 93.10 | 99.80 | [:link:](https://drive.google.com/drive/folders/1KwZTshNEpmqnYJcmyYPvfIJ_DNwqtAVj?usp=sharing) | -| NNCLR | ResNet18 | 1000 | :x: | 91.88 | 99.78 | [:link:](https://drive.google.com/drive/folders/1xdCzhvRehPmxinphuiZqFlfBwfwWDcLh?usp=sharing) | -| ReSSL | ResNet18 | 1000 | :x: | 90.63 | 99.62 | [:link:](https://drive.google.com/drive/folders/1jrFcztY2eO_fG98xPshqOD15pDIhLXp-?usp=sharing) | -| SimCLR | ResNet18 | 1000 | :x: | 90.74 | 99.75 | [:link:](https://drive.google.com/drive/folders/1mcvWr8P2WNJZ7TVpdLHA_Q91q4VK3y8O?usp=sharing) | -| Simsiam | ResNet18 | 1000 | :x: | 90.51 | 99.72 | [:link:](https://drive.google.com/drive/folders/1OO_igM3IK5oDw7GjQTNmdfg2I1DH3xOk?usp=sharing) | -| SupCon | ResNet18 | 1000 | :x: | 93.82 | 99.65 | [:link:](https://drive.google.com/drive/folders/1VwZ9TrJXCpnxyo7P_l397yGrGH-DAUv-?usp=sharing) | -| SwAV | ResNet18 | 1000 | :x: | 89.17 | 99.68 | [:link:](https://drive.google.com/drive/folders/1nlJH4Ljm8-5fOIeAaKppQT6gtsmmW1T0?usp=sharing) | -| VIbCReg | ResNet18 | 1000 | :x: | 91.18 | 99.74 | [:link:](https://drive.google.com/drive/folders/1XvxUOnLPZlC_-OkeuO7VqXT7z9_tNVk7?usp=sharing) | -| VICReg | ResNet18 | 1000 | :x: | 92.07 | 99.74 | [:link:](https://drive.google.com/drive/folders/159ZgCxocB7aaHxwNDubnAWU71zXV9hn-?usp=sharing) | -| W-MSE | ResNet18 | 1000 | :x: | 88.67 | 99.68 | [:link:](https://drive.google.com/drive/folders/1xPCiULzQ4JCmhrTsbxBp9S2jRZ01KiVM?usp=sharing) | +| Method | Backbone | Epochs | Dali | Acc@1 | Acc@5 | Checkpoint | +|--------------|:--------:|:------:|:----:|:-----:|:-----:|:----------------------------------------------------------------------------------------------:| +| Barlow Twins | ResNet18 | 1000 | :x: | 92.10 | 99.73 | [:link:](https://drive.google.com/drive/folders/1L5RAM3lCSViD2zEqLtC-GQKVw6mxtxJ_?usp=sharing) | +| BYOL | ResNet18 | 1000 | :x: | 92.58 | 99.79 | [:link:](https://drive.google.com/drive/folders/1KxeYAEE7Ev9kdFFhXWkPZhG-ya3_UwGP?usp=sharing) | +|DeepCluster V2| ResNet18 | 1000 | :x: | 88.85 | 99.58 | [:link:](https://drive.google.com/drive/folders/1tkEbiDQ38vZaQUsT6_vEpxbDxSUAGwF-?usp=sharing) | +| DINO | ResNet18 | 1000 | :x: | 89.52 | 99.71 | [:link:](https://drive.google.com/drive/folders/1vyqZKUyP8sQyEyf2cqonxlGMbQC-D1Gi?usp=sharing) | +| MoCo V2+ | ResNet18 | 1000 | :x: | 92.94 | 99.79 | [:link:](https://drive.google.com/drive/folders/1ruNFEB3F-Otxv2Y0p62wrjA4v5Fr2cKC?usp=sharing) | +| MoCo V3 | ResNet18 | 1000 | :x: | 93.10 | 99.80 | [:link:](https://drive.google.com/drive/folders/1KwZTshNEpmqnYJcmyYPvfIJ_DNwqtAVj?usp=sharing) | +| NNCLR | ResNet18 | 1000 | :x: | 91.88 | 99.78 | [:link:](https://drive.google.com/drive/folders/1xdCzhvRehPmxinphuiZqFlfBwfwWDcLh?usp=sharing) | +| ReSSL | ResNet18 | 1000 | :x: | 90.63 | 99.62 | [:link:](https://drive.google.com/drive/folders/1jrFcztY2eO_fG98xPshqOD15pDIhLXp-?usp=sharing) | +| SimCLR | ResNet18 | 1000 | :x: | 90.74 | 99.75 | [:link:](https://drive.google.com/drive/folders/1mcvWr8P2WNJZ7TVpdLHA_Q91q4VK3y8O?usp=sharing) | +| Simsiam | ResNet18 | 1000 | :x: | 90.51 | 99.72 | [:link:](https://drive.google.com/drive/folders/1OO_igM3IK5oDw7GjQTNmdfg2I1DH3xOk?usp=sharing) | +| SSL-EY | ResNet18 | 1000 | :x: | TODO | TODO | [:link:](TODO)| +| SupCon | ResNet18 | 1000 | :x: | 93.82 | 99.65 | [:link:](https://drive.google.com/drive/folders/1VwZ9TrJXCpnxyo7P_l397yGrGH-DAUv-?usp=sharing) | +| SwAV | ResNet18 | 1000 | :x: | 89.17 | 99.68 | [:link:](https://drive.google.com/drive/folders/1nlJH4Ljm8-5fOIeAaKppQT6gtsmmW1T0?usp=sharing) | +| VIbCReg | ResNet18 | 1000 | :x: | 91.18 | 99.74 | [:link:](https://drive.google.com/drive/folders/1XvxUOnLPZlC_-OkeuO7VqXT7z9_tNVk7?usp=sharing) | +| VICReg | ResNet18 | 1000 | :x: | 92.07 | 99.74 | [:link:](https://drive.google.com/drive/folders/159ZgCxocB7aaHxwNDubnAWU71zXV9hn-?usp=sharing) | +| W-MSE | ResNet18 | 1000 | :x: | 88.67 | 99.68 | [:link:](https://drive.google.com/drive/folders/1xPCiULzQ4JCmhrTsbxBp9S2jRZ01KiVM?usp=sharing) | ### CIFAR-100 -| Method | Backbone | Epochs | Dali | Acc@1 | Acc@5 | Checkpoint | -|--------------|:--------:|:------:|:----:|:--------------:|:--------------:|:----------:| -| Barlow Twins | ResNet18 | 1000 | :x: | 70.90 | 91.91 | [:link:](https://drive.google.com/drive/folders/1hDLSApF3zSMAKco1Ck4DMjyNxhsIR2yq?usp=sharing) | -| BYOL | ResNet18 | 1000 | :x: | 70.46 | 91.96 | [:link:](https://drive.google.com/drive/folders/1hwsEdsfsUulD2tAwa4epKK9pkSuvFv6m?usp=sharing) | -|DeepCluster V2| ResNet18 | 1000 | :x: | 63.61 | 88.09 | [:link:](https://drive.google.com/drive/folders/1gAKyMz41mvGh1BBOYdc_xu6JPSkKlWqK?usp=sharing) | -| DINO | ResNet18 | 1000 | :x: | 66.76 | 90.34 | [:link:](https://drive.google.com/drive/folders/1TxeZi2YLprDDtbt_y5m29t4euroWr1Fy?usp=sharing) | -| MoCo V2+ | ResNet18 | 1000 | :x: | 69.89 | 91.65 | [:link:](https://drive.google.com/drive/folders/15oWNM16vO6YVYmk_yOmw2XUrFivRXam4?usp=sharing) | -| MoCo V3 | ResNet18 | 1000 | :x: | 68.83 | 90.57 | [:link:](https://drive.google.com/drive/folders/1Hcf9kMIADKydfxvXLquY9nv7sfNaJ3v6?usp=sharing) | -| NNCLR | ResNet18 | 1000 | :x: | 69.62 | 91.52 | [:link:](https://drive.google.com/drive/folders/1Dz72o0-5hugYPW1kCCQDBb0Xi3kzMLzu?usp=sharing) | -| ReSSL | ResNet18 | 1000 | :x: | 65.92 | 89.73 | [:link:](https://drive.google.com/drive/folders/1aVZs9cHAu6Ccz8ILyWkp6NhTsJGBGfjr?usp=sharing) | -| SimCLR | ResNet18 | 1000 | :x: | 65.78 | 89.04 | [:link:](https://drive.google.com/drive/folders/13pGPcOO9Y3rBoeRVWARgbMFEp8OXxZa0?usp=sharing) | -| Simsiam | ResNet18 | 1000 | :x: | 66.04 | 89.62 | [:link:](https://drive.google.com/drive/folders/1AJUPmsIHh_nqEcFe-Vcz2o4ruEibFHWO?usp=sharing) | -| SupCon | ResNet18 | 1000 | :x: | 70.38 | 89.57 | [:link:](https://drive.google.com/drive/folders/15C68oHPDMAOPtmBAm_Xw6YI6GgOW00gM?usp=sharing) | -| SwAV | ResNet18 | 1000 | :x: | 64.88 | 88.78 | [:link:](https://drive.google.com/drive/folders/1U_bmyhlPEN941hbx0SdRGOT4ivCarQB9?usp=sharing) | -| VIbCReg | ResNet18 | 1000 | :x: | 67.37 | 90.07 | [:link:](https://drive.google.com/drive/folders/19u3p1maX3xqwoCHNrqSDb98J5fRvd_6v?usp=sharing) | -| VICReg | ResNet18 | 1000 | :x: | 68.54 | 90.83 | [:link:](https://drive.google.com/drive/folders/1AHmVf_Zl5fikkmR4X3NWlmMOnRzfv0aT?usp=sharing) | -| W-MSE | ResNet18 | 1000 | :x: | 61.33 | 87.26 | [:link:](https://drive.google.com/drive/folders/1vc9j3RLpVCbECh6o-44oMiE5snNyKPlF?usp=sharing) | +| Method | Backbone | Epochs | Dali | Acc@1 | Acc@5 | Checkpoint | +|----------------|:--------:|:------:|:----:|:-----:|:-----:|:----------------------------------------------------------------------------------------------:| +| Barlow Twins | ResNet18 | 1000 | :x: | 70.90 | 91.91 | [:link:](https://drive.google.com/drive/folders/1hDLSApF3zSMAKco1Ck4DMjyNxhsIR2yq?usp=sharing) | +| BYOL | ResNet18 | 1000 | :x: | 70.46 | 91.96 | [:link:](https://drive.google.com/drive/folders/1hwsEdsfsUulD2tAwa4epKK9pkSuvFv6m?usp=sharing) | +| DeepCluster V2 | ResNet18 | 1000 | :x: | 63.61 | 88.09 | [:link:](https://drive.google.com/drive/folders/1gAKyMz41mvGh1BBOYdc_xu6JPSkKlWqK?usp=sharing) | +| DINO | ResNet18 | 1000 | :x: | 66.76 | 90.34 | [:link:](https://drive.google.com/drive/folders/1TxeZi2YLprDDtbt_y5m29t4euroWr1Fy?usp=sharing) | +| MoCo V2+ | ResNet18 | 1000 | :x: | 69.89 | 91.65 | [:link:](https://drive.google.com/drive/folders/15oWNM16vO6YVYmk_yOmw2XUrFivRXam4?usp=sharing) | +| MoCo V3 | ResNet18 | 1000 | :x: | 68.83 | 90.57 | [:link:](https://drive.google.com/drive/folders/1Hcf9kMIADKydfxvXLquY9nv7sfNaJ3v6?usp=sharing) | +| NNCLR | ResNet18 | 1000 | :x: | 69.62 | 91.52 | [:link:](https://drive.google.com/drive/folders/1Dz72o0-5hugYPW1kCCQDBb0Xi3kzMLzu?usp=sharing) | +| ReSSL | ResNet18 | 1000 | :x: | 65.92 | 89.73 | [:link:](https://drive.google.com/drive/folders/1aVZs9cHAu6Ccz8ILyWkp6NhTsJGBGfjr?usp=sharing) | +| SimCLR | ResNet18 | 1000 | :x: | 65.78 | 89.04 | [:link:](https://drive.google.com/drive/folders/13pGPcOO9Y3rBoeRVWARgbMFEp8OXxZa0?usp=sharing) | +| Simsiam | ResNet18 | 1000 | :x: | 66.04 | 89.62 | [:link:](https://drive.google.com/drive/folders/1AJUPmsIHh_nqEcFe-Vcz2o4ruEibFHWO?usp=sharing) | +| SSLEY | ResNet18 | 1000 | :x: | TODO | TODO | [:link:](TODO)| +| SupCon | ResNet18 | 1000 | :x: | 70.38 | 89.57 | [:link:](https://drive.google.com/drive/folders/15C68oHPDMAOPtmBAm_Xw6YI6GgOW00gM?usp=sharing) | +| SwAV | ResNet18 | 1000 | :x: | 64.88 | 88.78 | [:link:](https://drive.google.com/drive/folders/1U_bmyhlPEN941hbx0SdRGOT4ivCarQB9?usp=sharing) | +| VIbCReg | ResNet18 | 1000 | :x: | 67.37 | 90.07 | [:link:](https://drive.google.com/drive/folders/19u3p1maX3xqwoCHNrqSDb98J5fRvd_6v?usp=sharing) | +| VICReg | ResNet18 | 1000 | :x: | 68.54 | 90.83 | [:link:](https://drive.google.com/drive/folders/1AHmVf_Zl5fikkmR4X3NWlmMOnRzfv0aT?usp=sharing) | +| W-MSE | ResNet18 | 1000 | :x: | 61.33 | 87.26 | [:link:](https://drive.google.com/drive/folders/1vc9j3RLpVCbECh6o-44oMiE5snNyKPlF?usp=sharing) | ### ImageNet-100 -| Method | Backbone | Epochs | Dali | Acc@1 (online) | Acc@1 (offline) | Acc@5 (online) | Acc@5 (offline) | Checkpoint | -|-------------------------|:--------:|:------:|:------------------:|:--------------:|:---------------:|:--------------:|:---------------:|:----------:| -| Barlow Twins :rocket: | ResNet18 | 400 | :heavy_check_mark: | 80.38 | 80.16 | 95.28 | 95.14 | [:link:](https://drive.google.com/drive/folders/1rj8RbER9E71mBlCHIZEIhKPUFn437D5O?usp=sharing) | -| BYOL :rocket: | ResNet18 | 400 | :heavy_check_mark: | 80.16 | 80.32 | 95.02 | 94.94 | [:link:](https://drive.google.com/drive/folders/1riOLjMawD_znO4HYj8LBN2e1X4jXpDE1?usp=sharing) | -| DeepCluster V2 | ResNet18 | 400 | :x: | 75.36 | 75.4 | 93.22 | 93.10 | [:link:](https://drive.google.com/drive/folders/1d5jPuavrQ7lMlQZn5m2KnN5sPMGhHFo8?usp=sharing) | -| DINO | ResNet18 | 400 | :heavy_check_mark: | 74.84 | 74.92 | 92.92 | 92.78 | [:link:](https://drive.google.com/drive/folders/1NtVvRj-tQJvrMxRlMtCJSAecQnYZYkqs?usp=sharing) | -| DINO :sleepy: | ViT Tiny | 400 | :x: | 63.04 | TODO | 87.72 | TODO | [:link:](https://drive.google.com/drive/folders/16AfsM-UpKky43kdSMlqj4XRe69pRdJLc?usp=sharing) | -| MoCo V2+ :rocket: | ResNet18 | 400 | :heavy_check_mark: | 78.20 | 79.28 | 95.50 | 95.18 | [:link:](https://drive.google.com/drive/folders/1ItYBtMJ23Yh-Rhrvwjm4w1waFfUGSoKX?usp=sharing) | -| MoCo V3 :rocket: | ResNet18 | 400 | :heavy_check_mark: | 80.36 | 80.36 | 95.18 | 94.96 | [:link:](https://drive.google.com/drive/folders/15J0JiZsQAsrQler8mbbio-desb_nVoD1?usp=sharing) | -| MoCo V3 :rocket: | ResNet50 | 400 | :heavy_check_mark: | 85.48 | 84.58 | 96.82 | 96.70 | [:link:](https://drive.google.com/drive/folders/1a1VRXGlP50COZ57DPUA_doBmpaxGKpQE?usp=sharing) | -| NNCLR :rocket: | ResNet18 | 400 | :heavy_check_mark: | 79.80 | 80.16 | 95.28 | 95.30 | [:link:](https://drive.google.com/drive/folders/1QMkq8w3UsdcZmoNUIUPgfSCAZl_LSNjZ?usp=sharing) | -| ReSSL | ResNet18 | 400 | :heavy_check_mark: | 76.92 | 78.48 | 94.20 | 94.24 | [:link:](https://drive.google.com/drive/folders/1urWIFACLont4GAduis6l0jcEbl080c9U?usp=sharing) | -| SimCLR :rocket: | ResNet18 | 400 | :heavy_check_mark: | 77.64 | TODO | 94.06 | TODO | [:link:](https://drive.google.com/drive/folders/1yxAVKnc8Vf0tDfkixSB5mXe7dsA8Ll37?usp=sharing) | -| Simsiam | ResNet18 | 400 | :heavy_check_mark: | 74.54 | 78.72 | 93.16 | 94.78 | [:link:](https://drive.google.com/drive/folders/1Bc8Xj-Z7ILmspsiEQHyQsTOn4M99F_f5?usp=sharing) | -| SupCon | ResNet18 | 400 | :heavy_check_mark: | 84.40 | TODO | 95.72 | TODO | [:link:](https://drive.google.com/drive/folders/1BzR0nehkCKpnLhi-oeDynzzUcCYOCUJi?usp=sharing) | -| SwAV | ResNet18 | 400 | :heavy_check_mark: | 74.04 | 74.28 | 92.70 | 92.84 | [:link:](https://drive.google.com/drive/folders/1VWCMM69sokzjVoPzPSLIsUy5S2Rrm1xJ?usp=sharing) | -| VIbCReg | ResNet18 | 400 | :heavy_check_mark: | 79.86 | 79.38 | 94.98 | 94.60 | [:link:](https://drive.google.com/drive/folders/1Q06hH18usvRwj2P0bsmoCkjNUX_0syCK?usp=sharing) | -| VICReg :rocket: | ResNet18 | 400 | :heavy_check_mark: | 79.22 | 79.40 | 95.06 | 95.02 | [:link:](https://drive.google.com/drive/folders/1uWWR5VBUru8vaHaGeLicS6X3R4CfZsr2?usp=sharing) | -| W-MSE | ResNet18 | 400 | :heavy_check_mark: | 67.60 | 69.06 | 90.94 | 91.22 | [:link:](https://drive.google.com/drive/folders/1TxubagNV4z5Qs7SqbBcyRHWGKevtFO5l?usp=sharing) | +| Method | Backbone | Epochs | Dali | Acc@1 (online) | Acc@1 (offline) | Acc@5 (online) | Acc@5 (offline) | Checkpoint | +|-----------------------|:--------:|:------:|:------------------:|:--------------:|:---------------:|:--------------:|:---------------:|:----------------------------------------------------------------------------------------------:| +| Barlow Twins :rocket: | ResNet18 | 400 | :heavy_check_mark: | 80.38 | 80.16 | 95.28 | 95.14 | [:link:](https://drive.google.com/drive/folders/1rj8RbER9E71mBlCHIZEIhKPUFn437D5O?usp=sharing) | +| BYOL :rocket: | ResNet18 | 400 | :heavy_check_mark: | 80.16 | 80.32 | 95.02 | 94.94 | [:link:](https://drive.google.com/drive/folders/1riOLjMawD_znO4HYj8LBN2e1X4jXpDE1?usp=sharing) | +| DeepCluster V2 | ResNet18 | 400 | :x: | 75.36 | 75.4 | 93.22 | 93.10 | [:link:](https://drive.google.com/drive/folders/1d5jPuavrQ7lMlQZn5m2KnN5sPMGhHFo8?usp=sharing) | +| DINO | ResNet18 | 400 | :heavy_check_mark: | 74.84 | 74.92 | 92.92 | 92.78 | [:link:](https://drive.google.com/drive/folders/1NtVvRj-tQJvrMxRlMtCJSAecQnYZYkqs?usp=sharing) | +| DINO :sleepy: | ViT Tiny | 400 | :x: | 63.04 | TODO | 87.72 | TODO | [:link:](https://drive.google.com/drive/folders/16AfsM-UpKky43kdSMlqj4XRe69pRdJLc?usp=sharing) | +| MoCo V2+ :rocket: | ResNet18 | 400 | :heavy_check_mark: | 78.20 | 79.28 | 95.50 | 95.18 | [:link:](https://drive.google.com/drive/folders/1ItYBtMJ23Yh-Rhrvwjm4w1waFfUGSoKX?usp=sharing) | +| MoCo V3 :rocket: | ResNet18 | 400 | :heavy_check_mark: | 80.36 | 80.36 | 95.18 | 94.96 | [:link:](https://drive.google.com/drive/folders/15J0JiZsQAsrQler8mbbio-desb_nVoD1?usp=sharing) | +| MoCo V3 :rocket: | ResNet50 | 400 | :heavy_check_mark: | 85.48 | 84.58 | 96.82 | 96.70 | [:link:](https://drive.google.com/drive/folders/1a1VRXGlP50COZ57DPUA_doBmpaxGKpQE?usp=sharing) | +| NNCLR :rocket: | ResNet18 | 400 | :heavy_check_mark: | 79.80 | 80.16 | 95.28 | 95.30 | [:link:](https://drive.google.com/drive/folders/1QMkq8w3UsdcZmoNUIUPgfSCAZl_LSNjZ?usp=sharing) | +| ReSSL | ResNet18 | 400 | :heavy_check_mark: | 76.92 | 78.48 | 94.20 | 94.24 | [:link:](https://drive.google.com/drive/folders/1urWIFACLont4GAduis6l0jcEbl080c9U?usp=sharing) | +| SimCLR :rocket: | ResNet18 | 400 | :heavy_check_mark: | 77.64 | TODO | 94.06 | TODO | [:link:](https://drive.google.com/drive/folders/1yxAVKnc8Vf0tDfkixSB5mXe7dsA8Ll37?usp=sharing) | +| Simsiam | ResNet18 | 400 | :heavy_check_mark: | 74.54 | 78.72 | 93.16 | 94.78 | [:link:](https://drive.google.com/drive/folders/1Bc8Xj-Z7ILmspsiEQHyQsTOn4M99F_f5?usp=sharing) | +| SSLEY :rocket: | ResNet18 | 400 | :heavy_check_mark: | TODO | TODO | TODO | TODO | [:link:](TODO) | +| SupCon | ResNet18 | 400 | :heavy_check_mark: | 84.40 | TODO | 95.72 | TODO | [:link:](https://drive.google.com/drive/folders/1BzR0nehkCKpnLhi-oeDynzzUcCYOCUJi?usp=sharing) | +| SwAV | ResNet18 | 400 | :heavy_check_mark: | 74.04 | 74.28 | 92.70 | 92.84 | [:link:](https://drive.google.com/drive/folders/1VWCMM69sokzjVoPzPSLIsUy5S2Rrm1xJ?usp=sharing) | +| VIbCReg | ResNet18 | 400 | :heavy_check_mark: | 79.86 | 79.38 | 94.98 | 94.60 | [:link:](https://drive.google.com/drive/folders/1Q06hH18usvRwj2P0bsmoCkjNUX_0syCK?usp=sharing) | +| VICReg :rocket: | ResNet18 | 400 | :heavy_check_mark: | 79.22 | 79.40 | 95.06 | 95.02 | [:link:](https://drive.google.com/drive/folders/1uWWR5VBUru8vaHaGeLicS6X3R4CfZsr2?usp=sharing) | +| W-MSE | ResNet18 | 400 | :heavy_check_mark: | 67.60 | 69.06 | 90.94 | 91.22 | [:link:](https://drive.google.com/drive/folders/1TxubagNV4z5Qs7SqbBcyRHWGKevtFO5l?usp=sharing) | :rocket: methods where hyperparameters were heavily tuned. From 2dda063ab790f058ea89809aa506a7512f74fec3 Mon Sep 17 00:00:00 2001 From: jameschapman19 Date: Wed, 6 Dec 2023 14:32:28 +0000 Subject: [PATCH 3/9] Adding SSL-EY to README.md --- scripts/pretrain/imagenet-100/ssley.yaml | 1 - solo/losses/ssley.py | 2 +- solo/methods/ssley.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/scripts/pretrain/imagenet-100/ssley.yaml b/scripts/pretrain/imagenet-100/ssley.yaml index 7429d92b3..7974072b5 100644 --- a/scripts/pretrain/imagenet-100/ssley.yaml +++ b/scripts/pretrain/imagenet-100/ssley.yaml @@ -1,6 +1,5 @@ defaults: - _self_ - - augmentations: ssley.yaml - wandb: private.yaml - override hydra/hydra_logging: disabled - override hydra/job_logging: disabled diff --git a/solo/losses/ssley.py b/solo/losses/ssley.py index 2586f37ab..cf32509fd 100644 --- a/solo/losses/ssley.py +++ b/solo/losses/ssley.py @@ -50,4 +50,4 @@ def ssley_loss_func( A = B[:D, D:] + B[D:, :D] B = B[:D, :D] + B[D:, D:] - return -torch.trace(2 * A - B @ B) \ No newline at end of file + return -torch.trace(2 * A - B @ B) diff --git a/solo/methods/ssley.py b/solo/methods/ssley.py index 9071de3ba..4c2eb4e07 100644 --- a/solo/methods/ssley.py +++ b/solo/methods/ssley.py @@ -144,4 +144,4 @@ def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: self.log("train_ssley_loss", ssley_loss, on_epoch=True, sync_dist=True) - return ssley_loss + class_loss \ No newline at end of file + return ssley_loss + class_loss From 39290edea0a4be1f87130f10fdbd06f68c73ad1b Mon Sep 17 00:00:00 2001 From: jameschapman19 Date: Wed, 6 Dec 2023 14:32:42 +0000 Subject: [PATCH 4/9] Adding SSL-EY to README.md --- solo/methods/ssley.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/solo/methods/ssley.py b/solo/methods/ssley.py index 4c2eb4e07..bdbad17c4 100644 --- a/solo/methods/ssley.py +++ b/solo/methods/ssley.py @@ -137,9 +137,6 @@ def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: ssley_loss = ssley_loss_func( z1, z2, - sim_loss_weight=self.sim_loss_weight, - var_loss_weight=self.var_loss_weight, - cov_loss_weight=self.cov_loss_weight, ) self.log("train_ssley_loss", ssley_loss, on_epoch=True, sync_dist=True) From ffb592898fe207e956369cb2af5dc3989b2bc4df Mon Sep 17 00:00:00 2001 From: jameschapman19 Date: Wed, 6 Dec 2023 14:33:03 +0000 Subject: [PATCH 5/9] Adding SSL-EY to README.md --- solo/methods/ssley.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/solo/methods/ssley.py b/solo/methods/ssley.py index bdbad17c4..85ceb4f83 100644 --- a/solo/methods/ssley.py +++ b/solo/methods/ssley.py @@ -24,7 +24,6 @@ import torch.nn as nn from solo.losses.ssley import ssley_loss_func from solo.methods.base import BaseMethod -from solo.utils.misc import omegaconf_select class SSLEY(BaseMethod): @@ -39,10 +38,6 @@ def __init__(self, cfg: omegaconf.DictConfig): super().__init__(cfg) - self.sim_loss_weight: float = cfg.method_kwargs.sim_loss_weight - self.var_loss_weight: float = cfg.method_kwargs.var_loss_weight - self.cov_loss_weight: float = cfg.method_kwargs.cov_loss_weight - proj_hidden_dim: int = cfg.method_kwargs.proj_hidden_dim proj_output_dim: int = cfg.method_kwargs.proj_output_dim @@ -73,22 +68,6 @@ def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConf assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_output_dim") assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_hidden_dim") - cfg.method_kwargs.sim_loss_weight = omegaconf_select( - cfg, - "method_kwargs.sim_loss_weight", - 25.0, - ) - cfg.method_kwargs.var_loss_weight = omegaconf_select( - cfg, - "method_kwargs.var_loss_weight", - 25.0, - ) - cfg.method_kwargs.cov_loss_weight = omegaconf_select( - cfg, - "method_kwargs.cov_loss_weight", - 1.0, - ) - return cfg @property From 9e8ad7edd9451d93f90ec44271f148472d9d6583 Mon Sep 17 00:00:00 2001 From: jameschapman19 Date: Wed, 6 Dec 2023 14:33:57 +0000 Subject: [PATCH 6/9] Adding SSL-EY to README.md --- solo/losses/ssley.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/solo/losses/ssley.py b/solo/losses/ssley.py index cf32509fd..6826013b3 100644 --- a/solo/losses/ssley.py +++ b/solo/losses/ssley.py @@ -18,9 +18,8 @@ # DEALINGS IN THE SOFTWARE. import torch -import torch.nn.functional as F -from solo.utils.misc import gather +import torch.distributed as dist def ssley_loss_func( z1: torch.Tensor, @@ -37,8 +36,6 @@ def ssley_loss_func( torch.Tensor: VICReg loss. """ - sim_loss = invariance_loss(z1, z2) - N, D = z1.size() B = torch.cov(torch.hstack((z1, z2)).T) From b2ad189f1476d2fbf98b8afff51682fca04d74c8 Mon Sep 17 00:00:00 2001 From: jameschapman19 Date: Wed, 6 Dec 2023 14:36:20 +0000 Subject: [PATCH 7/9] Adding SSL-EY to README.md --- tests/losses/test_ssley.py | 43 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/losses/test_ssley.py diff --git a/tests/losses/test_ssley.py b/tests/losses/test_ssley.py new file mode 100644 index 000000000..a29c0dd6e --- /dev/null +++ b/tests/losses/test_ssley.py @@ -0,0 +1,43 @@ +# Copyright 2023 solo-learn development team. + +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the +# Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies +# or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import torch +from solo.losses import ssley_loss_func + + +def test_ssley_loss(): + b, f = 32, 128 + z1 = torch.randn(b, f).requires_grad_() + z2 = torch.randn(b, f).requires_grad_() + + loss = ssley_loss_func(z1, z2) + initial_loss = loss.item() + assert loss != 0 + + for _ in range(20): + loss = ssley_loss_func( + z1, z2 + ) + loss.backward() + z1.data.add_(-0.5 * z1.grad) + z2.data.add_(-0.5 * z2.grad) + + z1.grad = z2.grad = None + + assert loss < initial_loss From 4a2e8376df8a4e0a59932441fc1f584dbae4019c Mon Sep 17 00:00:00 2001 From: jameschapman19 Date: Fri, 8 Dec 2023 11:32:36 +0000 Subject: [PATCH 8/9] Adding SSL-EY to README.md --- solo/losses/ssley.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/solo/losses/ssley.py b/solo/losses/ssley.py index 6826013b3..05438bdd5 100644 --- a/solo/losses/ssley.py +++ b/solo/losses/ssley.py @@ -33,18 +33,17 @@ def ssley_loss_func( z2 (torch.Tensor): NxD Tensor containing projected features from view 2. Returns: - torch.Tensor: VICReg loss. + torch.Tensor: SSL-EY loss. """ - N, D = z1.size() - B = torch.cov(torch.hstack((z1, z2)).T) + z1, z2 = gather(z1), gather(z2) - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(B) - world_size = dist.get_world_size() - B /= world_size + z1 = z1 - z1.mean(dim=0) + z2 = z2 - z2.mean(dim=0) - A = B[:D, D:] + B[D:, :D] - B = B[:D, :D] + B[D:, D:] + C = 2 * (z1.T @ z2) / (self.args.batch_size - 1) + V = (z1.T @ z1) / (self.args.batch_size - 1) + (z2.T @ z2) / (self.args.batch_size - 1) - return -torch.trace(2 * A - B @ B) + loss = torch.trace(C) - torch.trace(V @ V) + + return loss From 11eac929a3cfb8fe28cefa573ef87e08fcf1eb27 Mon Sep 17 00:00:00 2001 From: jameschapman19 Date: Fri, 8 Dec 2023 11:36:06 +0000 Subject: [PATCH 9/9] Adding SSL-EY to README.md --- solo/losses/ssley.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/solo/losses/ssley.py b/solo/losses/ssley.py index 05438bdd5..df5ed9c57 100644 --- a/solo/losses/ssley.py +++ b/solo/losses/ssley.py @@ -18,8 +18,7 @@ # DEALINGS IN THE SOFTWARE. import torch - -import torch.distributed as dist +from solo.utils.misc import gather def ssley_loss_func( z1: torch.Tensor, @@ -35,15 +34,15 @@ def ssley_loss_func( Returns: torch.Tensor: SSL-EY loss. """ - + N, D = z1.size() z1, z2 = gather(z1), gather(z2) z1 = z1 - z1.mean(dim=0) z2 = z2 - z2.mean(dim=0) - C = 2 * (z1.T @ z2) / (self.args.batch_size - 1) - V = (z1.T @ z1) / (self.args.batch_size - 1) + (z2.T @ z2) / (self.args.batch_size - 1) + C = 2 * (z1.T @ z2) / (N - 1) + V = (z1.T @ z1) / (N - 1) + (z2.T @ z2) / (N - 1) - loss = torch.trace(C) - torch.trace(V @ V) + loss = -2*torch.trace(C) + torch.trace(V @ V) return loss