Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SimCLR trainer #1252

Merged
merged 67 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
0d5279d
add simclr and tests
isaaccorley Apr 15, 2023
490b863
add lightly to reqs
isaaccorley Apr 15, 2023
a7d6e71
Merge branch 'main' into trainers/simclr
adamjstewart Apr 16, 2023
1dabbcd
pyupgrade
adamjstewart Apr 16, 2023
42f3be0
Copy things from prior implementation
adamjstewart Apr 16, 2023
554aa92
Add SimCLR v2 projection head
adamjstewart Apr 16, 2023
900d378
Remove kwargs
adamjstewart Apr 16, 2023
b580f19
Call __init__ explicitly
adamjstewart Apr 16, 2023
5a4408f
Fix mypy and docs
adamjstewart Apr 16, 2023
cda57df
Can't test newer setuptools
adamjstewart Apr 16, 2023
5bec352
Default to output dim of model
adamjstewart Apr 16, 2023
b1a51f8
Add memory bank
adamjstewart Apr 16, 2023
4b1a15d
Ignore erroneous warning
adamjstewart Apr 16, 2023
845519d
Fix configs, test SSL4EO
adamjstewart Apr 16, 2023
776071e
Fix a few layer bugs
adamjstewart Apr 16, 2023
7b86686
mypy fixes
adamjstewart Apr 16, 2023
3fba197
kernel_size must be an integer
adamjstewart Apr 16, 2023
3b48ed5
Fix SeCo in_channels
adamjstewart Apr 16, 2023
19c5052
Get more coverage
adamjstewart Apr 16, 2023
395fe37
Bump min lightly
adamjstewart Apr 16, 2023
1edc96a
Default logging
adamjstewart Apr 16, 2023
eeb5af9
Test weights
adamjstewart Apr 16, 2023
45b3dc6
mypy fix
adamjstewart Apr 16, 2023
be13c5e
Grab max_epochs from the trainer
adamjstewart Apr 17, 2023
1c32651
max_epochs param removed
adamjstewart Apr 17, 2023
0a65af4
Use num_features
adamjstewart Apr 18, 2023
8a4a2c2
Remove classification head
adamjstewart Apr 18, 2023
077f0ed
SimCLR uses LARS, with Adam as a backup
adamjstewart Apr 18, 2023
8dc3463
Add warnings
adamjstewart Apr 18, 2023
4082ee3
Grab num features directly from model
adamjstewart Apr 18, 2023
9e2b54d
Check if identity
adamjstewart Apr 18, 2023
fb78924
Match timm model design
adamjstewart Apr 19, 2023
e07a2c9
Capture warnings
adamjstewart Apr 19, 2023
f100aed
Fix tests
adamjstewart Apr 19, 2023
0421728
Increase coverage
adamjstewart Apr 19, 2023
bf36410
Fix method name
adamjstewart Apr 19, 2023
335d9af
More typos
adamjstewart Apr 19, 2023
1e90095
Escape regex
adamjstewart Apr 19, 2023
6a04435
Newer setuptools now supported
adamjstewart Apr 20, 2023
c37c510
New batch norm for every layer
adamjstewart Apr 20, 2023
93325e5
Merge branch 'main' into trainers/simclr
adamjstewart Apr 20, 2023
2549d51
Merge branch 'main' into trainers/simclr
adamjstewart Apr 21, 2023
d874b17
Merge branch 'main' into trainers/simclr
isaaccorley Apr 22, 2023
dbaeb3f
Rename forward arg
adamjstewart Apr 23, 2023
ce2de2a
Clarify usage of weights parameter
adamjstewart Apr 23, 2023
ab677f3
Fix flake8
adamjstewart Apr 23, 2023
8491a19
Merge branch 'main' into trainers/simclr
adamjstewart Apr 23, 2023
41ba0ff
Check it
calebrob6 Apr 23, 2023
7da418e
Use hydra
adamjstewart Apr 24, 2023
e0040e0
Track average L2 normed stdev over features
calebrob6 Apr 24, 2023
a76fda5
Merge branch 'trainers/simclr' of github.com:isaaccorley/torchgeo int…
calebrob6 Apr 24, 2023
2f838da
SimCLR decays lr to 0
adamjstewart Apr 24, 2023
3e9fb28
Add lr warmup
adamjstewart Apr 24, 2023
ba44bc9
Merge branch 'main' into trainers/simclr
adamjstewart Apr 24, 2023
04520cb
Fix version access
adamjstewart Apr 25, 2023
f0e18c7
Fix LinearLR
adamjstewart Apr 25, 2023
594babb
isinstance supports tuples
adamjstewart Apr 25, 2023
8aa26fa
Comment capitalization
adamjstewart Apr 25, 2023
5bc9c6d
Require lightly 1.4.3+
adamjstewart Apr 25, 2023
e74f55e
Require lightly 1.4.3+
adamjstewart Apr 25, 2023
a93cc23
Bump lightly version
adamjstewart May 3, 2023
a42902d
Merge branch 'main' into trainers/simclr
adamjstewart May 3, 2023
27b823d
Add RandomGrayscale
adamjstewart May 3, 2023
eb7b912
Flake8 fixes
adamjstewart May 3, 2023
19f6935
Placate pydocstyle
adamjstewart May 3, 2023
3b68504
Clarify docs
adamjstewart May 3, 2023
c3349c1
Pass correct weights
adamjstewart May 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
("py:class", "timm.models.resnet.ResNet"),
("py:class", "timm.models.vision_transformer.VisionTransformer"),
("py:class", "torch.optim.lr_scheduler.LRScheduler"),
("py:class", "torchvision.models._api.WeightsEnum"),
("py:class", "torchvision.models.resnet.ResNet"),
]
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies:
- isort[colors]>=5.8
- kornia>=0.6.5
- laspy>=2
- lightly>=1.4.4
- lightning>=1.8
- mypy>=0.900
- nbmake>=1.3.3
Expand Down
1 change: 1 addition & 0 deletions requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ setuptools==42.0.0
einops==0.3.0
fiona==1.8.19
kornia==0.6.5
lightly==1.4.4
lightning==1.8.0
matplotlib==3.3.3
numpy==1.19.3
Expand Down
1 change: 1 addition & 0 deletions requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ einops==0.6.1
fiona==1.9.3
kornia==0.6.12
lightning==2.0.2
lightly==1.4.4
matplotlib==3.7.1
numpy==1.24.3
pillow==9.5.0
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ install_requires =
fiona>=1.8.19,<2
# kornia 0.6.5+ required due to change in kornia.augmentation API
kornia>=0.6.5,<0.7
# lightly 1.4.4+ required for MoCo v3 support
lightly>=1.4.4
# lightning 1.8+ is first release
lightning>=1.8,<3
# matplotlib 3.3.3+ required for Python 3.9 wheels
Expand Down
23 changes: 23 additions & 0 deletions tests/conf/chesapeake_cvpr_prior_simclr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 4
version: 1
layers: 2
memory_bank_size: 0

