Skip to content

Commit

Permalink
Add random crop logic to DeepGlobeLandCover Datamodule (microsoft#876)
Browse files Browse the repository at this point in the history
* crop logic

* typo

* change train_batch_size logic

* fix failing test

* typos and naming

* return argument train dataloader

* typo

* fix failing test

* suggestions except about test file

* remove test_deepglobe and add test to trainer

* forgot new conf file

* reanme collate function

* move cropping logic to transform and utils

* remove comment

* simplify

* move pad_segmentation to transforms

* another one

* naming and versionadded

* another transforms approach

* typo

* fix read the docs

* some checks for Ncrop

* add unit tests new transforms

* Remove cruft

* More simplification

* Add config file

* Implemented ExtractTensorPatches

* Remove tests

* Remove unnecessary attrs

* Apply to both input and mask

* Implement RandomNCrop

* Fix dimensions

* mypy fixes

* Fix docs

* Ensure that image and mask get the same transformation

* Bump min kornia version

* ignore still needed?

* Remove unneeded hacks

* Fix pydocstyle

* Fix dimensions

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
nilsleh and adamjstewart authored Dec 30, 2022
1 parent 80ef10c commit ef31509
Show file tree
Hide file tree
Showing 13 changed files with 256 additions and 90 deletions.
4 changes: 2 additions & 2 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ Kenya Crop Type

.. autoclass:: CV4AKenyaCropType

Deep Globe Land Cover
^^^^^^^^^^^^^^^^^^^^^
DeepGlobe Land Cover
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: DeepGlobeLandCover

Expand Down
4 changes: 2 additions & 2 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`Cloud Cover Detection`_,S,Sentinel-2,"22,728",2,512x512,10,MSI
`COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","388,435",2,256x256,0.15,RGB
`Kenya Crop Type`_,S,Sentinel-2,"4,688",7,"3,035x2,016",10,MSI
`Deep Globe Land Cover`_,S,DigitalGlobe +Vivid,803,7,"2,448x2,448",0.5,RGB
`DeepGlobe Land Cover`_,S,DigitalGlobe +Vivid,803,7,"2,448x2,448",0.5,RGB
`DFC2022`_,S,Aerial,"3,981",15,"2,000x2,000",0.5,RGB
`ETCI2021 Flood Detection`_,S,Sentinel-1,"66,810",2,256x256,5--20,SAR
`EuroSAT`_,C,Sentinel-2,"27,000",10,64x64,10,MSI
Expand Down Expand Up @@ -34,4 +34,4 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`Vaihingen`_,S,Aerial,33,6,"1,281--3,816",0.09,RGB
`NWPU VHR-10`_,I,"Google Earth, Vaihingen",800,10,"358--1,728",0.08--2,RGB
`xView2`_,CD,Maxar,"3,732",4,"1,024x1,024",0.8,RGB
`ZueriCrop`_,"I, T",Sentinel-2,116K,48,24x24,10,MSI
`ZueriCrop`_,"I, T",Sentinel-2,116K,48,24x24,10,MSI
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
- flake8>=3.8
- ipywidgets>=7
- isort[colors]>=5.8
- kornia>=0.6.4
- kornia>=0.6.5
- laspy>=2
- mypy>=0.900
- nbmake>=0.1
Expand Down
2 changes: 1 addition & 1 deletion requirements/min.old
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ setuptools==42.0.0
# install
einops==0.3.0
fiona==1.8.0
kornia==0.6.4
kornia==0.6.5
matplotlib==3.3.0
numpy==1.17.2
omegaconf==2.1.0
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ install_requires =
einops>=0.3,<0.7
# fiona 1.8+ required for reading empty files
fiona>=1.8,<2
# kornia 0.6.4+ required for kornia.contrib.compute_padding
kornia>=0.6.4,<0.7
# kornia 0.6.5+ required due to change in kornia.augmentation API
kornia>=0.6.5,<0.7
# matplotlib 3.3+ required for (H, W, 1) image support in plt.imshow
matplotlib>=3.3,<4
# numpy 1.17.2+ required by pytorch-lightning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ experiment:
ignore_index: null
datamodule:
root: "tests/data/deepglobelandcover"
num_tiles_per_batch: 1
num_patches_per_tile: 1
patch_size: 2
val_split_pct: 0.5
batch_size: 1
num_workers: 0
19 changes: 0 additions & 19 deletions tests/conf/deepglobelandcover_0.yaml

This file was deleted.

3 changes: 1 addition & 2 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ class TestSemanticSegmentationTask:
"name,classname",
[
("chesapeake_cvpr_5", ChesapeakeCVPRDataModule),
("deepglobelandcover_0", DeepGlobeLandCoverDataModule),
("deepglobelandcover_5", DeepGlobeLandCoverDataModule),
("deepglobelandcover", DeepGlobeLandCoverDataModule),
("etci2021", ETCI2021DataModule),
("inria_train", InriaAerialImageLabelingDataModule),
("inria_val", InriaAerialImageLabelingDataModule),
Expand Down
131 changes: 78 additions & 53 deletions torchgeo/datamodules/deepglobelandcover.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,87 +3,92 @@

"""DeepGlobe Land Cover Classification Challenge datamodule."""

from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple, Union

import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose
from kornia.augmentation import Normalize
from torch.utils.data import DataLoader

from ..datasets import DeepGlobeLandCover
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop
from .utils import dataset_split


class DeepGlobeLandCoverDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the DeepGlobe Land Cover dataset.
Uses the train/test splits from the dataset.
"""

def __init__(
self,
batch_size: int = 64,
num_workers: int = 0,
num_tiles_per_batch: int = 16,
num_patches_per_tile: int = 16,
patch_size: Union[Tuple[int, int], int] = 64,
val_split_pct: float = 0.2,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for DeepGlobe Land Cover based DataLoaders.
"""Initialize a new LightningDataModule instance.
The DeepGlobe Land Cover dataset contains images that are too large to pass
directly through a model. Instead, we randomly sample patches from image tiles
during training and chop up image tiles into patch grids during evaluation.
During training, the effective batch size is equal to
``num_tiles_per_batch`` x ``num_patches_per_tile``.
Args:
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
num_tiles_per_batch: The number of image tiles to sample from during
training
num_patches_per_tile: The number of patches to randomly sample from each
image tile during training
patch_size: The size of each patch, either ``size`` or ``(height, width)``.
Should be a multiple of 32 for most segmentation architectures
val_split_pct: The percentage of the dataset to use as a validation set
num_workers: The number of workers to use for parallel data loading
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.DeepGlobeLandCover`
.. versionchanged:: 0.4
*batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*,
and *patch_size*.
"""
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers

self.num_tiles_per_batch = num_tiles_per_batch
self.num_patches_per_tile = num_patches_per_tile
self.patch_size = _to_tuple(patch_size)
self.val_split_pct = val_split_pct
self.num_workers = num_workers
self.kwargs = kwargs

def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
self.train_transform = AugmentationSequential(
Normalize(mean=0.0, std=255.0),
_RandomNCrop(self.patch_size, self.num_patches_per_tile),
data_keys=["image", "mask"],
)
self.test_transform = AugmentationSequential(
Normalize(mean=0.0, std=255.0),
_ExtractTensorPatches(self.patch_size),
data_keys=["image", "mask"],
)

def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
"""Initialize the main Dataset objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])

