Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add datamodule for GID-15 dataset #928

Merged
merged 6 commits into from
Dec 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions conf/gid15.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
experiment:
task: "gid15"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 16
num_filters: 1
ignore_index: null
datamodule:
root: "data/gid15"
num_tiles_per_batch: 16
num_patches_per_tile: 16
patch_size: 64
val_split_pct: 0.5
num_workers: 0
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ FAIR1M

.. autoclass:: FAIR1MDataModule

GID-15
^^^^^^

.. autoclass:: GID15DataModule

Inria Aerial Image Labeling
^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ filterwarnings = [
# https://github.com/lanpa/tensorboardX/issues/653
# https://github.com/lanpa/tensorboardX/pull/654
"ignore:Call to deprecated create function:DeprecationWarning:tensorboardX",
# https://github.com/kornia/kornia/issues/777
"ignore:Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0:UserWarning:torch.nn.functional",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only needed because our tests are trying to crop a 2x2 patch from a 1x1 image, which obviously requires resizing. If someone ever implements a data.py for GID-15 and increases the image size, this can be removed.


# Expected warnings
# pytorch-lightning warns us about using num_workers=0, but it's faster on macOS
Expand Down
22 changes: 22 additions & 0 deletions tests/conf/gid15.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
experiment:
task: "gid15"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 16
num_filters: 1
ignore_index: null
datamodule:
root: "tests/data/gid15"
download: true
num_tiles_per_batch: 1
num_patches_per_tile: 1
patch_size: 2
val_split_pct: 0.5
num_workers: 0
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 2 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ChesapeakeCVPRDataModule,
DeepGlobeLandCoverDataModule,
ETCI2021DataModule,
GID15DataModule,
InriaAerialImageLabelingDataModule,
LandCoverAIDataModule,
LoveDADataModule,
Expand All @@ -41,6 +42,7 @@ class TestSemanticSegmentationTask:
("chesapeake_cvpr_5", ChesapeakeCVPRDataModule),
("deepglobelandcover", DeepGlobeLandCoverDataModule),
("etci2021", ETCI2021DataModule),
("gid15", GID15DataModule),
("inria_train", InriaAerialImageLabelingDataModule),
("inria_val", InriaAerialImageLabelingDataModule),
("inria_test", InriaAerialImageLabelingDataModule),
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .etci2021 import ETCI2021DataModule
from .eurosat import EuroSATDataModule
from .fair1m import FAIR1MDataModule
from .gid15 import GID15DataModule
from .inria import InriaAerialImageLabelingDataModule
from .landcoverai import LandCoverAIDataModule
from .loveda import LoveDADataModule
Expand Down Expand Up @@ -38,6 +39,7 @@
"ETCI2021DataModule",
"EuroSATDataModule",
"FAIR1MDataModule",
"GID15DataModule",
"InriaAerialImageLabelingDataModule",
"LandCoverAIDataModule",
"LoveDADataModule",
Expand Down
177 changes: 177 additions & 0 deletions torchgeo/datamodules/gid15.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""GID-15 datamodule."""

from typing import Any, Dict, Optional, Tuple, Union

import matplotlib.pyplot as plt
import pytorch_lightning as pl
from einops import rearrange
from kornia.augmentation import Normalize
from torch import Tensor
from torch.utils.data import DataLoader

from ..datasets import GID15
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop
from .utils import dataset_split


class GID15DataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the GID-15 dataset.

Uses the train/test splits from the dataset.

.. versionadded:: 0.4
"""

def __init__(
self,
num_tiles_per_batch: int = 16,
num_patches_per_tile: int = 16,
patch_size: Union[Tuple[int, int], int] = 64,
val_split_pct: float = 0.2,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new LightningDataModule instance.

The GID-15 dataset contains images that are too large to pass
directly through a model. Instead, we randomly sample patches from image tiles
during training and chop up image tiles into patch grids during evaluation.
During training, the effective batch size is equal to
``num_tiles_per_batch`` x ``num_patches_per_tile``.

Args:
num_tiles_per_batch: The number of image tiles to sample from during
training
num_patches_per_tile: The number of patches to randomly sample from each
image tile during training
patch_size: The size of each patch, either ``size`` or ``(height, width)``.
Should be a multiple of 32 for most segmentation architectures
val_split_pct: The percentage of the dataset to use as a validation set
num_workers: The number of workers to use for parallel data loading
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.GID15`
"""
super().__init__()

self.num_tiles_per_batch = num_tiles_per_batch
self.num_patches_per_tile = num_patches_per_tile
self.patch_size = _to_tuple(patch_size)
self.val_split_pct = val_split_pct
self.num_workers = num_workers
self.kwargs = kwargs

self.train_transform = AugmentationSequential(
Normalize(mean=0.0, std=255.0),
_RandomNCrop(self.patch_size, self.num_patches_per_tile),
data_keys=["image", "mask"],
)
self.val_transform = AugmentationSequential(
Normalize(mean=0.0, std=255.0),
_ExtractTensorPatches(self.patch_size),
data_keys=["image", "mask"],
)
self.predict_transform = AugmentationSequential(
Normalize(mean=0.0, std=255.0),
_ExtractTensorPatches(self.patch_size),
data_keys=["image"],
)

def prepare_data(self) -> None:
"""Initialize the main Dataset objects for use in :func:`setup`.

This includes optionally downloading the dataset. This is done once per node,
while :func:`setup` is done once per GPU.
"""
if self.kwargs.get("download", False):
GID15(**self.kwargs)

def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main Dataset objects.

This method is called once per GPU per run.

Args:
stage: stage to set up
"""
train_dataset = GID15(split="train", **self.kwargs)
self.train_dataset, self.val_dataset = dataset_split(
train_dataset, self.val_split_pct
)

# Test set masks are not public, use for prediction instead
self.predict_dataset = GID15(split="test", **self.kwargs)

def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for training.

Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.num_tiles_per_batch,
num_workers=self.num_workers,
shuffle=True,
)

def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for validation.

Returns:
validation data loader
"""
return DataLoader(
self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)

def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for predicting.

Returns:
predicting data loader
"""
return DataLoader(
self.predict_dataset,
batch_size=1,
num_workers=self.num_workers,
shuffle=False,
)

def on_after_batch_transfer(
self, batch: Dict[str, Tensor], dataloader_idx: int
) -> Dict[str, Tensor]:
"""Apply augmentations to batch after transferring to GPU.

Args:
batch: A batch of data that needs to be altered or augmented
dataloader_idx: The index of the dataloader to which the batch belongs

Returns:
A batch of data
"""
# Kornia requires masks to have a channel dimension
if "mask" in batch:
batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w")

if self.trainer:
if self.trainer.training:
batch = self.train_transform(batch)
elif self.trainer.validating:
batch = self.val_transform(batch)
elif self.trainer.predicting:
batch = self.predict_transform(batch)

# Torchmetrics does not support masks with a channel dimension
if "mask" in batch:
batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w")

return batch

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.GID15.plot`."""
return self.predict_dataset.plot(*args, **kwargs)
2 changes: 1 addition & 1 deletion torchgeo/datasets/gid15.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _load_image(self, path: str) -> Tensor:
array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB"))
tensor = torch.from_numpy(array)
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
tensor = tensor.permute((2, 0, 1)).float()
return tensor

def _load_target(self, path: str) -> Tensor:
Expand Down
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DeepGlobeLandCoverDataModule,
ETCI2021DataModule,
EuroSATDataModule,
GID15DataModule,
InriaAerialImageLabelingDataModule,
LandCoverAIDataModule,
LoveDADataModule,
Expand Down Expand Up @@ -54,6 +55,7 @@
"deepglobelandcover": (SemanticSegmentationTask, DeepGlobeLandCoverDataModule),
"eurosat": (ClassificationTask, EuroSATDataModule),
"etci2021": (SemanticSegmentationTask, ETCI2021DataModule),
"gid15": (SemanticSegmentationTask, GID15DataModule),
"inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule),
"landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule),
"loveda": (SemanticSegmentationTask, LoveDADataModule),
Expand Down