Skip to content

Commit

Permalink
Segmentation Pretrained Weights (#1046)
Browse files Browse the repository at this point in the history
* add pretrained weights loading for the segmentation encoder

* Updating config files to use new pretrained arg style

* fix loading weights enum to encoder

* I have no idea what I'm doing with these tests

* tests passing

* update docstring

* add the tests back in dummy

* add tests

---------

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
  • Loading branch information
isaaccorley and calebrob6 authored May 3, 2023
1 parent f8f05cf commit 1973d77
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ datamodule = InriaAerialImageLabelingDataModule(root="...", batch_size=64, num_w
task = SemanticSegmentationTask(
model="unet",
backbone="resnet50",
weights="imagenet",
weights=True,
in_channels=3,
num_classes=2,
loss="ce",
Expand Down
2 changes: 1 addition & 1 deletion conf/etci2021.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 6
Expand Down
2 changes: 1 addition & 1 deletion conf/inria.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
Expand Down
2 changes: 1 addition & 1 deletion conf/landcoverai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
Expand Down
2 changes: 1 addition & 1 deletion conf/naipchesapeake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "deeplabv3+"
backbone: "resnet34"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 2
in_channels: 4
Expand Down
2 changes: 1 addition & 1 deletion conf/spacenet1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/inria.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
Expand Down
76 changes: 76 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,26 @@
# Licensed under the MIT License.

import os
from pathlib import Path
from typing import Any, cast

import pytest
import segmentation_models_pytorch as smp
import timm
import torch
import torch.nn as nn
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import OmegaConf
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum

from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule
from torchgeo.datasets import LandCoverAI
from torchgeo.models import get_model_weights, list_models
from torchgeo.trainers import SemanticSegmentationTask


Expand All @@ -34,6 +40,11 @@ def create_model(**kwargs: Any) -> Module:
return SegmentationTestModel(**kwargs)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


def plot(*args: Any, **kwargs: Any) -> None:
raise ValueError

Expand Down Expand Up @@ -111,6 +122,71 @@ def model_kwargs(self) -> dict[Any, Any]:
"ignore_index": 0,
}

@pytest.fixture(
params=[
weights
for model in list_models()
for weights in get_model_weights(model)
if "resnet" in weights.meta["model"]
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model(
weights.meta["model"], in_chans=weights.meta["in_chans"]
)
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, "url", str(path))
except AttributeError:
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights

def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint
SemanticSegmentationTask(**model_kwargs)

def test_weight_enum(
self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = mocked_weights.meta["model"]
model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
model_kwargs["weights"] = mocked_weights
SemanticSegmentationTask(**model_kwargs)

def test_weight_str(
self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = mocked_weights.meta["model"]
model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
model_kwargs["weights"] = str(mocked_weights)
SemanticSegmentationTask(**model_kwargs)

@pytest.mark.slow
def test_weight_enum_download(
self, model_kwargs: dict[str, Any], weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = weights.meta["model"]
model_kwargs["in_channels"] = weights.meta["in_chans"]
model_kwargs["weights"] = weights
SemanticSegmentationTask(**model_kwargs)

@pytest.mark.slow
def test_weight_str_download(
self, model_kwargs: dict[str, Any], weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = weights.meta["model"]
model_kwargs["in_channels"] = weights.meta["in_chans"]
model_kwargs["weights"] = str(weights)
SemanticSegmentationTask(**model_kwargs)

def test_invalid_model(self, model_kwargs: dict[Any, Any]) -> None:
model_kwargs["model"] = "invalid_model"
match = "Model type 'invalid_model' is not valid."
Expand Down
30 changes: 25 additions & 5 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Segmentation tasks."""

import os
import warnings
from typing import Any, cast

Expand All @@ -15,9 +16,11 @@
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
from torchvision.models._api import WeightsEnum

from ..datasets.utils import unbind_samples
from ..models import FCN
from ..models import FCN, get_weight
from . import utils


class SemanticSegmentationTask(LightningModule): # type: ignore[misc]
Expand All @@ -31,17 +34,19 @@ class SemanticSegmentationTask(LightningModule): # type: ignore[misc]

def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
weights = self.hyperparams["weights"]

if self.hyperparams["model"] == "unet":
self.model = smp.Unet(
encoder_name=self.hyperparams["backbone"],
encoder_weights=self.hyperparams["weights"],
encoder_weights="imagenet" if weights is True else None,
in_channels=self.hyperparams["in_channels"],
classes=self.hyperparams["num_classes"],
)
elif self.hyperparams["model"] == "deeplabv3+":
self.model = smp.DeepLabV3Plus(
encoder_name=self.hyperparams["backbone"],
encoder_weights=self.hyperparams["weights"],
encoder_weights="imagenet" if weights is True else None,
in_channels=self.hyperparams["in_channels"],
classes=self.hyperparams["num_classes"],
)
Expand Down Expand Up @@ -80,6 +85,16 @@ def config_task(self) -> None:
f"Currently, supports 'ce', 'jaccard' or 'focal' loss."
)

if self.hyperparams["model"] != "fcn":
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model.encoder.load_state_dict(state_dict)

# Freeze backbone
if self.hyperparams.get("freeze_backbone", False) and self.hyperparams[
"model"
Expand All @@ -100,8 +115,10 @@ def __init__(self, **kwargs: Any) -> None:
Keyword Args:
model: Name of the segmentation model type to use
backbone: Name of the timm backbone to use
weights: None or "imagenet" to use imagenet pretrained weights in
the backbone
weights: Either a weight enum, the string representation of a weight enum,
True for ImageNet weights, False or None for random weights,
or the path to a saved model state dict. FCN model does not support
pretrained weights. Pretrained ViT weight enums are not supported yet.
in_channels: Number of channels in input image
num_classes: Number of semantic classes to predict
loss: Name of the loss function, currently supports
Expand All @@ -127,6 +144,9 @@ class and used with 'ce' loss
The *class_weights*, *freeze_backbone*,
and *freeze_decoder* parameters.
.. versionchanged:: 0.5
The *weights* parameter now supports WeightEnums and checkpoint paths.
"""
super().__init__()

Expand Down

0 comments on commit 1973d77

Please sign in to comment.