dataset = DeepGlobeLandCover(
split="train", transforms=transforms, **self.kwargs
)

self.train_dataset: Dataset[Any]
self.val_dataset: Dataset[Any]

if self.val_split_pct > 0.0:
self.train_dataset, self.val_dataset, _ = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=0.0
)
else:
self.train_dataset = dataset
self.val_dataset = dataset

self.test_dataset = DeepGlobeLandCover(
split="test", transforms=transforms, **self.kwargs
train_dataset = DeepGlobeLandCover(split="train", **self.kwargs)
self.train_dataset, self.val_dataset = dataset_split(
train_dataset, self.val_split_pct
)
self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs)

def train_dataloader(self) -> DataLoader[Dict[str, Any]]:
"""Return a DataLoader for training.
Expand All @@ -93,7 +98,7 @@ def train_dataloader(self) -> DataLoader[Dict[str, Any]]:
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
batch_size=self.num_tiles_per_batch,
num_workers=self.num_workers,
shuffle=True,
)
Expand All @@ -105,10 +110,7 @@ def val_dataloader(self) -> DataLoader[Dict[str, Any]]:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)

def test_dataloader(self) -> DataLoader[Dict[str, Any]]:
Expand All @@ -118,12 +120,35 @@ def test_dataloader(self) -> DataLoader[Dict[str, Any]]:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)

def on_after_batch_transfer(
self, batch: Dict[str, Any], dataloader_idx: int
) -> Dict[str, Any]:
"""Apply augmentations to batch after transferring to GPU.
Args:
batch: A batch of data that needs to be altered or augmented
dataloader_idx: The index of the dataloader to which the batch belongs
Returns:
A batch of data
"""
# Kornia requires masks to have a channel dimension
batch["mask"] = batch["mask"].unsqueeze(1)

if self.trainer:
if self.trainer.training:
batch = self.train_transform(batch)
elif self.trainer.validating or self.trainer.testing:
batch = self.test_transform(batch)

# Torchmetrics does not support masks with a channel dimension
batch["mask"] = batch["mask"].squeeze(1)

return batch

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.DeepGlobeLandCover.plot`.
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datamodules/inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""InriaAerialImageLabeling datamodule."""

from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union

import kornia.augmentation as K
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(
self.num_workers = num_workers
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct
self.patch_size = cast(Tuple[int, int], _to_tuple(patch_size))
self.patch_size = _to_tuple(patch_size)
self.num_patches_per_tile = num_patches_per_tile
self.kwargs = kwargs

Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/deepglobelandcover.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _load_image(self, index: int) -> Tensor:
array: "np.typing.NDArray[np.int_]" = np.array(img)
tensor = torch.from_numpy(array)
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
tensor = tensor.permute((2, 0, 1)).to(torch.float32)
return tensor

def _load_target(self, index: int) -> Tensor:
Expand Down
12 changes: 11 additions & 1 deletion torchgeo/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,23 @@
"""Common sampler utilities."""

import math
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, overload

import torch

from ..datasets import BoundingBox


@overload
def _to_tuple(value: Union[Tuple[int, int], int]) -> Tuple[int, int]:
...


@overload
def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]:
...


def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]:
"""Convert value to a tuple if it is not already a tuple.
Expand Down
Loading

0 comments on commit ef31509

Please sign in to comment.