Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SimCLR: add new trainer #1195

Merged
merged 10 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
("py:class", "timm.models.resnet.ResNet"),
("py:class", "timm.models.vision_transformer.VisionTransformer"),
("py:class", "torch.optim.lr_scheduler.LRScheduler"),
("py:class", "torchvision.models._api.WeightsEnum"),
("py:class", "torchvision.models.resnet.ResNet"),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ experiment:
module:
loss: "ce"
model: "unet"
backbone: "resnet50"
backbone: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 4
Expand All @@ -13,7 +13,7 @@ experiment:
weights: null
datamodule:
root: "tests/data/chesapeake/cvpr"
download: true
download: false
train_splits:
- "de-test"
val_splits:
Expand Down
23 changes: 23 additions & 0 deletions tests/conf/chesapeake_cvpr_prior_simclr_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
experiment:
task: "chesapeake_cvpr"
module:
model: "resnet18"
in_channels: 4
version: 1
layers: 2
weight_decay: 1e-6
max_epochs: 1
datamodule:
root: "tests/data/chesapeake/cvpr"
download: false
train_splits:
- "de-test"
val_splits:
- "de-test"
test_splits:
- "de-test"
batch_size: 2
patch_size: 64
num_workers: 0
class_set: 5
use_prior_labels: True
23 changes: 23 additions & 0 deletions tests/conf/chesapeake_cvpr_prior_simclr_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
experiment:
task: "chesapeake_cvpr"
module:
model: "resnet18"
in_channels: 4
version: 2
layers: 3
weight_decay: 1e-4
max_epochs: 1
datamodule:
root: "tests/data/chesapeake/cvpr"
download: false
train_splits:
- "de-test"
val_splits:
- "de-test"
test_splits:
- "de-test"
batch_size: 2
patch_size: 64
num_workers: 0
class_set: 5
use_prior_labels: True
File renamed without changes.
File renamed without changes.
14 changes: 14 additions & 0 deletions tests/conf/seco_simclr_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
experiment:
task: "seco"
module:
model: "resnet18"
in_channels: 3
version: 1
layers: 2
weight_decay: 1e-6
max_epochs: 1
datamodule:
root: "tests/data/seco"
seasons: 1
batch_size: 2
num_workers: 0
14 changes: 14 additions & 0 deletions tests/conf/seco_simclr_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
experiment:
task: "seco"
module:
model: "resnet18"
in_channels: 3
version: 2
layers: 3
weight_decay: 1e-4
max_epochs: 1
datamodule:
root: "tests/data/seco"
seasons: 2
batch_size: 2
num_workers: 0
6 changes: 3 additions & 3 deletions tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ class TestBYOLTask:
@pytest.mark.parametrize(
"name,classname",
[
("chesapeake_cvpr_prior", ChesapeakeCVPRDataModule),
("seco_1", SeasonalContrastS2DataModule),
("seco_2", SeasonalContrastS2DataModule),
("chesapeake_cvpr_prior_byol", ChesapeakeCVPRDataModule),
("seco_byol_1", SeasonalContrastS2DataModule),
("seco_byol_2", SeasonalContrastS2DataModule),
],
)
def test_trainer(
Expand Down
65 changes: 65 additions & 0 deletions tests/trainers/test_simclr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from typing import Any, Dict, Type, cast

import pytest
import timm
from _pytest.monkeypatch import MonkeyPatch
from lightning.pytorch import LightningDataModule, Trainer
from omegaconf import OmegaConf
from torch.nn import Module

from torchgeo.datamodules import ChesapeakeCVPRDataModule, SeasonalContrastS2DataModule
from torchgeo.datasets import SeasonalContrastS2
from torchgeo.trainers import SimCLRTask

from .test_classification import ClassificationTestModel


def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)


class TestSimCLRTask:
@pytest.mark.parametrize(
"name,classname",
[
("chesapeake_cvpr_prior_simclr_1", ChesapeakeCVPRDataModule),
("chesapeake_cvpr_prior_simclr_2", ChesapeakeCVPRDataModule),
("seco_simclr_1", SeasonalContrastS2DataModule),
("seco_simclr_2", SeasonalContrastS2DataModule),
],
)
def test_trainer(
self,
monkeypatch: MonkeyPatch,
name: str,
classname: Type[LightningDataModule],
fast_dev_run: bool,
) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)

