Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added More Object Detectors #984

Merged
merged 15 commits into from
Jan 23, 2023
14 changes: 13 additions & 1 deletion tests/trainers/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,17 @@ class TestObjectDetectionTask:
@pytest.mark.parametrize(
"name,classname", [("nasa_marine_debris", NASAMarineDebrisDataModule)]
)
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
@pytest.mark.parametrize(
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
"model_name,backbone",
[("faster-rcnn", "resnet18"), ("fcos", "resnet18"), ("retinanet", "resnet18")],
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
)
def test_trainer(
self,
model_name: str,
backbone: str,
name: str,
classname: Type[LightningDataModule],
) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", f"{name}.yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
Expand All @@ -28,6 +38,8 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:

# Instantiate model
model_kwargs = conf_dict["module"]
model_kwargs["model"] = model_name
model_kwargs["backbone"] = backbone
model = ObjectDetectionTask(**model_kwargs)

# Instantiate trainer
Expand Down
115 changes: 93 additions & 22 deletions torchgeo/trainers/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Detection tasks."""

from functools import partial
from typing import Any, Dict, List, cast

import matplotlib.pyplot as plt
Expand All @@ -12,13 +13,26 @@
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models import resnet as R
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection import FCOS, FasterRCNN, RetinaNet
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.retinanet import RetinaNetHead
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
from torchvision.ops import MultiScaleRoIAlign, feature_pyramid_network, misc

from ..datasets.utils import unbind_samples

BACKBONE_LAT_DIM_MAP = {
"resnet18": 512,
"resnet34": 512,
"resnet50": 2048,
"resnet101": 2048,
"resnet152": 2048,
"resnext50_32x4d": 2048,
"resnext101_32x8d": 2048,
"wide_resnet50_2": 2048,
"wide_resnet101_2": 2048,
}

BACKBONE_WEIGHT_MAP = {
"resnet18": R.ResNet18_Weights.DEFAULT,
"resnet34": R.ResNet34_Weights.DEFAULT,
Expand All @@ -31,13 +45,26 @@
"wide_resnet101_2": R.Wide_ResNet101_2_Weights.DEFAULT,
}

BACKBONE_LAT_DIM_MAP = {
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
"resnet18": 512,
"resnet34": 512,
"resnet50": 2048,
"resnet101": 2048,
"resnet152": 2048,
"resnext50_32x4d": 2048,
"resnext101_32x8d": 2048,
"wide_resnet50_2": 2048,
"wide_resnet101_2": 2048,
}


class ObjectDetectionTask(pl.LightningModule):
"""LightningModule for object detection of images.

Currently, supports a Faster R-CNN model from
Currently, supports Faster R-CNN, FCOS, and RetinaNet models from
`torchvision
<https://pytorch.org/vision/stable/models/faster_rcnn.html>`_ with
<https://pytorch.org/vision/stable/models.html
#object-detection-instance-segmentation-and-person-keypoint-detection>`_ with
one of the following *backbone* arguments:

.. code-block:: python
Expand All @@ -52,40 +79,84 @@ class ObjectDetectionTask(pl.LightningModule):
def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
backbone_pretrained = self.hyperparams.get("pretrained", True)
if self.hyperparams["model"] == "faster-rcnn":
if "resnet" in self.hyperparams["backbone"]:
kwargs = {
"backbone_name": self.hyperparams["backbone"],
"trainable_layers": self.hyperparams.get("trainable_layers", 3),
}
if backbone_pretrained:
kwargs["weights"] = BACKBONE_WEIGHT_MAP[
self.hyperparams["backbone"]
]
else:
kwargs["weights"] = None

backbone = resnet_fpn_backbone(**kwargs)

if (
"resnet" in self.hyperparams["backbone"]
or "resnext" in self.hyperparams["backbone"]
):
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
kwargs = {
"backbone_name": self.hyperparams["backbone"],
"trainable_layers": self.hyperparams.get("trainable_layers", 3),
}
if backbone_pretrained:
kwargs["weights"] = BACKBONE_WEIGHT_MAP[self.hyperparams["backbone"]]
else:
raise ValueError(
f"Backbone type '{self.hyperparams['backbone']}' is not valid."
)
kwargs["weights"] = None

latent_dim = BACKBONE_LAT_DIM_MAP[self.hyperparams["backbone"]]
else:
raise ValueError(
f"Backbone type '{self.hyperparams['backbone']}' is not valid."
)

num_classes = self.hyperparams["num_classes"]

if self.hyperparams["model"] == "faster-rcnn":
backbone = resnet_fpn_backbone(**kwargs)
anchor_generator = AnchorGenerator(
sizes=((32), (64), (128), (256), (512)), aspect_ratios=((0.5, 1.0, 2.0))
)

roi_pooler = MultiScaleRoIAlign(
featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2
)
num_classes = self.hyperparams["num_classes"]
self.model = FasterRCNN(
backbone,
num_classes,
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler,
)
elif self.hyperparams["model"] == "fcos":
kwargs["extra_blocks"] = feature_pyramid_network.LastLevelP6P7(256, 256)
kwargs["norm_layer"] = (
misc.FrozenBatchNorm2d if kwargs["weights"] else torch.nn.BatchNorm2d
)

backbone = resnet_fpn_backbone(**kwargs)
anchor_generator = AnchorGenerator(
sizes=((8,), (16,), (32,), (64,), (128,), (256,)),
aspect_ratios=((1.0,), (1.0,), (1.0,), (1.0,), (1.0,), (1.0,)),
)

self.model = FCOS(backbone, num_classes, anchor_generator=anchor_generator)

elif self.hyperparams["model"] == "retinanet":
kwargs["extra_blocks"] = feature_pyramid_network.LastLevelP6P7(
latent_dim, 256
)
backbone = resnet_fpn_backbone(**kwargs)

anchor_sizes = (
(16, 20, 25),
(32, 40, 50),
(64, 80, 101),
(128, 161, 203),
(256, 322, 406),
(512, 645, 812),
)
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)

head = RetinaNetHead(
backbone.out_channels,
anchor_generator.num_anchors_per_location()[0],
num_classes,
norm_layer=partial(torch.nn.GroupNorm, 32),
)

self.model = RetinaNet(
backbone, num_classes, anchor_generator=anchor_generator, head=head
)
else:
raise ValueError(f"Model type '{self.hyperparams['model']}' is not valid.")

Expand Down