Skip to content

Commit

Permalink
SimCLR: add new trainer (#1195)
Browse files Browse the repository at this point in the history
* SimCLR: add new trainer

* Add tests

* Support custom number of MLP layers

* Change default params, add TODOs

* Fix mypy

* Fix docs and most of tests

* Fix all tests

* Fix support for older Kornia versions

* Fix support for older Kornia versions

* Crop should be 224, not 96
  • Loading branch information
adamjstewart authored Mar 29, 2023
1 parent 0dd3d06 commit 39d6941
Show file tree
Hide file tree
Showing 12 changed files with 345 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
("py:class", "timm.models.resnet.ResNet"),
("py:class", "timm.models.vision_transformer.VisionTransformer"),
("py:class", "torch.optim.lr_scheduler.LRScheduler"),
("py:class", "torchvision.models._api.WeightsEnum"),
("py:class", "torchvision.models.resnet.ResNet"),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ experiment:
module:
loss: "ce"
model: "unet"
backbone: "resnet50"
backbone: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 4
Expand All @@ -13,7 +13,7 @@ experiment:
weights: null
datamodule:
root: "tests/data/chesapeake/cvpr"
download: true
download: false
train_splits:
- "de-test"
val_splits:
Expand Down
23 changes: 23 additions & 0 deletions tests/conf/chesapeake_cvpr_prior_simclr_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
experiment:
task: "chesapeake_cvpr"
module:
model: "resnet18"
in_channels: 4
version: 1
layers: 2
weight_decay: 1e-6
max_epochs: 1
datamodule:
root: "tests/data/chesapeake/cvpr"
download: false
train_splits:
- "de-test"
val_splits:
- "de-test"
test_splits:
- "de-test"
batch_size: 2
patch_size: 64
num_workers: 0
class_set: 5
use_prior_labels: True
23 changes: 23 additions & 0 deletions tests/conf/chesapeake_cvpr_prior_simclr_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
experiment:
task: "chesapeake_cvpr"
module:
model: "resnet18"
in_channels: 4
version: 2
layers: 3
weight_decay: 1e-4
max_epochs: 1
datamodule:
root: "tests/data/chesapeake/cvpr"
download: false
train_splits:
- "de-test"
val_splits:
- "de-test"
test_splits:
- "de-test"
batch_size: 2
patch_size: 64
num_workers: 0
class_set: 5
use_prior_labels: True
File renamed without changes.
File renamed without changes.
14 changes: 14 additions & 0 deletions tests/conf/seco_simclr_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
experiment:
task: "seco"
module:
model: "resnet18"
in_channels: 3
version: 1
layers: 2
weight_decay: 1e-6
max_epochs: 1
datamodule:
root: "tests/data/seco"
seasons: 1
batch_size: 2
num_workers: 0
14 changes: 14 additions & 0 deletions tests/conf/seco_simclr_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
experiment:
task: "seco"
module:
model: "resnet18"
in_channels: 3
version: 2
layers: 3
weight_decay: 1e-4
max_epochs: 1
datamodule:
root: "tests/data/seco"
seasons: 2
batch_size: 2
num_workers: 0
6 changes: 3 additions & 3 deletions tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ class TestBYOLTask:
@pytest.mark.parametrize(
"name,classname",
[
("chesapeake_cvpr_prior", ChesapeakeCVPRDataModule),
("seco_1", SeasonalContrastS2DataModule),
("seco_2", SeasonalContrastS2DataModule),
("chesapeake_cvpr_prior_byol", ChesapeakeCVPRDataModule),
("seco_byol_1", SeasonalContrastS2DataModule),
("seco_byol_2", SeasonalContrastS2DataModule),
],
)
def test_trainer(
Expand Down
65 changes: 65 additions & 0 deletions tests/trainers/test_simclr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from typing import Any, Dict, Type, cast

import pytest
import timm
from _pytest.monkeypatch import MonkeyPatch
from lightning.pytorch import LightningDataModule, Trainer
from omegaconf import OmegaConf
from torch.nn import Module

from torchgeo.datamodules import ChesapeakeCVPRDataModule, SeasonalContrastS2DataModule
from torchgeo.datasets import SeasonalContrastS2
from torchgeo.trainers import SimCLRTask

from .test_classification import ClassificationTestModel


def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)


class TestSimCLRTask:
@pytest.mark.parametrize(
"name,classname",
[
("chesapeake_cvpr_prior_simclr_1", ChesapeakeCVPRDataModule),
("chesapeake_cvpr_prior_simclr_2", ChesapeakeCVPRDataModule),
("seco_simclr_1", SeasonalContrastS2DataModule),
("seco_simclr_2", SeasonalContrastS2DataModule),
],
)
def test_trainer(
self,
monkeypatch: MonkeyPatch,
name: str,
classname: Type[LightningDataModule],
fast_dev_run: bool,
) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)

if name.startswith("seco"):
monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2)

# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
datamodule = classname(**datamodule_kwargs)

# Instantiate model
monkeypatch.setattr(timm, "create_model", create_model)
model_kwargs = conf_dict["module"]
model = SimCLRTask(**model_kwargs)

# Instantiate trainer
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)
2 changes: 2 additions & 0 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .detection import ObjectDetectionTask
from .regression import RegressionTask
from .segmentation import SemanticSegmentationTask
from .simclr import SimCLRTask