if name.startswith("seco"):
monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2)

# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
datamodule = classname(**datamodule_kwargs)

# Instantiate model
monkeypatch.setattr(timm, "create_model", create_model)
model_kwargs = conf_dict["module"]
model = SimCLRTask(**model_kwargs)

# Instantiate trainer
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)
2 changes: 2 additions & 0 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .detection import ObjectDetectionTask
from .regression import RegressionTask
from .segmentation import SemanticSegmentationTask
from .simclr import SimCLRTask

__all__ = (
"BYOLTask",
Expand All @@ -16,4 +17,5 @@
"ObjectDetectionTask",
"RegressionTask",
"SemanticSegmentationTask",
"SimCLRTask",
)
198 changes: 198 additions & 0 deletions torchgeo/trainers/simclr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""SimCLR trainers for self-supervised learning (SSL)."""

from typing import Dict, List, Tuple

import kornia.augmentation as K
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch import LightningModule
from torch import Tensor
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR

try:
from torch.optim.lr_scheduler import LRScheduler
except ImportError:
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler


class SimCLRTask(LightningModule): # type: ignore[misc]
"""SimCLR: a simple framework for contrastive learning of visual representations.

Implementation based on:

* https://github.com/google-research/simclr
* https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/13-contrastive-learning.html

If you use this trainer in your research, please cite the following papers:

* v1: https://arxiv.org/abs/2002.05709
* v2: https://arxiv.org/abs/2006.10029

.. versionadded:: 0.5
""" # noqa: E501

def __init__(
self,
model: str = "resnet50",
in_channels: int = 3,
version: int = 2,
layers: int = 3,
hidden_dim: int = 128,
lr: float = 4.8,
weight_decay: float = 1e-4,
max_epochs: int = 100,
temperature: float = 0.07,
) -> None:
"""Initialize a new SimCLRTask instance.

Args:
model: Name of the timm model to use.
in_channels: Number of input channels to model.
version: Version of SimCLR, 1--2.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't really used at the moment since layers and weight_decay are also parameters, but it could be used to control other things in the future (see TODOs).

layers: Number of layers in projection head (2 for v1 or 3+ for v2).
hidden_dim: Number of hidden dimensions in projection head.
lr: Learning rate (0.3 x batch_size / 256 is recommended).
weight_decay: Weight decay coefficient (1e-6 for v1 or 1e-4 for v2).
max_epochs: Maximum number of epochs to train for.
temperature: Temperature used in InfoNCE loss.
"""
super().__init__()

assert version in range(1, 3)

self.save_hyperparameters()

self.model = timm.create_model(
model, in_chans=in_channels, num_classes=hidden_dim
)

# Projection head
# https://github.com/google-research/simclr/blob/master/model_util.py#L141
for i in range(layers):
if i == layers - 1:
# For the final layer, skip bias and ReLU
self.model.fc = nn.Sequential(
self.model.fc, nn.Linear(hidden_dim, hidden_dim, bias=False)
)
else:
# For the middle layers, use bias and ReLU
self.model.fc = nn.Sequential(
self.model.fc,
nn.ReLU(inplace=True),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ReLU the desired/required activation function choice here or should that be more flexible?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just what the original paper used. Depends on how much customization we want to support.

nn.Linear(hidden_dim, hidden_dim, bias=True),
)

# TODO
# v1+: add global batch norm
# v2: add selective kernels, channel-wise attention mechanism, memory bank
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure exactly how to make these changes, and I don't really want to change the architecture too much to ensure that our pre-trained weights can be loaded in a vanilla model. The memory bank only adds +1% performance, so I don't really think it's worth the complexity.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that when most papers use SimCLR, they use v1 without all the tricks that get the ~1% improvement. I think it would be better to keep it simple.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The majority of the performance bump in v2 is thanks to the deeper projection head, which we have, so we should be good.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, in terms of performance bump:

  • Bigger ResNets, SK, channel-wise attention: +29%
  • Deeper projection head: +14%
  • Memory bank: +1%

So memory bank isn't high on my priority list, but adding SK and channel-wise attention may be worth it.


# Data augmentation
# https://github.com/google-research/simclr/blob/master/data_util.py
self.aug = K.AugmentationSequential(
K.RandomResizedCrop(size=(96, 96)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be hardcoding this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hardcoded in BYOL (should actually be 224, not 96, let me fix this). We can make it a parameter if you want, but at the moment I don't know if we need it to be.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with fixing it for now but it's only set to 224 because that's what imagenet experiments use. It's probably better to not restrict to 224 in case we use higher res imagery.

K.RandomHorizontalFlip(),
K.RandomVerticalFlip(), # added
# Not appropriate for multispectral imagery, seasonal contrast used instead
# K.ColorJitter(
# brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8
# )
# K.RandomGrayscale(p=0.2),
K.RandomGaussianBlur(kernel_size=9, sigma=(0.1, 2)),
data_keys=["image"],
)

def forward(self, batch: Tensor) -> Tensor:
"""Forward pass of the model.

