Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add crop logic to Potsdam2D datamodule #929

Merged
merged 3 commits into from
Dec 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions conf/potsdam2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
experiment:
task: "potsdam2d"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 4
num_classes: 6
num_filters: 1
ignore_index: null
datamodule:
root: "data/potsdam"
num_tiles_per_batch: 16
num_patches_per_tile: 16
patch_size: 64
val_split_pct: 0.5
num_workers: 0
21 changes: 21 additions & 0 deletions tests/conf/potsdam2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
experiment:
task: "potsdam2d"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 4
num_classes: 6
num_filters: 1
ignore_index: null
datamodule:
root: "tests/data/potsdam"
num_tiles_per_batch: 1
num_patches_per_tile: 1
patch_size: 2
val_split_pct: 0.5
num_workers: 0
44 changes: 0 additions & 44 deletions tests/datamodules/test_potsdam.py

This file was deleted.

2 changes: 2 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
LandCoverAIDataModule,
LoveDADataModule,
NAIPChesapeakeDataModule,
Potsdam2DDataModule,
SEN12MSDataModule,
SpaceNet1DataModule,
Vaihingen2DDataModule,
Expand Down Expand Up @@ -46,6 +47,7 @@ class TestSemanticSegmentationTask:
("landcoverai", LandCoverAIDataModule),
("loveda", LoveDADataModule),
("naipchesapeake", NAIPChesapeakeDataModule),
("potsdam2d", Potsdam2DDataModule),
("sen12ms_all", SEN12MSDataModule),
("sen12ms_s1", SEN12MSDataModule),
("sen12ms_s2_all", SEN12MSDataModule),
Expand Down
136 changes: 83 additions & 53 deletions torchgeo/datamodules/potsdam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@

"""Potsdam datamodule."""

from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple, Union

import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose
from einops import rearrange
from kornia.augmentation import Normalize
from torch import Tensor
from torch.utils.data import DataLoader

from ..datasets import Potsdam2D
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop
from .utils import dataset_split


Expand All @@ -24,105 +29,130 @@ class Potsdam2DDataModule(pl.LightningDataModule):

def __init__(
self,
batch_size: int = 64,
num_workers: int = 0,
num_tiles_per_batch: int = 16,
num_patches_per_tile: int = 16,
patch_size: Union[Tuple[int, int], int] = 64,
val_split_pct: float = 0.2,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for Potsdam2D based DataLoaders.
"""Initialize a new LightningDataModule instance.

The Potsdam2D dataset contains images that are too large to pass
directly through a model. Instead, we randomly sample patches from image tiles
during training and chop up image tiles into patch grids during evaluation.
During training, the effective batch size is equal to
``num_tiles_per_batch`` x ``num_patches_per_tile``.

Args:
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
num_tiles_per_batch: The number of image tiles to sample from during
training
num_patches_per_tile: The number of patches to randomly sample from each
image tile during training
patch_size: The size of each patch, either ``size`` or ``(height, width)``.
Should be a multiple of 32 for most segmentation architectures
val_split_pct: The percentage of the dataset to use as a validation set
num_workers: The number of workers to use for parallel data loading
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.Potsdam2D`

.. versionchanged:: 0.4
*batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*,
and *patch_size*.
"""
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers

self.num_tiles_per_batch = num_tiles_per_batch
self.num_patches_per_tile = num_patches_per_tile
self.patch_size = _to_tuple(patch_size)
self.val_split_pct = val_split_pct
self.num_workers = num_workers
self.kwargs = kwargs

def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.

Args:
sample: input image dictionary

Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
self.train_transform = AugmentationSequential(
Normalize(mean=0.0, std=255.0),
_RandomNCrop(self.patch_size, self.num_patches_per_tile),
data_keys=["image", "mask"],
)
self.test_transform = AugmentationSequential(
Normalize(mean=0.0, std=255.0),
_ExtractTensorPatches(self.patch_size),
data_keys=["image", "mask"],
)

def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
"""Initialize the main Dataset objects.

This method is called once per GPU per run.

Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])

dataset = Potsdam2D(split="train", transforms=transforms, **self.kwargs)

self.train_dataset: Dataset[Any]
self.val_dataset: Dataset[Any]

if self.val_split_pct > 0.0:
self.train_dataset, self.val_dataset, _ = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=0.0
)
else:
self.train_dataset = dataset
self.val_dataset = dataset

self.test_dataset = Potsdam2D(
split="test", transforms=transforms, **self.kwargs
train_dataset = Potsdam2D(split="train", **self.kwargs)
self.train_dataset, self.val_dataset = dataset_split(
train_dataset, self.val_split_pct
)
self.test_dataset = Potsdam2D(split="test", **self.kwargs)

def train_dataloader(self) -> DataLoader[Any]:
def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for training.

Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
batch_size=self.num_tiles_per_batch,
num_workers=self.num_workers,
shuffle=True,
)

def val_dataloader(self) -> DataLoader[Any]:
def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for validation.

Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)

def test_dataloader(self) -> DataLoader[Any]:
def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]:
"""Return a DataLoader for testing.

Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)

def on_after_batch_transfer(
self, batch: Dict[str, Tensor], dataloader_idx: int
) -> Dict[str, Tensor]:
"""Apply augmentations to batch after transferring to GPU.

Args:
batch: A batch of data that needs to be altered or augmented
dataloader_idx: The index of the dataloader to which the batch belongs

Returns:
A batch of data
"""
# Kornia requires masks to have a channel dimension
batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w")

if self.trainer:
if self.trainer.training:
batch = self.train_transform(batch)
elif self.trainer.validating or self.trainer.testing:
batch = self.test_transform(batch)

# Torchmetrics does not support masks with a channel dimension
batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w")

return batch

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.Potsdam2D.plot`.

Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/potsdam.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _load_image(self, index: int) -> Tensor:
path = self.files[index]["image"]
with rasterio.open(path) as f:
array = f.read()
tensor = torch.from_numpy(array)
tensor = torch.from_numpy(array).float()
return tensor

def _load_target(self, index: int) -> Tensor:
Expand Down
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
LoveDADataModule,
NAIPChesapeakeDataModule,
NASAMarineDebrisDataModule,
Potsdam2DDataModule,
RESISC45DataModule,
SEN12MSDataModule,
So2SatDataModule,
Expand Down Expand Up @@ -58,6 +59,7 @@
"loveda": (SemanticSegmentationTask, LoveDADataModule),
"naipchesapeake": (SemanticSegmentationTask, NAIPChesapeakeDataModule),
"nasa_marine_debris": (ObjectDetectionTask, NASAMarineDebrisDataModule),
"potsdam2d": (SemanticSegmentationTask, Potsdam2DDataModule),
"resisc45": (ClassificationTask, RESISC45DataModule),
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule),
"so2sat": (ClassificationTask, So2SatDataModule),
Expand Down