-
Notifications
You must be signed in to change notification settings - Fork 385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SimCLR: add new trainer #1195
SimCLR: add new trainer #1195
Changes from 7 commits
10f0bde
38d20fb
5e9592d
853f7ed
26f79eb
ac3d4ae
d482a0e
4c049c0
c404e3e
41a10d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
experiment: | ||
task: "chesapeake_cvpr" | ||
module: | ||
model: "resnet18" | ||
in_channels: 4 | ||
version: 1 | ||
layers: 2 | ||
weight_decay: 1e-6 | ||
max_epochs: 1 | ||
datamodule: | ||
root: "tests/data/chesapeake/cvpr" | ||
download: false | ||
train_splits: | ||
- "de-test" | ||
val_splits: | ||
- "de-test" | ||
test_splits: | ||
- "de-test" | ||
batch_size: 2 | ||
patch_size: 64 | ||
num_workers: 0 | ||
class_set: 5 | ||
use_prior_labels: True |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
experiment: | ||
task: "chesapeake_cvpr" | ||
module: | ||
model: "resnet18" | ||
in_channels: 4 | ||
version: 2 | ||
layers: 3 | ||
weight_decay: 1e-4 | ||
max_epochs: 1 | ||
datamodule: | ||
root: "tests/data/chesapeake/cvpr" | ||
download: false | ||
train_splits: | ||
- "de-test" | ||
val_splits: | ||
- "de-test" | ||
test_splits: | ||
- "de-test" | ||
batch_size: 2 | ||
patch_size: 64 | ||
num_workers: 0 | ||
class_set: 5 | ||
use_prior_labels: True |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
experiment: | ||
task: "seco" | ||
module: | ||
model: "resnet18" | ||
in_channels: 3 | ||
version: 1 | ||
layers: 2 | ||
weight_decay: 1e-6 | ||
max_epochs: 1 | ||
datamodule: | ||
root: "tests/data/seco" | ||
seasons: 1 | ||
batch_size: 2 | ||
num_workers: 0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
experiment: | ||
task: "seco" | ||
module: | ||
model: "resnet18" | ||
in_channels: 3 | ||
version: 2 | ||
layers: 3 | ||
weight_decay: 1e-4 | ||
max_epochs: 1 | ||
datamodule: | ||
root: "tests/data/seco" | ||
seasons: 2 | ||
batch_size: 2 | ||
num_workers: 0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
from typing import Any, Dict, Type, cast | ||
|
||
import pytest | ||
import timm | ||
from _pytest.monkeypatch import MonkeyPatch | ||
from lightning.pytorch import LightningDataModule, Trainer | ||
from omegaconf import OmegaConf | ||
from torch.nn import Module | ||
|
||
from torchgeo.datamodules import ChesapeakeCVPRDataModule, SeasonalContrastS2DataModule | ||
from torchgeo.datasets import SeasonalContrastS2 | ||
from torchgeo.trainers import SimCLRTask | ||
|
||
from .test_classification import ClassificationTestModel | ||
|
||
|
||
def create_model(*args: Any, **kwargs: Any) -> Module: | ||
return ClassificationTestModel(**kwargs) | ||
|
||
|
||
class TestSimCLRTask: | ||
@pytest.mark.parametrize( | ||
"name,classname", | ||
[ | ||
("chesapeake_cvpr_prior_simclr_1", ChesapeakeCVPRDataModule), | ||
("chesapeake_cvpr_prior_simclr_2", ChesapeakeCVPRDataModule), | ||
("seco_simclr_1", SeasonalContrastS2DataModule), | ||
("seco_simclr_2", SeasonalContrastS2DataModule), | ||
], | ||
) | ||
def test_trainer( | ||
self, | ||
monkeypatch: MonkeyPatch, | ||
name: str, | ||
classname: Type[LightningDataModule], | ||
fast_dev_run: bool, | ||
) -> None: | ||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) | ||
conf_dict = OmegaConf.to_object(conf.experiment) | ||
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) | ||
|
||
if name.startswith("seco"): | ||
monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2) | ||
|
||
# Instantiate datamodule | ||
datamodule_kwargs = conf_dict["datamodule"] | ||
datamodule = classname(**datamodule_kwargs) | ||
|
||
# Instantiate model | ||
monkeypatch.setattr(timm, "create_model", create_model) | ||
model_kwargs = conf_dict["module"] | ||
model = SimCLRTask(**model_kwargs) | ||
|
||
# Instantiate trainer | ||
trainer = Trainer( | ||
accelerator="cpu", | ||
fast_dev_run=fast_dev_run, | ||
log_every_n_steps=1, | ||
max_epochs=1, | ||
) | ||
trainer.fit(model=model, datamodule=datamodule) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
"""SimCLR trainers for self-supervised learning (SSL).""" | ||
|
||
from typing import Dict, List, Tuple | ||
|
||
import kornia.augmentation as K | ||
import timm | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from lightning.pytorch import LightningModule | ||
from torch import Tensor | ||
from torch.optim import AdamW, Optimizer | ||
from torch.optim.lr_scheduler import CosineAnnealingLR | ||
|
||
try: | ||
from torch.optim.lr_scheduler import LRScheduler | ||
except ImportError: | ||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler | ||
|
||
|
||
class SimCLRTask(LightningModule): # type: ignore[misc] | ||
"""SimCLR: a simple framework for contrastive learning of visual representations. | ||
|
||
Implementation based on: | ||
|
||
* https://github.com/google-research/simclr | ||
* https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/13-contrastive-learning.html | ||
|
||
If you use this trainer in your research, please cite the following papers: | ||
|
||
* v1: https://arxiv.org/abs/2002.05709 | ||
* v2: https://arxiv.org/abs/2006.10029 | ||
|
||
.. versionadded:: 0.5 | ||
""" # noqa: E501 | ||
|
||
def __init__( | ||
self, | ||
model: str = "resnet50", | ||
in_channels: int = 3, | ||
version: int = 2, | ||
layers: int = 3, | ||
hidden_dim: int = 128, | ||
lr: float = 4.8, | ||
weight_decay: float = 1e-4, | ||
max_epochs: int = 100, | ||
temperature: float = 0.07, | ||
) -> None: | ||
"""Initialize a new SimCLRTask instance. | ||
|
||
Args: | ||
model: Name of the timm model to use. | ||
in_channels: Number of input channels to model. | ||
version: Version of SimCLR, 1--2. | ||
layers: Number of layers in projection head (2 for v1 or 3+ for v2). | ||
hidden_dim: Number of hidden dimensions in projection head. | ||
lr: Learning rate (0.3 x batch_size / 256 is recommended). | ||
weight_decay: Weight decay coefficient (1e-6 for v1 or 1e-4 for v2). | ||
max_epochs: Maximum number of epochs to train for. | ||
temperature: Temperature used in InfoNCE loss. | ||
""" | ||
super().__init__() | ||
|
||
assert version in range(1, 3) | ||
|
||
self.save_hyperparameters() | ||
|
||
self.model = timm.create_model( | ||
model, in_chans=in_channels, num_classes=hidden_dim | ||
) | ||
|
||
# Projection head | ||
# https://github.com/google-research/simclr/blob/master/model_util.py#L141 | ||
for i in range(layers): | ||
if i == layers - 1: | ||
# For the final layer, skip bias and ReLU | ||
self.model.fc = nn.Sequential( | ||
self.model.fc, nn.Linear(hidden_dim, hidden_dim, bias=False) | ||
) | ||
else: | ||
# For the middle layers, use bias and ReLU | ||
self.model.fc = nn.Sequential( | ||
self.model.fc, | ||
nn.ReLU(inplace=True), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is ReLU the desired/required activation function choice here or should that be more flexible? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's just what the original paper used. Depends on how much customization we want to support. |
||
nn.Linear(hidden_dim, hidden_dim, bias=True), | ||
) | ||
|
||
# TODO | ||
# v1+: add global batch norm | ||
# v2: add selective kernels, channel-wise attention mechanism, memory bank | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure exactly how to make these changes, and I don't really want to change the architecture too much to ensure that our pre-trained weights can be loaded in a vanilla model. The memory bank only adds +1% performance, so I don't really think it's worth the complexity. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say that when most papers use SimCLR, they use v1 without all the tricks that get the ~1% improvement. I think it would be better to keep it simple. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The majority of the performance bump in v2 is thanks to the deeper projection head, which we have, so we should be good. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, in terms of performance bump:
So memory bank isn't high on my priority list, but adding SK and channel-wise attention may be worth it. |
||
|
||
# Data augmentation | ||
# https://github.com/google-research/simclr/blob/master/data_util.py | ||
self.aug = K.AugmentationSequential( | ||
K.RandomResizedCrop(size=(96, 96)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we be hardcoding this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's hardcoded in BYOL (should actually be 224, not 96, let me fix this). We can make it a parameter if you want, but at the moment I don't know if we need it to be. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm okay with fixing it for now but it's only set to 224 because that's what imagenet experiments use. It's probably better to not restrict to 224 in case we use higher res imagery. |
||
K.RandomHorizontalFlip(), | ||
K.RandomVerticalFlip(), # added | ||
# Not appropriate for multispectral imagery, seasonal contrast used instead | ||
# K.ColorJitter( | ||
# brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8 | ||
# ) | ||
# K.RandomGrayscale(p=0.2), | ||
K.RandomGaussianBlur(kernel_size=9, sigma=(0.1, 2)), | ||
data_keys=["image"], | ||
) | ||
|
||
def forward(self, batch: Tensor) -> Tensor: | ||
"""Forward pass of the model. | ||
|
||
Args: | ||
batch: Mini-batch of images. | ||
|
||
Returns: | ||
Output from the model. | ||
""" | ||
batch = self.model(batch) | ||
return batch | ||
|
||
def training_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Tensor: | ||
"""Compute the training loss and additional metrics. | ||
|
||
Args: | ||
batch: The output of your DataLoader. | ||
batch_idx: Integer displaying index of this batch. | ||
|
||
Returns: | ||
The loss tensor. | ||
""" | ||
x = batch["image"] | ||
|
||
in_channels = self.hparams["in_channels"] | ||
assert x.size(1) == in_channels or x.size(1) == 2 * in_channels | ||
|
||
if x.size(1) == in_channels: | ||
x1 = x | ||
x2 = x | ||
else: | ||
x1 = x[:, :in_channels] | ||
x2 = x[:, in_channels:] | ||
|
||
# Apply augmentations independently for each season | ||
x1 = self.aug(x1) | ||
x2 = self.aug(x2) | ||
|
||
x = torch.cat([x1, x2], dim=0) | ||
|
||
# Encode all images | ||
x = self(x) | ||
|
||
# Calculate cosine similarity | ||
cos_sim = F.cosine_similarity(x[:, None, :], x[None, :, :], dim=-1) | ||
|
||
# Mask out cosine similarity to itself | ||
self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could make NT-XEnt loss it's own nn.Module so we can test it and reuse it. Maybe in a future PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found several repos with their own InfoNCE loss implementation but they all implement it different and I don't know the math well enough to decide which is best. The implementation here assumes that there is exactly 1 positive pair and everything else is a negative pair. A more general implementation, or a faster implementation, is a lot more work to get right. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the implementation is fine. I was suggesting that we make it a separate module since other SSL methods use it as well. But until we have another SSL method that uses it, I think it's fine to leave as is. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well I'm about to add MoCo which also uses it, although their implementation is completely different, and I have no idea what the difference is. |
||
cos_sim.masked_fill_(self_mask, -9e15) | ||
|
||
# Find positive example -> batch_size // 2 away from the original example | ||
pos_mask = self_mask.roll(shifts=cos_sim.shape[0] // 2, dims=0) | ||
|
||
# NT-Xent loss (aka InfoNCE loss) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both SimCLR and MoCo use InfoNCE loss, but there is no implementation in PyTorch. There are many libraries that implement it, but I'd rather not add yet another dependency. |
||
cos_sim = cos_sim / self.hparams["temperature"] | ||
nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1) | ||
nll = nll.mean() | ||
|
||
self.log("train_loss", nll) | ||
|
||
return nll | ||
|
||
def validation_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: | ||
"""No-op, does nothing.""" | ||
|
||
def test_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: | ||
"""No-op, does nothing.""" | ||
# TODO | ||
# v2: add distillation step | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would actually be very useful to add someday. Both using a large model to better train a small model, and self-distillation, have been found to greatly improve performance. I didn't do this because I'm not super familiar with teacher-student distillation methods. |
||
|
||
def predict_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: | ||
"""No-op, does nothing.""" | ||
|
||
def configure_optimizers(self) -> Tuple[List[Optimizer], List[LRScheduler]]: | ||
"""Initialize the optimizer and learning rate scheduler. | ||
|
||
Returns: | ||
Optimizer and learning rate scheduler. | ||
""" | ||
# Original paper uses LARS optimizer, but this is not defined in PyTorch | ||
optimizer = AdamW( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the optimizer choice also be user defineable, as different model architectures work better with certain optimizers? Or would you expect/want a user to overwrite the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our BYOL trainer supports specifying an optimizer, but none of the other trainers do. For now, I'm just using the optimizer used in the original paper. The only difficulty with making it user configurable is that each optimizer has different arguments. We could add a |
||
self.parameters(), | ||
lr=self.hparams["lr"], | ||
weight_decay=self.hparams["weight_decay"], | ||
) | ||
lr_scheduler = CosineAnnealingLR( | ||
optimizer, T_max=self.hparams["max_epochs"], eta_min=self.hparams["lr"] / 50 | ||
) | ||
return [optimizer], [lr_scheduler] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't really used at the moment since layers and weight_decay are also parameters, but it could be used to control other things in the future (see TODOs).