From 6324e42c9593457b7f2220ffde5619e27b9677f6 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Wed, 25 Jan 2023 23:32:51 +0000 Subject: [PATCH] add pretrained weights loading for the segmentation encoder --- torchgeo/trainers/segmentation.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 759b4ffe41c..495d48e9cb6 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -3,6 +3,7 @@ """Segmentation tasks.""" +import os import warnings from typing import Any, cast @@ -15,9 +16,11 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics import MetricCollection from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex +from torchvision.models._api import WeightsEnum from ..datasets.utils import unbind_samples -from ..models import FCN +from ..models import FCN, get_weight +from . import utils class SemanticSegmentationTask(LightningModule): # type: ignore[misc] @@ -31,17 +34,19 @@ class SemanticSegmentationTask(LightningModule): # type: ignore[misc] def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" + weights = self.hyperparams["weights"] + if self.hyperparams["model"] == "unet": self.model = smp.Unet( encoder_name=self.hyperparams["backbone"], - encoder_weights=self.hyperparams["weights"], + encoder_weights="imagenet" if weights is True else None, in_channels=self.hyperparams["in_channels"], classes=self.hyperparams["num_classes"], ) elif self.hyperparams["model"] == "deeplabv3+": self.model = smp.DeepLabV3Plus( encoder_name=self.hyperparams["backbone"], - encoder_weights=self.hyperparams["weights"], + encoder_weights="imagenet" if weights is True else None, in_channels=self.hyperparams["in_channels"], classes=self.hyperparams["num_classes"], ) @@ -74,6 +79,16 @@ def config_task(self) -> None: f"Currently, supports 'ce', 'jaccard' or 'focal' loss." ) + if self.hyperparams["model"] != "fcn": + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) + else: + state_dict = get_weight(weights).get_state_dict(progress=True) + self.model.encoder = utils.load_state_dict(self.model, state_dict) + def __init__(self, **kwargs: Any) -> None: """Initialize the LightningModule with a model and loss function.