Args:
batch: Mini-batch of images.

Returns:
Output from the model.
"""
batch = self.model(batch)
return batch

def training_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Tensor:
"""Compute the training loss and additional metrics.

Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.

Returns:
The loss tensor.
"""
x = batch["image"]

in_channels = self.hparams["in_channels"]
assert x.size(1) == in_channels or x.size(1) == 2 * in_channels

if x.size(1) == in_channels:
x1 = x
x2 = x
else:
x1 = x[:, :in_channels]
x2 = x[:, in_channels:]

# Apply augmentations independently for each season
x1 = self.aug(x1)
x2 = self.aug(x2)

x = torch.cat([x1, x2], dim=0)

# Encode all images
x = self(x)

# Calculate cosine similarity
cos_sim = F.cosine_similarity(x[:, None, :], x[None, :, :], dim=-1)

# Mask out cosine similarity to itself
self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make NT-XEnt loss it's own nn.Module so we can test it and reuse it. Maybe in a future PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found several repos with their own InfoNCE loss implementation but they all implement it different and I don't know the math well enough to decide which is best. The implementation here assumes that there is exactly 1 positive pair and everything else is a negative pair. A more general implementation, or a faster implementation, is a lot more work to get right.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the implementation is fine. I was suggesting that we make it a separate module since other SSL methods use it as well. But until we have another SSL method that uses it, I think it's fine to leave as is.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I'm about to add MoCo which also uses it, although their implementation is completely different, and I have no idea what the difference is.

cos_sim.masked_fill_(self_mask, -9e15)

# Find positive example -> batch_size // 2 away from the original example
pos_mask = self_mask.roll(shifts=cos_sim.shape[0] // 2, dims=0)

# NT-Xent loss (aka InfoNCE loss)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both SimCLR and MoCo use InfoNCE loss, but there is no implementation in PyTorch. There are many libraries that implement it, but I'd rather not add yet another dependency.

cos_sim = cos_sim / self.hparams["temperature"]
nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
nll = nll.mean()

self.log("train_loss", nll)

return nll

def validation_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None:
"""No-op, does nothing."""

def test_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None:
"""No-op, does nothing."""
# TODO
# v2: add distillation step
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would actually be very useful to add someday. Both using a large model to better train a small model, and self-distillation, have been found to greatly improve performance. I didn't do this because I'm not super familiar with teacher-student distillation methods.


def predict_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None:
"""No-op, does nothing."""

def configure_optimizers(self) -> Tuple[List[Optimizer], List[LRScheduler]]:
"""Initialize the optimizer and learning rate scheduler.

Returns:
Optimizer and learning rate scheduler.
"""
# Original paper uses LARS optimizer, but this is not defined in PyTorch
optimizer = AdamW(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the optimizer choice also be user defineable, as different model architectures work better with certain optimizers? Or would you expect/want a user to overwrite the configure_optimizers method in their inherited trainer class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our BYOL trainer supports specifying an optimizer, but none of the other trainers do. For now, I'm just using the optimizer used in the original paper. The only difficulty with making it user configurable is that each optimizer has different arguments. We could add a **kwargs that is used in the optimizer to handle this, but then we can't use it anywhere else (without a bit of hacking like we did in NAIPChesapeakeDataModule.

self.parameters(),
lr=self.hparams["lr"],
weight_decay=self.hparams["weight_decay"],
)
lr_scheduler = CosineAnnealingLR(
optimizer, T_max=self.hparams["max_epochs"], eta_min=self.hparams["lr"] / 50
)
return [optimizer], [lr_scheduler]