datamodule:
_target_: torchgeo.datamodules.ChesapeakeCVPRDataModule
root: "tests/data/chesapeake/cvpr"
download: false
train_splits:
- "de-test"
val_splits:
- "de-test"
test_splits:
- "de-test"
batch_size: 2
patch_size: 64
num_workers: 0
class_set: 5
use_prior_labels: True
17 changes: 17 additions & 0 deletions tests/conf/seco_simclr_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 3
version: 1
layers: 2
hidden_dim: 8
output_dim: 8
weight_decay: 1e-6
memory_bank_size: 0

datamodule:
_target_: torchgeo.datamodules.SeasonalContrastS2DataModule
root: "tests/data/seco"
seasons: 1
batch_size: 2
num_workers: 0
17 changes: 17 additions & 0 deletions tests/conf/seco_simclr_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 3
version: 2
layers: 4
hidden_dim: 8
output_dim: 8
weight_decay: 1e-4
memory_bank_size: 10

datamodule:
_target_: torchgeo.datamodules.SeasonalContrastS2DataModule
root: "tests/data/seco"
seasons: 2
batch_size: 2
num_workers: 0
17 changes: 17 additions & 0 deletions tests/conf/ssl4eo_s12_simclr_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 13
version: 1
layers: 2
hidden_dim: 8
output_dim: 8
weight_decay: 1e-6
memory_bank_size: 0

