Skip to content

Commit

Permalink
Change regression task to timm support (microsoft#854)
Browse files Browse the repository at this point in the history
* change regression task to timm support

* add docstring about available models

* typo again

* failing test

* change name

* change name

* expose all available models

* docstring list_models

* Update torchgeo/trainers/regression.py

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
3 people authored Dec 7, 2022
1 parent 0d46c1c commit 2907f20
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 38 deletions.
3 changes: 3 additions & 0 deletions tests/conf/cowc_counting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ experiment:
task: cowc_counting
module:
model: resnet18
weights: "random"
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
pretrained: True
Expand Down
3 changes: 3 additions & 0 deletions tests/conf/cyclone.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ experiment:
task: "cyclone"
module:
model: "resnet18"
weights: "random"
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
pretrained: False
Expand Down
48 changes: 31 additions & 17 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any, Dict, Type, cast

import pytest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer

Expand Down Expand Up @@ -63,22 +62,37 @@ def test_no_logger(self) -> None:
)
trainer.fit(model=model, datamodule=datamodule)

def test_invalid_model(self) -> None:
match = "module 'torchvision.models' has no attribute 'invalid_model'"
with pytest.raises(AttributeError, match=match):
RegressionTask(model="invalid_model", pretrained=False)

@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
return {"model": "resnet18", "pretrained": False}

def test_missing_attributes(
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch
return {
"model": "resnet18",
"weights": "random",
"num_outputs": 1,
"in_channels": 3,
}

def test_invalid_pretrained(
self, model_kwargs: Dict[Any, Any], checkpoint: str
) -> None:
monkeypatch.delattr(COWCCountingDataModule, "plot")
datamodule = COWCCountingDataModule(
root="tests/data/cowc_counting", batch_size=1, num_workers=0
)
model = RegressionTask(**model_kwargs)
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1)
trainer.validate(model=model, datamodule=datamodule)
model_kwargs["weights"] = checkpoint
model_kwargs["model"] = "resnet50"
match = "Trying to load resnet18 weights into a resnet50"
with pytest.raises(ValueError, match=match):
RegressionTask(**model_kwargs)

def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint
with pytest.warns(UserWarning):
RegressionTask(**model_kwargs)

def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["model"] = "invalid_model"
match = "Model type 'invalid_model' is not a valid timm model."
with pytest.raises(ValueError, match=match):
RegressionTask(**model_kwargs)

def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["weights"] = "invalid_weights"
match = "Weight type 'invalid_weights' is not valid."
with pytest.raises(ValueError, match=match):
RegressionTask(**model_kwargs)
80 changes: 59 additions & 21 deletions torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@

"""Regression tasks."""

import os
from typing import Any, Dict, cast

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from packaging.version import parse
from torch import Tensor
from torch.nn.modules import Conv2d, Linear
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection

from ..datasets.utils import unbind_samples
from . import utils

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Expand All @@ -26,36 +26,74 @@


class RegressionTask(pl.LightningModule):
"""LightningModule for training models on regression datasets."""
"""LightningModule for training models on regression datasets.
Supports any available `Timm model
<https://rwightman.github.io/pytorch-image-models/>`_
as an architecture choice. To see a list of available
models, you can do:
.. code-block:: python
import timm
print(timm.list_models())
"""

def config_task(self) -> None:
"""Configures the task based on kwargs parameters."""
in_channels = self.hyperparams["in_channels"]
model = self.hyperparams["model"]
pretrained = self.hyperparams["pretrained"]

if parse(torchvision.__version__) >= parse("0.13"):
if pretrained:
kwargs = {
"weights": getattr(
torchvision.models, f"ResNet{model[6:]}_Weights"
).DEFAULT
}

imagenet_pretrained = False
custom_pretrained = False
if self.hyperparams["weights"] and not os.path.exists(
self.hyperparams["weights"]
):
if self.hyperparams["weights"] not in ["imagenet", "random"]:
raise ValueError(
f"Weight type '{self.hyperparams['weights']}' is not valid."
)
else:
kwargs = {"weights": None}
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
custom_pretrained = False
else:
custom_pretrained = True

# Create the model
valid_models = timm.list_models(pretrained=imagenet_pretrained)
if model in valid_models:
self.model = timm.create_model(
model,
num_classes=self.hyperparams["num_outputs"],
in_chans=in_channels,
pretrained=imagenet_pretrained,
)
else:
kwargs = {"pretrained": pretrained}
raise ValueError(f"Model type '{model}' is not a valid timm model.")

self.model = getattr(torchvision.models, model)(**kwargs)
in_features = self.model.fc.in_features
self.model.fc = nn.Linear(in_features, out_features=1)
if custom_pretrained:
name, state_dict = utils.extract_encoder(self.hyperparams["weights"])

if self.hyperparams["model"] != name:
raise ValueError(
f"Trying to load {name} weights into a "
f"{self.hyperparams['model']}"
)
self.model = utils.load_state_dict(self.model, state_dict)

def __init__(self, **kwargs: Any) -> None:
"""Initialize a new LightningModule for training simple regression models.
Keyword Args:
model: Name of the model to use
learning_rate: Initial learning rate to use in the optimizer
learning_rate_schedule_patience: Patience parameter for the LR scheduler
model: Name of the timm model to use
weights: Either "random" or "imagenet"
num_outputs: Number of prediction outputs
in_channels: Number of input channels to model
learning_rate: Learning rate for optimizer
learning_rate_schedule_patience: Patience for learning rate scheduler
.. versionchanged:: 0.4
Change regression model support from torchvision.models to timm
"""
super().__init__()

Expand Down

0 comments on commit 2907f20

Please sign in to comment.