Skip to content

Commit

Permalink
Add datamodule for GID-15 dataset (#928)
Browse files Browse the repository at this point in the history
* add datamodule with crop logic

* remove print and fix batch_size

* typo

* Use Kornia augmentations

* Style

* Ignore warning

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
nilsleh and adamjstewart authored Dec 30, 2022
1 parent 449656f commit 2bf1a36
Show file tree
Hide file tree
Showing 19 changed files with 234 additions and 1 deletion.
21 changes: 21 additions & 0 deletions conf/gid15.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
experiment:
task: "gid15"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 16
num_filters: 1
ignore_index: null
datamodule:
root: "data/gid15"
num_tiles_per_batch: 16
num_patches_per_tile: 16
patch_size: 64
val_split_pct: 0.5
num_workers: 0
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ FAIR1M

.. autoclass:: FAIR1MDataModule

GID-15
^^^^^^

.. autoclass:: GID15DataModule

Inria Aerial Image Labeling
^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ filterwarnings = [
# https://github.com/lanpa/tensorboardX/issues/653
# https://github.com/lanpa/tensorboardX/pull/654
"ignore:Call to deprecated create function:DeprecationWarning:tensorboardX",
# https://github.com/kornia/kornia/issues/777
"ignore:Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0:UserWarning:torch.nn.functional",

# Expected warnings
# pytorch-lightning warns us about using num_workers=0, but it's faster on macOS
Expand Down
22 changes: 22 additions & 0 deletions tests/conf/gid15.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
experiment:
task: "gid15"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 16
num_filters: 1
ignore_index: null
datamodule:
root: "tests/data/gid15"
download: true
num_tiles_per_batch: 1
num_patches_per_tile: 1
patch_size: 2
val_split_pct: 0.5
num_workers: 0
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 2 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ChesapeakeCVPRDataModule,
DeepGlobeLandCoverDataModule,
ETCI2021DataModule,
GID15DataModule,
InriaAerialImageLabelingDataModule,
LandCoverAIDataModule,
LoveDADataModule,
Expand All @@ -41,6 +42,7 @@ class TestSemanticSegmentationTask:
("chesapeake_cvpr_5", ChesapeakeCVPRDataModule),
("deepglobelandcover", DeepGlobeLandCoverDataModule),
("etci2021", ETCI2021DataModule),
("gid15", GID15DataModule),
("inria_train", InriaAerialImageLabelingDataModule),
("inria_val", InriaAerialImageLabelingDataModule),
("inria_test", InriaAerialImageLabelingDataModule),
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .etci2021 import ETCI2021DataModule
from .eurosat import EuroSATDataModule
from .fair1m import FAIR1MDataModule
from .gid15 import GID15DataModule
from .inria import InriaAerialImageLabelingDataModule
from .landcoverai import LandCoverAIDataModule
from .loveda import LoveDADataModule
Expand Down Expand Up @@ -38,6 +39,7 @@
"ETCI2021DataModule",
"EuroSATDataModule",
"FAIR1MDataModule",
"GID15DataModule",
"InriaAerialImageLabelingDataModule",
"LandCoverAIDataModule",
"LoveDADataModule",
Expand Down
177 changes: 177 additions & 0 deletions torchgeo/datamodules/gid15.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""GID-15 datamodule."""

from typing import Any, Dict, Optional, Tuple, Union

import matplotlib.pyplot as plt
import pytorch_lightning as pl
from einops import rearrange
from kornia.augmentation import Normalize
from torch import Tensor
from torch.utils.data import DataLoader

from ..datasets import GID15
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop
from .utils import dataset_split


class GID15DataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the GID-15 dataset.
Uses the train/test splits from the dataset.
.. versionadded:: 0.4
"""

def __init__(
self,
num_tiles_per_batch: int = 16,
num_patches_per_tile: int = 16,
patch_size: Union[Tuple[int, int], int] = 64,
val_split_pct: float = 0.2,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new LightningDataModule instance.
The GID-15 dataset contains images that are too large to pass
directly through a model. Instead, we randomly sample patches from image tiles
during training and chop up image tiles into patch grids during evaluation.
During training, the effective batch size is equal to
``num_tiles_per_batch`` x ``num_patches_per_tile``.
Args:
num_tiles_per_batch: The number of image tiles to sample from during
training
num_patches_per_tile: The number of patches to randomly sample from each
image tile during training
patch_size: The size of each patch, either ``size`` or ``(height, width)``.
Should be a multiple of 32 for most segmentation architectures
val_split_pct: The percentage of the dataset to use as a validation set
num_workers: The number of workers to use for parallel data loading
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.GID15`
"""
super().__init__()

self.num_tiles_per_batch = num_tiles_per_batch
self.num_patches_per_tile = num_patches_per_tile
self.patch_size = _to_tuple(patch_size)
self.val_split_pct = val_split_pct
self.num_workers = num_workers
self.kwargs = kwargs

self.train_transform = AugmentationSequential(
Normalize(mean=0.0, std=255.0),
_RandomNCrop(self.patch_size, self.num_patches_per_tile),
data_keys=["image", "mask"],
)
self.val_transform = AugmentationSequential(
Normalize(mean=0.0, std=255.0),
_ExtractTensorPatches(self.patch_size),
data_keys=["image", "mask"],
)
self.predict_transform = AugmentationSequential(
Normalize(mean=0.0, std=255.0),
_ExtractTensorPatches(self.patch_size),
data_keys=["image"],
)

def prepare_data(self) -> None:
"""Initialize the main Dataset objects for use in :func:`setup`.
This includes optionally downloading the dataset. This is done once per node,
while :func:`setup` is done once per GPU.
"""
if self.kwargs.get("download", False):
GID15(**self.kwargs)

def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main Dataset objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
train_dataset = GID15(split="train", **self.kwargs)
self.train_dataset, self.val_dataset = dataset_split(
train_dataset, self.val_split_pct
)

# Test set masks are not public, use for prediction instead
self.predict_dataset = GID15(split="test", **self.kwargs)

def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.num_tiles_per_batch,
num_workers=self.num_workers,
shuffle=True,
)

def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)

def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for predicting.
Returns:
predicting data loader
"""
return DataLoader(
self.predict_dataset,
batch_size=1,
num_workers=self.num_workers,
shuffle=False,
)

def on_after_batch_transfer(
self, batch: Dict[str, Tensor], dataloader_idx: int
) -> Dict[str, Tensor]:
"""Apply augmentations to batch after transferring to GPU.
Args:
batch: A batch of data that needs to be altered or augmented
dataloader_idx: The index of the dataloader to which the batch belongs
Returns:
A batch of data
"""
# Kornia requires masks to have a channel dimension
if "mask" in batch:
batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w")

if self.trainer:
if self.trainer.training:
batch = self.train_transform(batch)
elif self.trainer.validating:
batch = self.val_transform(batch)
elif self.trainer.predicting:
batch = self.predict_transform(batch)

# Torchmetrics does not support masks with a channel dimension
if "mask" in batch:
batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w")

return batch

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.GID15.plot`."""
return self.predict_dataset.plot(*args, **kwargs)
2 changes: 1 addition & 1 deletion torchgeo/datasets/gid15.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _load_image(self, path: str) -> Tensor:
array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB"))
tensor = torch.from_numpy(array)
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
tensor = tensor.permute((2, 0, 1)).float()
return tensor

def _load_target(self, path: str) -> Tensor:
Expand Down
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DeepGlobeLandCoverDataModule,
ETCI2021DataModule,
EuroSATDataModule,
GID15DataModule,
InriaAerialImageLabelingDataModule,
LandCoverAIDataModule,
LoveDADataModule,
Expand Down Expand Up @@ -54,6 +55,7 @@
"deepglobelandcover": (SemanticSegmentationTask, DeepGlobeLandCoverDataModule),
"eurosat": (ClassificationTask, EuroSATDataModule),
"etci2021": (SemanticSegmentationTask, ETCI2021DataModule),
"gid15": (SemanticSegmentationTask, GID15DataModule),
"inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule),
"landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule),
"loveda": (SemanticSegmentationTask, LoveDADataModule),
Expand Down

0 comments on commit 2bf1a36

Please sign in to comment.