datamodule:
_target_: torchgeo.datamodules.SSL4EOS12DataModule
root: "tests/data/ssl4eo/s12"
seasons: 1
batch_size: 2
num_workers: 0
17 changes: 17 additions & 0 deletions tests/conf/ssl4eo_s12_simclr_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 13
version: 2
layers: 3
hidden_dim: 8
output_dim: 8
weight_decay: 1e-4
memory_bank_size: 10

datamodule:
_target_: torchgeo.datamodules.SSL4EOS12DataModule
root: "tests/data/ssl4eo/s12"
seasons: 2
batch_size: 2
num_workers: 0
3 changes: 2 additions & 1 deletion tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(self, in_chans: int = 3, num_classes: int = 10, **kwargs: Any) -> N
super().__init__()
self.conv1 = nn.Conv2d(in_channels=in_chans, out_channels=1, kernel_size=1)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(1, num_classes)
self.fc = nn.Linear(1, num_classes) if num_classes else nn.Identity()
self.num_features = 1

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
Expand Down
154 changes: 154 additions & 0 deletions tests/trainers/test_simclr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from pathlib import Path
from typing import Any

import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import OmegaConf
from torch.nn import Module
from torchvision.models._api import WeightsEnum

from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
from torchgeo.models import get_model_weights, list_models
from torchgeo.trainers import SimCLRTask

from .test_classification import ClassificationTestModel


def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestSimCLRTask:
@pytest.mark.parametrize(
"name",
[
"chesapeake_cvpr_prior_simclr",
"seco_simclr_1",
"seco_simclr_2",
"ssl4eo_s12_simclr_1",
"ssl4eo_s12_simclr_2",
],
)
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))

if name.startswith("seco"):
monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2)

if name.startswith("ssl4eo_s12"):
monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2)

# Instantiate datamodule
datamodule = instantiate(conf.datamodule)

# Instantiate model
monkeypatch.setattr(timm, "create_model", create_model)
model = instantiate(conf.module)

# Instantiate trainer
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)

def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match="SimCLR v1 only uses 2 layers"):
SimCLRTask(version=1, layers=3)
with pytest.warns(UserWarning, match="SimCLR v1 does not use a memory bank"):
SimCLRTask(version=1, memory_bank_size=10)
with pytest.warns(UserWarning, match=r"SimCLR v2 uses 3\+ layers"):
SimCLRTask(version=2, layers=2)
with pytest.warns(UserWarning, match="SimCLR v2 uses a memory bank"):
SimCLRTask(version=2, memory_bank_size=0)

@pytest.fixture(
params=[
weights for model in list_models() for weights in get_model_weights(model)
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model(
weights.meta["model"], in_chans=weights.meta["in_chans"]
)
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, "url", str(path))
except AttributeError:
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights

def test_weight_file(self, checkpoint: str) -> None:
model_kwargs: dict[str, Any] = {"model": "resnet18", "weights": checkpoint}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)

def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": mocked_weights.meta["model"],
"weights": mocked_weights,
"in_channels": mocked_weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)

def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": mocked_weights.meta["model"],
"weights": str(mocked_weights),
"in_channels": mocked_weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)

@pytest.mark.slow
def test_weight_enum_download(self, weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": weights.meta["model"],
"weights": weights,
"in_channels": weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)

@pytest.mark.slow
def test_weight_str_download(self, weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": weights.meta["model"],
"weights": str(weights),
"in_channels": weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)
2 changes: 2 additions & 0 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .detection import ObjectDetectionTask
from .regression import PixelwiseRegressionTask, RegressionTask
from .segmentation import SemanticSegmentationTask
from .simclr import SimCLRTask

__all__ = (
"BYOLTask",
Expand All @@ -17,4 +18,5 @@
"PixelwiseRegressionTask",
"RegressionTask",
"SemanticSegmentationTask",
"SimCLRTask",
)
Loading