__all__ = (
"BYOLTask",
Expand All @@ -16,4 +17,5 @@
"ObjectDetectionTask",
"RegressionTask",
"SemanticSegmentationTask",
"SimCLRTask",
)
198 changes: 198 additions & 0 deletions torchgeo/trainers/simclr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""SimCLR trainers for self-supervised learning (SSL)."""

from typing import Dict, List, Tuple

import kornia.augmentation as K
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch import LightningModule
from torch import Tensor
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR

try:
from torch.optim.lr_scheduler import LRScheduler
except ImportError:
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler


class SimCLRTask(LightningModule): # type: ignore[misc]
"""SimCLR: a simple framework for contrastive learning of visual representations.
Implementation based on:
* https://github.com/google-research/simclr
* https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/13-contrastive-learning.html
If you use this trainer in your research, please cite the following papers:
* v1: https://arxiv.org/abs/2002.05709
* v2: https://arxiv.org/abs/2006.10029
.. versionadded:: 0.5
""" # noqa: E501

def __init__(
self,
model: str = "resnet50",
in_channels: int = 3,
version: int = 2,
layers: int = 3,
hidden_dim: int = 128,
lr: float = 4.8,
weight_decay: float = 1e-4,
max_epochs: int = 100,
temperature: float = 0.07,
) -> None:
"""Initialize a new SimCLRTask instance.
Args:
model: Name of the timm model to use.
in_channels: Number of input channels to model.
version: Version of SimCLR, 1--2.
layers: Number of layers in projection head (2 for v1 or 3+ for v2).
hidden_dim: Number of hidden dimensions in projection head.
lr: Learning rate (0.3 x batch_size / 256 is recommended).
weight_decay: Weight decay coefficient (1e-6 for v1 or 1e-4 for v2).
max_epochs: Maximum number of epochs to train for.
temperature: Temperature used in InfoNCE loss.
"""
super().__init__()

assert version in range(1, 3)

self.save_hyperparameters()

self.model = timm.create_model(
model, in_chans=in_channels, num_classes=hidden_dim
)

# Projection head
# https://github.com/google-research/simclr/blob/master/model_util.py#L141
for i in range(layers):
if i == layers - 1:
# For the final layer, skip bias and ReLU
self.model.fc = nn.Sequential(
self.model.fc, nn.Linear(hidden_dim, hidden_dim, bias=False)
)
else:
# For the middle layers, use bias and ReLU
self.model.fc = nn.Sequential(
self.model.fc,
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim, bias=True),
)

# TODO
# v1+: add global batch norm
# v2: add selective kernels, channel-wise attention mechanism, memory bank

# Data augmentation
# https://github.com/google-research/simclr/blob/master/data_util.py
self.aug = K.AugmentationSequential(
K.RandomResizedCrop(size=(224, 224)),
K.RandomHorizontalFlip(),
K.RandomVerticalFlip(), # added
# Not appropriate for multispectral imagery, seasonal contrast used instead
# K.ColorJitter(
# brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8
# )
# K.RandomGrayscale(p=0.2),
K.RandomGaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2)),
data_keys=["input"],
)

def forward(self, batch: Tensor) -> Tensor:
"""Forward pass of the model.
Args:
batch: Mini-batch of images.
Returns:
Output from the model.
"""
batch = self.model(batch)
return batch

def training_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Tensor:
"""Compute the training loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
Returns:
The loss tensor.
"""
x = batch["image"]

in_channels = self.hparams["in_channels"]
assert x.size(1) == in_channels or x.size(1) == 2 * in_channels

if x.size(1) == in_channels:
x1 = x
x2 = x
else:
x1 = x[:, :in_channels]
x2 = x[:, in_channels:]

# Apply augmentations independently for each season
x1 = self.aug(x1)
x2 = self.aug(x2)

x = torch.cat([x1, x2], dim=0)

# Encode all images
x = self(x)

# Calculate cosine similarity
cos_sim = F.cosine_similarity(x[:, None, :], x[None, :, :], dim=-1)

# Mask out cosine similarity to itself
self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
cos_sim.masked_fill_(self_mask, -9e15)

# Find positive example -> batch_size // 2 away from the original example
pos_mask = self_mask.roll(shifts=cos_sim.shape[0] // 2, dims=0)

# NT-Xent loss (aka InfoNCE loss)
cos_sim = cos_sim / self.hparams["temperature"]
nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
nll = nll.mean()

self.log("train_loss", nll)

return nll

def validation_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None:
"""No-op, does nothing."""

def test_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None:
"""No-op, does nothing."""
# TODO
# v2: add distillation step

def predict_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None:
"""No-op, does nothing."""

def configure_optimizers(self) -> Tuple[List[Optimizer], List[LRScheduler]]:
"""Initialize the optimizer and learning rate scheduler.
Returns:
Optimizer and learning rate scheduler.
"""
# Original paper uses LARS optimizer, but this is not defined in PyTorch
optimizer = AdamW(
self.parameters(),
lr=self.hparams["lr"],
weight_decay=self.hparams["weight_decay"],
)
lr_scheduler = CosineAnnealingLR(
optimizer, T_max=self.hparams["max_epochs"], eta_min=self.hparams["lr"] / 50
)
return [optimizer], [lr_scheduler]

0 comments on commit 39d6941

Please sign in to comment.