Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change regression task to timm support #854

Merged
merged 13 commits into from
Dec 7, 2022
3 changes: 3 additions & 0 deletions tests/conf/cowc_counting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ experiment:
task: cowc_counting
module:
model: resnet18
weights: "random"
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
pretrained: True
Expand Down
3 changes: 3 additions & 0 deletions tests/conf/cyclone.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ experiment:
task: "cyclone"
module:
model: "resnet18"
weights: "random"
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
pretrained: False
Expand Down
48 changes: 31 additions & 17 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any, Dict, Type, cast

import pytest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer

Expand Down Expand Up @@ -63,22 +62,37 @@ def test_no_logger(self) -> None:
)
trainer.fit(model=model, datamodule=datamodule)

def test_invalid_model(self) -> None:
match = "module 'torchvision.models' has no attribute 'invalid_model'"
with pytest.raises(AttributeError, match=match):
RegressionTask(model="invalid_model", pretrained=False)

@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
return {"model": "resnet18", "pretrained": False}

def test_missing_attributes(
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch
return {
"model": "resnet18",
"weights": "random",
"num_outputs": 1,
"in_channels": 3,
}

def test_invalid_pretrained(
self, model_kwargs: Dict[Any, Any], checkpoint: str
) -> None:
monkeypatch.delattr(COWCCountingDataModule, "plot")
datamodule = COWCCountingDataModule(
root="tests/data/cowc_counting", batch_size=1, num_workers=0
)
model = RegressionTask(**model_kwargs)
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1)
trainer.validate(model=model, datamodule=datamodule)
model_kwargs["weights"] = checkpoint
model_kwargs["model"] = "resnet50"
match = "Trying to load resnet18 weights into a resnet50"
with pytest.raises(ValueError, match=match):
RegressionTask(**model_kwargs)

def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint
with pytest.warns(UserWarning):
RegressionTask(**model_kwargs)

def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["model"] = "invalid_model"
match = "Model type 'invalid_model' is not a valid timm model."
with pytest.raises(ValueError, match=match):
RegressionTask(**model_kwargs)

def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["weights"] = "invalid_weights"
match = "Weight type 'invalid_weights' is not valid."
with pytest.raises(ValueError, match=match):
RegressionTask(**model_kwargs)
80 changes: 59 additions & 21 deletions torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@

"""Regression tasks."""

import os
from typing import Any, Dict, cast

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from packaging.version import parse
from torch import Tensor
from torch.nn.modules import Conv2d, Linear
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection

from ..datasets.utils import unbind_samples
from . import utils

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Expand All @@ -26,36 +26,74 @@


class RegressionTask(pl.LightningModule):
"""LightningModule for training models on regression datasets."""
"""LightningModule for training models on regression datasets.

Supports any available `Timm model
<https://rwightman.github.io/pytorch-image-models/>`_
as an architecture choice. To see a list of available
models, you can do:

.. code-block:: python

import timm
print(timm.list_models())
"""

def config_task(self) -> None:
"""Configures the task based on kwargs parameters."""
in_channels = self.hyperparams["in_channels"]
model = self.hyperparams["model"]
pretrained = self.hyperparams["pretrained"]

if parse(torchvision.__version__) >= parse("0.13"):
if pretrained:
kwargs = {
"weights": getattr(
torchvision.models, f"ResNet{model[6:]}_Weights"
).DEFAULT
}

imagenet_pretrained = False
custom_pretrained = False
if self.hyperparams["weights"] and not os.path.exists(
self.hyperparams["weights"]
):
if self.hyperparams["weights"] not in ["imagenet", "random"]:
raise ValueError(
f"Weight type '{self.hyperparams['weights']}' is not valid."
)
else:
kwargs = {"weights": None}
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
custom_pretrained = False
else:
custom_pretrained = True

# Create the model
valid_models = timm.list_models(pretrained=imagenet_pretrained)
if model in valid_models:
self.model = timm.create_model(
model,
num_classes=self.hyperparams["num_outputs"],
in_chans=in_channels,
pretrained=imagenet_pretrained,
)
else:
kwargs = {"pretrained": pretrained}
raise ValueError(f"Model type '{model}' is not a valid timm model.")

self.model = getattr(torchvision.models, model)(**kwargs)
in_features = self.model.fc.in_features
self.model.fc = nn.Linear(in_features, out_features=1)
if custom_pretrained:
name, state_dict = utils.extract_encoder(self.hyperparams["weights"])

if self.hyperparams["model"] != name:
raise ValueError(
f"Trying to load {name} weights into a "
f"{self.hyperparams['model']}"
)
self.model = utils.load_state_dict(self.model, state_dict)

def __init__(self, **kwargs: Any) -> None:
"""Initialize a new LightningModule for training simple regression models.

Keyword Args:
model: Name of the model to use
learning_rate: Initial learning rate to use in the optimizer
learning_rate_schedule_patience: Patience parameter for the LR scheduler
model: Name of the timm model to use
weights: Either "random" or "imagenet"
num_outputs: Number of prediction outputs
in_channels: Number of input channels to model
learning_rate: Learning rate for optimizer
learning_rate_schedule_patience: Patience for learning rate scheduler

.. versionchanged:: 0.4
Change regression model support from torchvision.models to timm
"""
super().__init__()

Expand Down