Skip to content

Commit

Permalink
Refactor BYOLTask tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Dec 28, 2021
1 parent 3db16f9 commit 8f72642
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 91 deletions.
7 changes: 4 additions & 3 deletions conf/task_defaults/chesapeake_cvpr_5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ experiment:
module:
loss: "ce"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_name: "resnet50"
encoder_weights: null
encoder_output_stride: 16
learning_rate: 1e-3
Expand All @@ -12,6 +12,7 @@ experiment:
num_classes: 5
num_filters: 1
ignore_zeros: False
imagenet_pretraining: False
datamodule:
root_dir: "tests/data/chesapeake/cvpr"
train_splits:
Expand All @@ -20,8 +21,8 @@ experiment:
- "de-test"
test_splits:
- "de-test"
patches_per_tile: 1
patches_per_tile: 2
patch_size: 64
batch_size: 1
batch_size: 2
num_workers: 0
class_set: ${experiment.module.num_classes}
5 changes: 3 additions & 2 deletions conf/task_defaults/chesapeake_cvpr_7.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ experiment:
num_classes: 7
num_filters: 1
ignore_zeros: False
imagenet_pretraining: False
datamodule:
root_dir: "tests/data/chesapeake/cvpr"
train_splits:
Expand All @@ -20,8 +21,8 @@ experiment:
- "de-test"
test_splits:
- "de-test"
patches_per_tile: 1
patches_per_tile: 2
patch_size: 64
batch_size: 1
batch_size: 2
num_workers: 0
class_set: ${experiment.module.num_classes}
115 changes: 44 additions & 71 deletions tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,96 +2,69 @@
# Licensed under the MIT License.

import os
from typing import Any, Dict, Generator, cast
from typing import Any, Dict, Type, cast

import pytest
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import LightningDataModule, Trainer
from torchvision.models import resnet18

from torchgeo.datamodules import ChesapeakeCVPRDataModule
from torchgeo.trainers import BYOLTask
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation

from .test_utils import mocked_log


class TestBYOL:
def test_custom_augment_fn(self) -> None:
encoder = resnet18()
layer = encoder.conv1
new_layer = nn.Conv2d( # type: ignore[attr-defined]
in_channels=4,
out_channels=layer.out_channels,
kernel_size=layer.kernel_size,
stride=layer.stride,
padding=layer.padding,
bias=layer.bias,
).requires_grad_()
encoder.conv1 = new_layer
augment_fn = SimCLRAugmentation((2, 2))
BYOL(encoder, augment_fn=augment_fn)
def test_custom_augment_fn(self) -> None:
encoder = resnet18()
layer = encoder.conv1
new_layer = nn.Conv2d( # type: ignore[attr-defined]
in_channels=4,
out_channels=layer.out_channels,
kernel_size=layer.kernel_size,
stride=layer.stride,
padding=layer.padding,
bias=layer.bias,
).requires_grad_()
encoder.conv1 = new_layer
augment_fn = SimCLRAugmentation((2, 2))
BYOL(encoder, augment_fn=augment_fn)


class TestBYOLTask:
@pytest.fixture(scope="class")
def datamodule(self) -> ChesapeakeCVPRDataModule:
dm = ChesapeakeCVPRDataModule(
os.path.join("tests", "data", "chesapeake", "cvpr"),
["de-test"],
["de-test"],
["de-test"],
patch_size=4,
patches_per_tile=2,
batch_size=2,
num_workers=0,
)
dm.prepare_data()
dm.setup()
return dm

@pytest.fixture(params=["resnet18", "resnet50"])
def config(self, request: SubRequest) -> Dict[str, Any]:
task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "byol.yaml"))
task_args = OmegaConf.to_object(task_conf.experiment.module)
task_args = cast(Dict[str, Any], task_args)
task_args["encoder"] = request.param
return task_args

@pytest.fixture
def task(
self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None]
) -> LightningModule:
task = BYOLTask(**config)
monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined]
return task

def test_configure_optimizers(self, task: BYOLTask) -> None:
out = task.configure_optimizers()
assert "optimizer" in out
assert "lr_scheduler" in out
@pytest.mark.parametrize(
"name,classname",
[
("chesapeake_cvpr_5", ChesapeakeCVPRDataModule),
("chesapeake_cvpr_7", ChesapeakeCVPRDataModule),
],
)
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)

def test_training(
self, datamodule: ChesapeakeCVPRDataModule, task: BYOLTask
) -> None:
batch = next(iter(datamodule.train_dataloader()))
task.training_step(batch, 0)
# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
datamodule = classname(**datamodule_kwargs)

def test_validation(
self, datamodule: ChesapeakeCVPRDataModule, task: BYOLTask
) -> None:
batch = next(iter(datamodule.val_dataloader()))
task.validation_step(batch, 0)
# Instantiate model
model_kwargs = conf_dict["module"]
model = BYOLTask(**model_kwargs)

