From 39d6941d48b785988561c4f75f1132ac2049be88 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 29 Mar 2023 05:06:39 -0500 Subject: [PATCH] SimCLR: add new trainer (#1195) * SimCLR: add new trainer * Add tests * Support custom number of MLP layers * Change default params, add TODOs * Fix mypy * Fix docs and most of tests * Fix all tests * Fix support for older Kornia versions * Fix support for older Kornia versions * Crop should be 224, not 96 --- docs/conf.py | 1 + ...r.yaml => chesapeake_cvpr_prior_byol.yaml} | 4 +- .../conf/chesapeake_cvpr_prior_simclr_1.yaml | 23 ++ .../conf/chesapeake_cvpr_prior_simclr_2.yaml | 23 ++ tests/conf/{seco_1.yaml => seco_byol_1.yaml} | 0 tests/conf/{seco_2.yaml => seco_byol_2.yaml} | 0 tests/conf/seco_simclr_1.yaml | 14 ++ tests/conf/seco_simclr_2.yaml | 14 ++ tests/trainers/test_byol.py | 6 +- tests/trainers/test_simclr.py | 65 ++++++ torchgeo/trainers/__init__.py | 2 + torchgeo/trainers/simclr.py | 198 ++++++++++++++++++ 12 files changed, 345 insertions(+), 5 deletions(-) rename tests/conf/{chesapeake_cvpr_prior.yaml => chesapeake_cvpr_prior_byol.yaml} (92%) create mode 100644 tests/conf/chesapeake_cvpr_prior_simclr_1.yaml create mode 100644 tests/conf/chesapeake_cvpr_prior_simclr_2.yaml rename tests/conf/{seco_1.yaml => seco_byol_1.yaml} (100%) rename tests/conf/{seco_2.yaml => seco_byol_2.yaml} (100%) create mode 100644 tests/conf/seco_simclr_1.yaml create mode 100644 tests/conf/seco_simclr_2.yaml create mode 100644 tests/trainers/test_simclr.py create mode 100644 torchgeo/trainers/simclr.py diff --git a/docs/conf.py b/docs/conf.py index c0c13b5f5bb..46e4365195e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -62,6 +62,7 @@ ("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"), ("py:class", "timm.models.resnet.ResNet"), ("py:class", "timm.models.vision_transformer.VisionTransformer"), + ("py:class", "torch.optim.lr_scheduler.LRScheduler"), ("py:class", "torchvision.models._api.WeightsEnum"), ("py:class", "torchvision.models.resnet.ResNet"), ] diff --git a/tests/conf/chesapeake_cvpr_prior.yaml b/tests/conf/chesapeake_cvpr_prior_byol.yaml similarity index 92% rename from tests/conf/chesapeake_cvpr_prior.yaml rename to tests/conf/chesapeake_cvpr_prior_byol.yaml index 3e9713fbb59..a5947035137 100644 --- a/tests/conf/chesapeake_cvpr_prior.yaml +++ b/tests/conf/chesapeake_cvpr_prior_byol.yaml @@ -3,7 +3,7 @@ experiment: module: loss: "ce" model: "unet" - backbone: "resnet50" + backbone: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 in_channels: 4 @@ -13,7 +13,7 @@ experiment: weights: null datamodule: root: "tests/data/chesapeake/cvpr" - download: true + download: false train_splits: - "de-test" val_splits: diff --git a/tests/conf/chesapeake_cvpr_prior_simclr_1.yaml b/tests/conf/chesapeake_cvpr_prior_simclr_1.yaml new file mode 100644 index 00000000000..b6876c5a89f --- /dev/null +++ b/tests/conf/chesapeake_cvpr_prior_simclr_1.yaml @@ -0,0 +1,23 @@ +experiment: + task: "chesapeake_cvpr" + module: + model: "resnet18" + in_channels: 4 + version: 1 + layers: 2 + weight_decay: 1e-6 + max_epochs: 1 + datamodule: + root: "tests/data/chesapeake/cvpr" + download: false + train_splits: + - "de-test" + val_splits: + - "de-test" + test_splits: + - "de-test" + batch_size: 2 + patch_size: 64 + num_workers: 0 + class_set: 5 + use_prior_labels: True diff --git a/tests/conf/chesapeake_cvpr_prior_simclr_2.yaml b/tests/conf/chesapeake_cvpr_prior_simclr_2.yaml new file mode 100644 index 00000000000..9d82787a772 --- /dev/null +++ b/tests/conf/chesapeake_cvpr_prior_simclr_2.yaml @@ -0,0 +1,23 @@ +experiment: + task: "chesapeake_cvpr" + module: + model: "resnet18" + in_channels: 4 + version: 2 + layers: 3 + weight_decay: 1e-4 + max_epochs: 1 + datamodule: + root: "tests/data/chesapeake/cvpr" + download: false + train_splits: + - "de-test" + val_splits: + - "de-test" + test_splits: + - "de-test" + batch_size: 2 + patch_size: 64 + num_workers: 0 + class_set: 5 + use_prior_labels: True diff --git a/tests/conf/seco_1.yaml b/tests/conf/seco_byol_1.yaml similarity index 100% rename from tests/conf/seco_1.yaml rename to tests/conf/seco_byol_1.yaml diff --git a/tests/conf/seco_2.yaml b/tests/conf/seco_byol_2.yaml similarity index 100% rename from tests/conf/seco_2.yaml rename to tests/conf/seco_byol_2.yaml diff --git a/tests/conf/seco_simclr_1.yaml b/tests/conf/seco_simclr_1.yaml new file mode 100644 index 00000000000..8e4563398eb --- /dev/null +++ b/tests/conf/seco_simclr_1.yaml @@ -0,0 +1,14 @@ +experiment: + task: "seco" + module: + model: "resnet18" + in_channels: 3 + version: 1 + layers: 2 + weight_decay: 1e-6 + max_epochs: 1 + datamodule: + root: "tests/data/seco" + seasons: 1 + batch_size: 2 + num_workers: 0 diff --git a/tests/conf/seco_simclr_2.yaml b/tests/conf/seco_simclr_2.yaml new file mode 100644 index 00000000000..d72940c3828 --- /dev/null +++ b/tests/conf/seco_simclr_2.yaml @@ -0,0 +1,14 @@ +experiment: + task: "seco" + module: + model: "resnet18" + in_channels: 3 + version: 2 + layers: 3 + weight_decay: 1e-4 + max_epochs: 1 + datamodule: + root: "tests/data/seco" + seasons: 2 + batch_size: 2 + num_workers: 0 diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index e60b0d46014..e0f1cb0c69a 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -52,9 +52,9 @@ class TestBYOLTask: @pytest.mark.parametrize( "name,classname", [ - ("chesapeake_cvpr_prior", ChesapeakeCVPRDataModule), - ("seco_1", SeasonalContrastS2DataModule), - ("seco_2", SeasonalContrastS2DataModule), + ("chesapeake_cvpr_prior_byol", ChesapeakeCVPRDataModule), + ("seco_byol_1", SeasonalContrastS2DataModule), + ("seco_byol_2", SeasonalContrastS2DataModule), ], ) def test_trainer( diff --git a/tests/trainers/test_simclr.py b/tests/trainers/test_simclr.py new file mode 100644 index 00000000000..ce6e90791f8 --- /dev/null +++ b/tests/trainers/test_simclr.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from typing import Any, Dict, Type, cast + +import pytest +import timm +from _pytest.monkeypatch import MonkeyPatch +from lightning.pytorch import LightningDataModule, Trainer +from omegaconf import OmegaConf +from torch.nn import Module + +from torchgeo.datamodules import ChesapeakeCVPRDataModule, SeasonalContrastS2DataModule +from torchgeo.datasets import SeasonalContrastS2 +from torchgeo.trainers import SimCLRTask + +from .test_classification import ClassificationTestModel + + +def create_model(*args: Any, **kwargs: Any) -> Module: + return ClassificationTestModel(**kwargs) + + +class TestSimCLRTask: + @pytest.mark.parametrize( + "name,classname", + [ + ("chesapeake_cvpr_prior_simclr_1", ChesapeakeCVPRDataModule), + ("chesapeake_cvpr_prior_simclr_2", ChesapeakeCVPRDataModule), + ("seco_simclr_1", SeasonalContrastS2DataModule), + ("seco_simclr_2", SeasonalContrastS2DataModule), + ], + ) + def test_trainer( + self, + monkeypatch: MonkeyPatch, + name: str, + classname: Type[LightningDataModule], + fast_dev_run: bool, + ) -> None: + conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) + conf_dict = OmegaConf.to_object(conf.experiment) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) + + if name.startswith("seco"): + monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2) + + # Instantiate datamodule + datamodule_kwargs = conf_dict["datamodule"] + datamodule = classname(**datamodule_kwargs) + + # Instantiate model + monkeypatch.setattr(timm, "create_model", create_model) + model_kwargs = conf_dict["module"] + model = SimCLRTask(**model_kwargs) + + # Instantiate trainer + trainer = Trainer( + accelerator="cpu", + fast_dev_run=fast_dev_run, + log_every_n_steps=1, + max_epochs=1, + ) + trainer.fit(model=model, datamodule=datamodule) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index 6240b53f681..f2c5b7b25d7 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -8,6 +8,7 @@ from .detection import ObjectDetectionTask from .regression import RegressionTask from .segmentation import SemanticSegmentationTask +from .simclr import SimCLRTask __all__ = ( "BYOLTask", @@ -16,4 +17,5 @@ "ObjectDetectionTask", "RegressionTask", "SemanticSegmentationTask", + "SimCLRTask", ) diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py new file mode 100644 index 00000000000..88ac657af6c --- /dev/null +++ b/torchgeo/trainers/simclr.py @@ -0,0 +1,198 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""SimCLR trainers for self-supervised learning (SSL).""" + +from typing import Dict, List, Tuple + +import kornia.augmentation as K +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F +from lightning.pytorch import LightningModule +from torch import Tensor +from torch.optim import AdamW, Optimizer +from torch.optim.lr_scheduler import CosineAnnealingLR + +try: + from torch.optim.lr_scheduler import LRScheduler +except ImportError: + from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + + +class SimCLRTask(LightningModule): # type: ignore[misc] + """SimCLR: a simple framework for contrastive learning of visual representations. + + Implementation based on: + + * https://github.com/google-research/simclr + * https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/13-contrastive-learning.html + + If you use this trainer in your research, please cite the following papers: + + * v1: https://arxiv.org/abs/2002.05709 + * v2: https://arxiv.org/abs/2006.10029 + + .. versionadded:: 0.5 + """ # noqa: E501 + + def __init__( + self, + model: str = "resnet50", + in_channels: int = 3, + version: int = 2, + layers: int = 3, + hidden_dim: int = 128, + lr: float = 4.8, + weight_decay: float = 1e-4, + max_epochs: int = 100, + temperature: float = 0.07, + ) -> None: + """Initialize a new SimCLRTask instance. + + Args: + model: Name of the timm model to use. + in_channels: Number of input channels to model. + version: Version of SimCLR, 1--2. + layers: Number of layers in projection head (2 for v1 or 3+ for v2). + hidden_dim: Number of hidden dimensions in projection head. + lr: Learning rate (0.3 x batch_size / 256 is recommended). + weight_decay: Weight decay coefficient (1e-6 for v1 or 1e-4 for v2). + max_epochs: Maximum number of epochs to train for. + temperature: Temperature used in InfoNCE loss. + """ + super().__init__() + + assert version in range(1, 3) + + self.save_hyperparameters() + + self.model = timm.create_model( + model, in_chans=in_channels, num_classes=hidden_dim + ) + + # Projection head + # https://github.com/google-research/simclr/blob/master/model_util.py#L141 + for i in range(layers): + if i == layers - 1: + # For the final layer, skip bias and ReLU + self.model.fc = nn.Sequential( + self.model.fc, nn.Linear(hidden_dim, hidden_dim, bias=False) + ) + else: + # For the middle layers, use bias and ReLU + self.model.fc = nn.Sequential( + self.model.fc, + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim, bias=True), + ) + + # TODO + # v1+: add global batch norm + # v2: add selective kernels, channel-wise attention mechanism, memory bank + + # Data augmentation + # https://github.com/google-research/simclr/blob/master/data_util.py + self.aug = K.AugmentationSequential( + K.RandomResizedCrop(size=(224, 224)), + K.RandomHorizontalFlip(), + K.RandomVerticalFlip(), # added + # Not appropriate for multispectral imagery, seasonal contrast used instead + # K.ColorJitter( + # brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8 + # ) + # K.RandomGrayscale(p=0.2), + K.RandomGaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2)), + data_keys=["input"], + ) + + def forward(self, batch: Tensor) -> Tensor: + """Forward pass of the model. + + Args: + batch: Mini-batch of images. + + Returns: + Output from the model. + """ + batch = self.model(batch) + return batch + + def training_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Tensor: + """Compute the training loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + + Returns: + The loss tensor. + """ + x = batch["image"] + + in_channels = self.hparams["in_channels"] + assert x.size(1) == in_channels or x.size(1) == 2 * in_channels + + if x.size(1) == in_channels: + x1 = x + x2 = x + else: + x1 = x[:, :in_channels] + x2 = x[:, in_channels:] + + # Apply augmentations independently for each season + x1 = self.aug(x1) + x2 = self.aug(x2) + + x = torch.cat([x1, x2], dim=0) + + # Encode all images + x = self(x) + + # Calculate cosine similarity + cos_sim = F.cosine_similarity(x[:, None, :], x[None, :, :], dim=-1) + + # Mask out cosine similarity to itself + self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device) + cos_sim.masked_fill_(self_mask, -9e15) + + # Find positive example -> batch_size // 2 away from the original example + pos_mask = self_mask.roll(shifts=cos_sim.shape[0] // 2, dims=0) + + # NT-Xent loss (aka InfoNCE loss) + cos_sim = cos_sim / self.hparams["temperature"] + nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1) + nll = nll.mean() + + self.log("train_loss", nll) + + return nll + + def validation_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: + """No-op, does nothing.""" + + def test_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: + """No-op, does nothing.""" + # TODO + # v2: add distillation step + + def predict_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: + """No-op, does nothing.""" + + def configure_optimizers(self) -> Tuple[List[Optimizer], List[LRScheduler]]: + """Initialize the optimizer and learning rate scheduler. + + Returns: + Optimizer and learning rate scheduler. + """ + # Original paper uses LARS optimizer, but this is not defined in PyTorch + optimizer = AdamW( + self.parameters(), + lr=self.hparams["lr"], + weight_decay=self.hparams["weight_decay"], + ) + lr_scheduler = CosineAnnealingLR( + optimizer, T_max=self.hparams["max_epochs"], eta_min=self.hparams["lr"] / 50 + ) + return [optimizer], [lr_scheduler]