From 10f0bde9cbb3555fab59538520555023d083667e Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 22 Mar 2023 12:13:12 -0500 Subject: [PATCH 01/10] SimCLR: add new trainer --- torchgeo/trainers/simclr.py | 164 ++++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 torchgeo/trainers/simclr.py diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py new file mode 100644 index 00000000000..946213604fc --- /dev/null +++ b/torchgeo/trainers/simclr.py @@ -0,0 +1,164 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""SimCLR trainers for self-supervised learning (SSL).""" + +from typing import Dict, List, Tuple + +import kornia.augmentation as K +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F +from lightning.pytorch import LightningModule +from torch import Tensor +from torch.optim import AdamW, Optimizer +from torch.optim.lr_scheduler import CosineAnnealingLR, _LRScheduler + +from ..transforms import AugmentationSequential + + +class SimCLRTask(LightningModule): # type: ignore[misc] + """SimCLR: a simple framework for contrastive learning of visual representations. + + Implementation based on: + + * https://github.com/google-research/simclr + * https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/13-contrastive-learning.html + + If you use this trainer in your research, please cite the following papers: + + * v1: https://arxiv.org/abs/2002.05709 + * v2: https://arxiv.org/abs/2006.10029 + + .. versionadded:: 0.5 + """ # noqa: E501 + + def __init__( + self, + model: str = "resnet50", + in_channels: int = 3, + version: int = 2, + hidden_dim: int = 128, + lr: float = 4.8, + weight_decay: float = 1e-6, + max_epochs: int = 100, + temperature: float = 0.07, + ) -> None: + """Initialize a new SimCLRTask instance. + + Args: + model: Name of the timm model to use. + in_channels: Number of input channels to model. + version: Version of SimCLR, 1--2. + hidden_dim: Number of hidden dimensions. + lr: Learning rate (0.3 x batch_size / 256 is recommended). + weight_decay: Weight decay coefficient. + max_epochs: Maximum number of epochs to train for. + temperature: Temperature used in InfoNCE loss. + """ + super().__init__() + + assert version in range(2) + + self.save_hyperparameters() + + self.model = timm.create_model(model, in_chans=in_channels) + + # Add projection head + self.model.fc = nn.Sequential( + self.model.fc, nn.ReLU(inplace=True), nn.Linear(4 * hidden_dim, hidden_dim) + ) + + # Data augmentation + self.aug = AugmentationSequential( + K.RandomHorizontalFlip(), + K.RandomVerticalFlip(), # added + K.RandomResizedCrop(size=96), + # Not appropriate for multispectral imagery, seasonal contrast used instead + # K.ColorJitter( + # brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8 + # ) + # K.RandomGrayscale(p=0.2), + K.RandomGaussianBlur(kernel_size=9), + data_keys=["image"], + ) + + def forward(self, batch: Tensor) -> Tensor: + """Forward pass of the model. + + Args: + batch: Mini-batch of images. + + Returns: + Output from the model. + """ + batch = self.model(batch) + return batch + + def training_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Tensor: + """Compute the training loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + + Returns: + The loss tensor. + """ + x = batch["image"] + + in_channels = self.hparams["in_channels"] + assert x.size(1) == in_channels or x.size(1) == 2 * in_channels + + if x.size(1) == in_channels: + x1 = x + x2 = x + else: + x1 = x[:, :in_channels] + x2 = x[:, in_channels:] + + # Apply augmentations independently for each season + x1 = self.aug(x1) + x2 = self.aug(x2) + + x = torch.cat([x1, x2], dim=0) + + # Encode all images + x = self(x) + + # Calculate cosine similarity + cos_sim = F.cosine_similarity(x[:, None, :], x[None, :, :], dim=-1) + + # Mask out cosine similarity to itself + self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device) + cos_sim.masked_fill_(self_mask, -9e15) + + # Find positive example -> batch_size // 2 away from the original example + pos_mask = self_mask.roll(shifts=cos_sim.shape[0] // 2, dims=0) + + # NT-Xent loss (aka InfoNCE loss) + cos_sim = cos_sim / self.hparams["temperature"] + nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1) + nll = nll.mean() + + self.log("train_loss", nll) + + return nll + + def configure_optimizers(self) -> Tuple[List[Optimizer], List[_LRScheduler]]: + """Initialize the optimizer and learning rate scheduler. + + Returns: + Optimizer and learning rate scheduler. + """ + # Original paper uses LARS optimizer, but this is not defined in PyTorch + optimizer = AdamW( + self.parameters(), + lr=self.hparams["lr"], + weight_decay=self.hparams["weight_decay"], + ) + lr_scheduler = CosineAnnealingLR( + optimizer, T_max=self.hparams["max_epochs"], eta_min=self.hparams["lr"] / 50 + ) + return [optimizer], [lr_scheduler] From 38d20fb3567bd377bebe6fa3654cc41ca1f02347 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 22 Mar 2023 13:41:01 -0500 Subject: [PATCH 02/10] Add tests --- tests/trainers/test_simclr.py | 58 +++++++++++++++++++++++++++++++++++ torchgeo/trainers/__init__.py | 2 ++ torchgeo/trainers/simclr.py | 10 ++++++ 3 files changed, 70 insertions(+) create mode 100644 tests/trainers/test_simclr.py diff --git a/tests/trainers/test_simclr.py b/tests/trainers/test_simclr.py new file mode 100644 index 00000000000..54b46f75c1f --- /dev/null +++ b/tests/trainers/test_simclr.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from typing import Any, Dict, Type, cast + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from lightning.pytorch import LightningDataModule, Trainer +from omegaconf import OmegaConf + +from torchgeo.datamodules import ChesapeakeCVPRDataModule, SeasonalContrastS2DataModule +from torchgeo.datasets import SeasonalContrastS2 +from torchgeo.trainers import SimCLRTask + +from .test_classification import ClassificationTestModel + + +class TestSimCLRTask: + @pytest.mark.parametrize( + "name,classname", + [ + ("chesapeake_cvpr_prior", ChesapeakeCVPRDataModule), + ("seco_1", SeasonalContrastS2DataModule), + ("seco_2", SeasonalContrastS2DataModule), + ], + ) + def test_trainer( + self, + monkeypatch: MonkeyPatch, + name: str, + classname: Type[LightningDataModule], + fast_dev_run: bool, + ) -> None: + conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) + conf_dict = OmegaConf.to_object(conf.experiment) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) + + if name.startswith("seco"): + monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2) + + # Instantiate datamodule + datamodule_kwargs = conf_dict["datamodule"] + datamodule = classname(**datamodule_kwargs) + + # Instantiate model + model_kwargs = conf_dict["module"] + model = SimCLRTask(**model_kwargs) + model.model = ClassificationTestModel() + + # Instantiate trainer + trainer = Trainer( + accelerator="cpu", + fast_dev_run=fast_dev_run, + log_every_n_steps=1, + max_epochs=1, + ) + trainer.fit(model=model, datamodule=datamodule) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index 6240b53f681..f2c5b7b25d7 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -8,6 +8,7 @@ from .detection import ObjectDetectionTask from .regression import RegressionTask from .segmentation import SemanticSegmentationTask +from .simclr import SimCLRTask __all__ = ( "BYOLTask", @@ -16,4 +17,5 @@ "ObjectDetectionTask", "RegressionTask", "SemanticSegmentationTask", + "SimCLRTask", ) diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index 946213604fc..a922e4ba15f 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -146,6 +146,16 @@ def training_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Tensor: return nll + def validation_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: + """No-op, does nothing.""" + + def test_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: + """No-op, does nothing.""" + # TODO: add distillation step + + def predict_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: + """No-op, does nothing.""" + def configure_optimizers(self) -> Tuple[List[Optimizer], List[_LRScheduler]]: """Initialize the optimizer and learning rate scheduler. From 5e9592d751ded9a14d0c861b29065cb6904b6581 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 24 Mar 2023 13:36:01 -0500 Subject: [PATCH 03/10] Support custom number of MLP layers --- torchgeo/trainers/simclr.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index a922e4ba15f..6316b45f237 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -39,6 +39,7 @@ def __init__( model: str = "resnet50", in_channels: int = 3, version: int = 2, + layers: int = 3, hidden_dim: int = 128, lr: float = 4.8, weight_decay: float = 1e-6, @@ -51,7 +52,8 @@ def __init__( model: Name of the timm model to use. in_channels: Number of input channels to model. version: Version of SimCLR, 1--2. - hidden_dim: Number of hidden dimensions. + layers: Number of layers in projection head. + hidden_dim: Number of hidden dimensions in projection head. lr: Learning rate (0.3 x batch_size / 256 is recommended). weight_decay: Weight decay coefficient. max_epochs: Maximum number of epochs to train for. @@ -66,9 +68,20 @@ def __init__( self.model = timm.create_model(model, in_chans=in_channels) # Add projection head - self.model.fc = nn.Sequential( - self.model.fc, nn.ReLU(inplace=True), nn.Linear(4 * hidden_dim, hidden_dim) - ) + # https://github.com/google-research/simclr/blob/2fc637bdd6a723130db91b377ac15151e01e4fc2/model_util.py#L141 # noqa: E501 + for i in range(layers): + if i == layers - 1: + # For the final layer, skip bias and ReLU + self.model.fc = nn.Sequential( + self.model.fc, nn.Linear(hidden_dim, hidden_dim, bias=False) + ) + else: + # For the middle layers, use bias and ReLU + self.model.fc = nn.Sequential( + self.model.fc, + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim, bias=True), + ) # Data augmentation self.aug = AugmentationSequential( From 853f7edc65ceb6ee29b10eb69f6e6f86fc81fd85 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 25 Mar 2023 12:30:04 -0500 Subject: [PATCH 04/10] Change default params, add TODOs --- torchgeo/trainers/simclr.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index 6316b45f237..baebbd85f3a 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -42,7 +42,7 @@ def __init__( layers: int = 3, hidden_dim: int = 128, lr: float = 4.8, - weight_decay: float = 1e-6, + weight_decay: float = 1e-4, max_epochs: int = 100, temperature: float = 0.07, ) -> None: @@ -52,23 +52,23 @@ def __init__( model: Name of the timm model to use. in_channels: Number of input channels to model. version: Version of SimCLR, 1--2. - layers: Number of layers in projection head. + layers: Number of layers in projection head (2 for v1 or 3+ for v2). hidden_dim: Number of hidden dimensions in projection head. lr: Learning rate (0.3 x batch_size / 256 is recommended). - weight_decay: Weight decay coefficient. + weight_decay: Weight decay coefficient (1e-6 for v1 or 1e-4 for v2). max_epochs: Maximum number of epochs to train for. temperature: Temperature used in InfoNCE loss. """ super().__init__() - assert version in range(2) + assert version in range(1, 3) self.save_hyperparameters() self.model = timm.create_model(model, in_chans=in_channels) - # Add projection head - # https://github.com/google-research/simclr/blob/2fc637bdd6a723130db91b377ac15151e01e4fc2/model_util.py#L141 # noqa: E501 + # Projection head + # https://github.com/google-research/simclr/blob/master/model_util.py#L141 for i in range(layers): if i == layers - 1: # For the final layer, skip bias and ReLU @@ -83,11 +83,16 @@ def __init__( nn.Linear(hidden_dim, hidden_dim, bias=True), ) + # TODO + # v1+: add global batch norm + # v2: add selective kernels, channel-wise attention mechanism, memory bank + # Data augmentation + # https://github.com/google-research/simclr/blob/master/data_util.py self.aug = AugmentationSequential( + K.RandomResizedCrop(size=96), K.RandomHorizontalFlip(), K.RandomVerticalFlip(), # added - K.RandomResizedCrop(size=96), # Not appropriate for multispectral imagery, seasonal contrast used instead # K.ColorJitter( # brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8 @@ -164,7 +169,8 @@ def validation_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: def test_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: """No-op, does nothing.""" - # TODO: add distillation step + # TODO + # v2: add distillation step def predict_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: """No-op, does nothing.""" From 26f79eb78e2500e1050043b1a9323f77b1c8f0c1 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 25 Mar 2023 12:45:57 -0500 Subject: [PATCH 05/10] Fix mypy --- torchgeo/trainers/simclr.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index baebbd85f3a..aa4fa417464 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -13,7 +13,11 @@ from lightning.pytorch import LightningModule from torch import Tensor from torch.optim import AdamW, Optimizer -from torch.optim.lr_scheduler import CosineAnnealingLR, _LRScheduler +from torch.optim.lr_scheduler import CosineAnnealingLR +try: + from torch.optim.lr_scheduler import LRScheduler +except ImportError: + from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from ..transforms import AugmentationSequential @@ -175,7 +179,7 @@ def test_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: def predict_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: """No-op, does nothing.""" - def configure_optimizers(self) -> Tuple[List[Optimizer], List[_LRScheduler]]: + def configure_optimizers(self) -> Tuple[List[Optimizer], List[LRScheduler]]: """Initialize the optimizer and learning rate scheduler. Returns: From ac3d4ae2992e827209d32e91faff2e9bf8c529c1 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 25 Mar 2023 13:21:28 -0500 Subject: [PATCH 06/10] Fix docs and most of tests --- docs/conf.py | 1 + ...r.yaml => chesapeake_cvpr_prior_byol.yaml} | 4 ++-- .../conf/chesapeake_cvpr_prior_simclr_1.yaml | 23 +++++++++++++++++++ .../conf/chesapeake_cvpr_prior_simclr_2.yaml | 23 +++++++++++++++++++ tests/conf/{seco_1.yaml => seco_byol_1.yaml} | 0 tests/conf/{seco_2.yaml => seco_byol_2.yaml} | 0 tests/conf/seco_simclr_1.yaml | 14 +++++++++++ tests/conf/seco_simclr_2.yaml | 14 +++++++++++ tests/trainers/test_byol.py | 6 ++--- tests/trainers/test_simclr.py | 7 +++--- torchgeo/trainers/simclr.py | 5 ++-- 11 files changed, 87 insertions(+), 10 deletions(-) rename tests/conf/{chesapeake_cvpr_prior.yaml => chesapeake_cvpr_prior_byol.yaml} (92%) create mode 100644 tests/conf/chesapeake_cvpr_prior_simclr_1.yaml create mode 100644 tests/conf/chesapeake_cvpr_prior_simclr_2.yaml rename tests/conf/{seco_1.yaml => seco_byol_1.yaml} (100%) rename tests/conf/{seco_2.yaml => seco_byol_2.yaml} (100%) create mode 100644 tests/conf/seco_simclr_1.yaml create mode 100644 tests/conf/seco_simclr_2.yaml diff --git a/docs/conf.py b/docs/conf.py index c0c13b5f5bb..46e4365195e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -62,6 +62,7 @@ ("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"), ("py:class", "timm.models.resnet.ResNet"), ("py:class", "timm.models.vision_transformer.VisionTransformer"), + ("py:class", "torch.optim.lr_scheduler.LRScheduler"), ("py:class", "torchvision.models._api.WeightsEnum"), ("py:class", "torchvision.models.resnet.ResNet"), ] diff --git a/tests/conf/chesapeake_cvpr_prior.yaml b/tests/conf/chesapeake_cvpr_prior_byol.yaml similarity index 92% rename from tests/conf/chesapeake_cvpr_prior.yaml rename to tests/conf/chesapeake_cvpr_prior_byol.yaml index 3e9713fbb59..a5947035137 100644 --- a/tests/conf/chesapeake_cvpr_prior.yaml +++ b/tests/conf/chesapeake_cvpr_prior_byol.yaml @@ -3,7 +3,7 @@ experiment: module: loss: "ce" model: "unet" - backbone: "resnet50" + backbone: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 in_channels: 4 @@ -13,7 +13,7 @@ experiment: weights: null datamodule: root: "tests/data/chesapeake/cvpr" - download: true + download: false train_splits: - "de-test" val_splits: diff --git a/tests/conf/chesapeake_cvpr_prior_simclr_1.yaml b/tests/conf/chesapeake_cvpr_prior_simclr_1.yaml new file mode 100644 index 00000000000..b6876c5a89f --- /dev/null +++ b/tests/conf/chesapeake_cvpr_prior_simclr_1.yaml @@ -0,0 +1,23 @@ +experiment: + task: "chesapeake_cvpr" + module: + model: "resnet18" + in_channels: 4 + version: 1 + layers: 2 + weight_decay: 1e-6 + max_epochs: 1 + datamodule: + root: "tests/data/chesapeake/cvpr" + download: false + train_splits: + - "de-test" + val_splits: + - "de-test" + test_splits: + - "de-test" + batch_size: 2 + patch_size: 64 + num_workers: 0 + class_set: 5 + use_prior_labels: True diff --git a/tests/conf/chesapeake_cvpr_prior_simclr_2.yaml b/tests/conf/chesapeake_cvpr_prior_simclr_2.yaml new file mode 100644 index 00000000000..9d82787a772 --- /dev/null +++ b/tests/conf/chesapeake_cvpr_prior_simclr_2.yaml @@ -0,0 +1,23 @@ +experiment: + task: "chesapeake_cvpr" + module: + model: "resnet18" + in_channels: 4 + version: 2 + layers: 3 + weight_decay: 1e-4 + max_epochs: 1 + datamodule: + root: "tests/data/chesapeake/cvpr" + download: false + train_splits: + - "de-test" + val_splits: + - "de-test" + test_splits: + - "de-test" + batch_size: 2 + patch_size: 64 + num_workers: 0 + class_set: 5 + use_prior_labels: True diff --git a/tests/conf/seco_1.yaml b/tests/conf/seco_byol_1.yaml similarity index 100% rename from tests/conf/seco_1.yaml rename to tests/conf/seco_byol_1.yaml diff --git a/tests/conf/seco_2.yaml b/tests/conf/seco_byol_2.yaml similarity index 100% rename from tests/conf/seco_2.yaml rename to tests/conf/seco_byol_2.yaml diff --git a/tests/conf/seco_simclr_1.yaml b/tests/conf/seco_simclr_1.yaml new file mode 100644 index 00000000000..8e4563398eb --- /dev/null +++ b/tests/conf/seco_simclr_1.yaml @@ -0,0 +1,14 @@ +experiment: + task: "seco" + module: + model: "resnet18" + in_channels: 3 + version: 1 + layers: 2 + weight_decay: 1e-6 + max_epochs: 1 + datamodule: + root: "tests/data/seco" + seasons: 1 + batch_size: 2 + num_workers: 0 diff --git a/tests/conf/seco_simclr_2.yaml b/tests/conf/seco_simclr_2.yaml new file mode 100644 index 00000000000..d72940c3828 --- /dev/null +++ b/tests/conf/seco_simclr_2.yaml @@ -0,0 +1,14 @@ +experiment: + task: "seco" + module: + model: "resnet18" + in_channels: 3 + version: 2 + layers: 3 + weight_decay: 1e-4 + max_epochs: 1 + datamodule: + root: "tests/data/seco" + seasons: 2 + batch_size: 2 + num_workers: 0 diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index e60b0d46014..e0f1cb0c69a 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -52,9 +52,9 @@ class TestBYOLTask: @pytest.mark.parametrize( "name,classname", [ - ("chesapeake_cvpr_prior", ChesapeakeCVPRDataModule), - ("seco_1", SeasonalContrastS2DataModule), - ("seco_2", SeasonalContrastS2DataModule), + ("chesapeake_cvpr_prior_byol", ChesapeakeCVPRDataModule), + ("seco_byol_1", SeasonalContrastS2DataModule), + ("seco_byol_2", SeasonalContrastS2DataModule), ], ) def test_trainer( diff --git a/tests/trainers/test_simclr.py b/tests/trainers/test_simclr.py index 54b46f75c1f..f77a871f7fa 100644 --- a/tests/trainers/test_simclr.py +++ b/tests/trainers/test_simclr.py @@ -20,9 +20,10 @@ class TestSimCLRTask: @pytest.mark.parametrize( "name,classname", [ - ("chesapeake_cvpr_prior", ChesapeakeCVPRDataModule), - ("seco_1", SeasonalContrastS2DataModule), - ("seco_2", SeasonalContrastS2DataModule), + ("chesapeake_cvpr_prior_simclr_1", ChesapeakeCVPRDataModule), + ("chesapeake_cvpr_prior_simclr_2", ChesapeakeCVPRDataModule), + ("seco_simclr_1", SeasonalContrastS2DataModule), + ("seco_simclr_2", SeasonalContrastS2DataModule), ], ) def test_trainer( diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index aa4fa417464..37c5f66f052 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -14,6 +14,7 @@ from torch import Tensor from torch.optim import AdamW, Optimizer from torch.optim.lr_scheduler import CosineAnnealingLR + try: from torch.optim.lr_scheduler import LRScheduler except ImportError: @@ -94,7 +95,7 @@ def __init__( # Data augmentation # https://github.com/google-research/simclr/blob/master/data_util.py self.aug = AugmentationSequential( - K.RandomResizedCrop(size=96), + K.RandomResizedCrop(size=(96, 96)), K.RandomHorizontalFlip(), K.RandomVerticalFlip(), # added # Not appropriate for multispectral imagery, seasonal contrast used instead @@ -102,7 +103,7 @@ def __init__( # brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8 # ) # K.RandomGrayscale(p=0.2), - K.RandomGaussianBlur(kernel_size=9), + K.RandomGaussianBlur(kernel_size=9, sigma=(0.1, 2)), data_keys=["image"], ) From d482a0ea74cc5e8f7080b853705b0fea7f79ade7 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 25 Mar 2023 16:05:07 -0500 Subject: [PATCH 07/10] Fix all tests --- tests/trainers/test_simclr.py | 8 +++++++- torchgeo/trainers/simclr.py | 8 ++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/trainers/test_simclr.py b/tests/trainers/test_simclr.py index f77a871f7fa..ce6e90791f8 100644 --- a/tests/trainers/test_simclr.py +++ b/tests/trainers/test_simclr.py @@ -5,9 +5,11 @@ from typing import Any, Dict, Type, cast import pytest +import timm from _pytest.monkeypatch import MonkeyPatch from lightning.pytorch import LightningDataModule, Trainer from omegaconf import OmegaConf +from torch.nn import Module from torchgeo.datamodules import ChesapeakeCVPRDataModule, SeasonalContrastS2DataModule from torchgeo.datasets import SeasonalContrastS2 @@ -16,6 +18,10 @@ from .test_classification import ClassificationTestModel +def create_model(*args: Any, **kwargs: Any) -> Module: + return ClassificationTestModel(**kwargs) + + class TestSimCLRTask: @pytest.mark.parametrize( "name,classname", @@ -45,9 +51,9 @@ def test_trainer( datamodule = classname(**datamodule_kwargs) # Instantiate model + monkeypatch.setattr(timm, "create_model", create_model) model_kwargs = conf_dict["module"] model = SimCLRTask(**model_kwargs) - model.model = ClassificationTestModel() # Instantiate trainer trainer = Trainer( diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index 37c5f66f052..a03b4590d5e 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -20,8 +20,6 @@ except ImportError: from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from ..transforms import AugmentationSequential - class SimCLRTask(LightningModule): # type: ignore[misc] """SimCLR: a simple framework for contrastive learning of visual representations. @@ -70,7 +68,9 @@ def __init__( self.save_hyperparameters() - self.model = timm.create_model(model, in_chans=in_channels) + self.model = timm.create_model( + model, in_chans=in_channels, num_classes=hidden_dim + ) # Projection head # https://github.com/google-research/simclr/blob/master/model_util.py#L141 @@ -94,7 +94,7 @@ def __init__( # Data augmentation # https://github.com/google-research/simclr/blob/master/data_util.py - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.RandomResizedCrop(size=(96, 96)), K.RandomHorizontalFlip(), K.RandomVerticalFlip(), # added From 4c049c080a8b49c566d331a390c1ee55518d4a76 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 25 Mar 2023 16:13:34 -0500 Subject: [PATCH 08/10] Fix support for older Kornia versions --- torchgeo/trainers/simclr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index a03b4590d5e..7972d83c362 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -104,7 +104,7 @@ def __init__( # ) # K.RandomGrayscale(p=0.2), K.RandomGaussianBlur(kernel_size=9, sigma=(0.1, 2)), - data_keys=["image"], + data_keys=["input"], ) def forward(self, batch: Tensor) -> Tensor: From c404e3e5f1506add0dffdf710a1977f67ad19d11 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 25 Mar 2023 20:29:52 -0500 Subject: [PATCH 09/10] Fix support for older Kornia versions --- torchgeo/trainers/simclr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index 7972d83c362..575695134e3 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -103,7 +103,7 @@ def __init__( # brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8 # ) # K.RandomGrayscale(p=0.2), - K.RandomGaussianBlur(kernel_size=9, sigma=(0.1, 2)), + K.RandomGaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2)), data_keys=["input"], ) From 41a10d2efa7fb572f90a40b905a66c96f2415474 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 27 Mar 2023 10:24:08 -0500 Subject: [PATCH 10/10] Crop should be 224, not 96 --- torchgeo/trainers/simclr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index 575695134e3..88ac657af6c 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -95,7 +95,7 @@ def __init__( # Data augmentation # https://github.com/google-research/simclr/blob/master/data_util.py self.aug = K.AugmentationSequential( - K.RandomResizedCrop(size=(96, 96)), + K.RandomResizedCrop(size=(224, 224)), K.RandomHorizontalFlip(), K.RandomVerticalFlip(), # added # Not appropriate for multispectral imagery, seasonal contrast used instead