def test_test(self, datamodule: ChesapeakeCVPRDataModule, task: BYOLTask) -> None:
batch = next(iter(datamodule.test_dataloader()))
task.test_step(batch, 0)
# Instantiate trainer
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1)
trainer.fit(model=model, datamodule=datamodule)
trainer.test(model=model, datamodule=datamodule)

def test_invalid_encoder(self, config: Dict[str, Any]) -> None:
config["encoder"] = "invalid_encoder"
def test_invalid_encoder(self) -> None:
kwargs = {
"in_channels": 1,
"imagenet_pretraining": False,
"encoder_name": "invalid_encoder",
}
error_message = "Encoder type 'invalid_encoder' is not valid."
with pytest.raises(ValueError, match=error_message):
BYOLTask(**config)
BYOLTask(**kwargs)
1 change: 0 additions & 1 deletion tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class TestSemanticSegmentationTask:
@pytest.mark.parametrize(
"name,classname",
[
("chesapeake_cvpr_5", ChesapeakeCVPRDataModule),
("chesapeake_cvpr_7", ChesapeakeCVPRDataModule),
("etci2021", ETCI2021DataModule),
("landcoverai", LandCoverAIDataModule),
Expand Down
28 changes: 14 additions & 14 deletions torchgeo/trainers/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def __init__(
model: Module,
image_size: Tuple[int, int] = (256, 256),
hidden_layer: Union[str, int] = -2,
input_channels: int = 4,
in_channels: int = 4,
projection_size: int = 256,
hidden_size: int = 4096,
augment_fn: Optional[Module] = None,
Expand All @@ -261,7 +261,7 @@ def __init__(
image_size: the size of the training images
hidden_layer: the hidden layer in ``model`` to attach the projection
head to, can be the name of the layer or index of the layer
input_channels: number of input channels to the model
in_channels: number of input channels to the model
projection_size: size of first layer of the projection MLP
hidden_size: size of the hidden layer of the projection MLP
augment_fn: an instance of a module that performs data augmentation
Expand All @@ -277,7 +277,7 @@ def __init__(
self.augment = augment_fn

self.beta = beta
self.input_channels = input_channels
self.in_channels = in_channels
self.encoder = EncoderWrapper(
model, projection_size, hidden_size, layer=hidden_layer
)
Expand All @@ -288,9 +288,7 @@ def __init__(

# Perform a single forward pass to initialize the wrapper correctly
self.encoder(
torch.zeros( # type: ignore[attr-defined]
2, self.input_channels, *image_size
)
torch.zeros(2, self.in_channels, *image_size) # type: ignore[attr-defined]
)

def forward(self, x: Tensor) -> Tensor:
Expand All @@ -315,21 +313,23 @@ class BYOLTask(LightningModule):

def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
input_channels = self.hparams["input_channels"]
in_channels = self.hparams["in_channels"]
pretrained = self.hparams["imagenet_pretraining"]
encoder = None

if self.hparams["encoder"] == "resnet18":
if self.hparams["encoder_name"] == "resnet18":
encoder = resnet18(pretrained=pretrained)
elif self.hparams["encoder"] == "resnet50":
elif self.hparams["encoder_name"] == "resnet50":
encoder = resnet50(pretrained=pretrained)
else:
raise ValueError(f"Encoder type '{self.hparams['encoder']}' is not valid.")
raise ValueError(
f"Encoder type '{self.hparams['encoder_name']}' is not valid."
)

layer = encoder.conv1
# Creating new Conv2d layer
new_layer = Conv2d(
in_channels=input_channels,
in_channels=in_channels,
out_channels=layer.out_channels,
kernel_size=layer.kernel_size,
stride=layer.stride,
Expand All @@ -343,7 +343,7 @@ def config_task(self) -> None:
... # type: ignore[index]
] = Variable(layer.weight.clone(), requires_grad=True)
# Copying the weights of the old layer to the extra channels
for i in range(input_channels - layer.in_channels):
for i in range(in_channels - layer.in_channels):
channel = layer.in_channels + i
new_layer.weight[:, channel : channel + 1, :, :].data[
... # type: ignore[index]
Expand All @@ -359,8 +359,8 @@ def __init__(self, **kwargs: Any) -> None:
"""Initialize a LightningModule for pre-training a model with BYOL.
Keyword Args:
input_channels: number of channels on the input imagery
encoder: either "resnet18" or "resnet50"
in_channels: number of channels on the input imagery
encoder_name: either "resnet18" or "resnet50"
imagenet_pretraining: bool indicating whether to use imagenet pretrained
weights
Expand Down

0 comments on commit 8f72642

Please sign in to comment.