From ada76e27a5189dde808a48da327e81805c03f6e2 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sat, 3 Dec 2022 17:08:39 +0100 Subject: [PATCH 1/6] add datamodule with crop logic --- docs/api/datamodules.rst | 5 + tests/datamodules/test_gid15.py | 57 ++++++++ torchgeo/datamodules/__init__.py | 2 + torchgeo/datamodules/gid15.py | 218 +++++++++++++++++++++++++++++++ torchgeo/datasets/utils.py | 19 +++ 5 files changed, 301 insertions(+) create mode 100644 tests/datamodules/test_gid15.py create mode 100644 torchgeo/datamodules/gid15.py diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index ab06618eee2..0bd7701a79a 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -49,6 +49,11 @@ FAIR1M .. autoclass:: FAIR1MDataModule +GID-15 +^^^^^^ + +.. autoclass:: GIDDataModule + Inria Aerial Image Labeling ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/datamodules/test_gid15.py b/tests/datamodules/test_gid15.py new file mode 100644 index 00000000000..bf3764af840 --- /dev/null +++ b/tests/datamodules/test_gid15.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import matplotlib.pyplot as plt +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import GID15DataModule +from torchgeo.datasets import unbind_samples + + +class TestGID15DataModule: + @pytest.fixture(scope="class", params=[0.0, 0.5]) + def datamodule(self, request: SubRequest) -> GID15DataModule: + root = os.path.join("tests", "data", "gid15") + batch_size = 2 + num_workers = 0 + val_split_size = request.param + dm = GID15DataModule( + root=root, + train_batch_size=batch_size, + num_workers=num_workers, + val_split_pct=val_split_size, + num_tiles_per_batch=1, + download=True, + ) + dm.prepare_data() + dm.setup() + return dm + + def test_batch_size_warning(self, datamodule: GID15DataModule) -> None: + match = "The effective batch size will differ" + with pytest.warns(UserWarning, match=match): + GID15DataModule( + root=datamodule.test_dataset.root, + train_batch_size=3, + num_tiles_per_batch=2, + num_workers=datamodule.num_workers, + val_split_pct=datamodule.val_split_pct, + ) + + def test_train_dataloader(self, datamodule: GID15DataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: GID15DataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: GID15DataModule) -> None: + next(iter(datamodule.test_dataloader())) + + def test_plot(self, datamodule: GID15DataModule) -> None: + batch = next(iter(datamodule.train_dataloader())) + sample = unbind_samples(batch)[0] + datamodule.plot(sample) + plt.close() diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 1f0ecf82284..14e50ad7d45 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -11,6 +11,7 @@ from .etci2021 import ETCI2021DataModule from .eurosat import EuroSATDataModule from .fair1m import FAIR1MDataModule +from .gid15 import GID15DataModule from .inria import InriaAerialImageLabelingDataModule from .landcoverai import LandCoverAIDataModule from .loveda import LoveDADataModule @@ -38,6 +39,7 @@ "ETCI2021DataModule", "EuroSATDataModule", "FAIR1MDataModule", + "GID15DataModule", "InriaAerialImageLabelingDataModule", "LandCoverAIDataModule", "LoveDADataModule", diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py new file mode 100644 index 00000000000..dc9ab5c1f64 --- /dev/null +++ b/torchgeo/datamodules/gid15.py @@ -0,0 +1,218 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""GID-15 datamodule.""" + +import warnings +from typing import Any, Dict, Optional, Tuple, Union + +import kornia.augmentation as K +import matplotlib.pyplot as plt +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader, Dataset +from torchvision.transforms import Compose + +from torchgeo.datasets.utils import collate_patches_per_tile +from torchgeo.samplers.utils import _to_tuple + +from ..datasets import GID15 +from .utils import dataset_split + + +class GID15DataModule(pl.LightningDataModule): + """GID15 LightningDataModule implementation for the GID-15 dataset. + + Uses the train/test splits from the dataset. + + """ + + def __init__( + self, + batch_size: int = 32, + num_workers: int = 0, + patch_size: Union[Tuple[int, int], int] = (64, 64), + num_tiles_per_batch: int = 16, + val_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for GID-15 based DataLoaders. + + Args: + batch_size: The batch size used in the train DataLoader + (val_batch_size == test_batch_size == 1). + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + patch_size: Size of random patch from image and mask (height, width), should + be a multiple of 32 for most segmentation architectures + num_tiles_per_batch: number of random tiles to consider sampling patches + from per sample, should evenly divide batch_size and be less than + or equal to batch_size + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.GID15` + + .. versionadded:: 0.4 + """ + super().__init__() + + self.batch_size = batch_size + self.num_workers = num_workers + self.patch_size = _to_tuple(patch_size) + self.val_split_pct = val_split_pct + self.kwargs = kwargs + + assert ( + self.batch_size >= num_tiles_per_batch + ), "num_tiles_per_bacth should be less than or equal to batch_size." + + self.num_patches_per_tile = self.batch_size // num_tiles_per_batch + + if (self.num_patches_per_tile % 2) != 0 and ( + self.num_patches_per_tile != num_tiles_per_batch + ): + warnings.warn( + "The effective batch size" + f" will differ from the specified {batch_size}" + f" and be {self.num_patches_per_tile * num_tiles_per_batch} instead." + " To match the batch_size exactly, ensure that" + " num_tiles_per_batch evenly divides batch_size" + ) + + self.rcrop = K.AugmentationSequential( + K.RandomCrop(self.patch_size), data_keys=["input", "mask"] + ) + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + return sample + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + + def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: + """Construct 'num_patches_per_tile' random patches of input tile. + + Args: + sample: contains image and mask tile from dataset + + Returns: + stacked randomly cropped patches from input tile + """ + images, masks = [], [] + for i in range(self.num_patches_per_tile): + image, mask = self.rcrop(sample["image"], sample["mask"].float()) + images.append(image.squeeze(0)) + masks.append(mask.squeeze().long()) + + sample["image"] = torch.stack(images) + sample["mask"] = torch.stack(masks) + return sample + + def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: + """Pad image and mask to next multiple of 32. + + Args: + sample: contains image and mask sample from dataset + + Returns: + padded image and mask + """ + h, w = sample["image"].shape[1], sample["image"].shape[2] + new_h = int(32 * ((h // 32) + 1)) + new_w = int(32 * ((w // 32) + 1)) + + padto = K.PadTo((new_h, new_w)) + + sample["image"] = padto(sample["image"])[0] + return sample + + train_transforms = Compose([self.preprocess, n_random_crop]) + # for testing and validation we pad all inputs to next larger multiple of 32 + # to avoid issues with upsampling paths in encoder-decoder architectures + test_transforms = Compose([self.preprocess, pad_to]) + + train_dataset = GID15(split="train", transforms=train_transforms, **self.kwargs) + + self.train_dataset: Dataset[Any] + self.val_dataset: Dataset[Any] + + if self.val_split_pct > 0.0: + val_dataset = GID15( + split="train", transforms=test_transforms, **self.kwargs + ) + self.train_dataset, self.val_dataset, _ = dataset_split( + train_dataset, val_pct=self.val_split_pct, test_pct=0.0 + ) + self.val_dataset.dataset = val_dataset + else: + self.train_dataset = train_dataset + self.val_dataset = train_dataset + + self.test_dataset = GID15( + split="test", transforms=test_transforms, **self.kwargs + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=collate_patches_per_tile, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Dict[str, Any]]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + if self.val_split_pct > 0.0: + return DataLoader( + self.val_dataset, + batch_size=1, + num_workers=self.num_workers, + shuffle=False, + ) + else: + return DataLoader( + self.val_dataset, + batch_size=1, + num_workers=self.num_workers, + shuffle=False, + collate_fn=collate_patches_per_tile, + ) + + def test_dataloader(self) -> DataLoader[Dict[str, Any]]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False + ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.GID15.plot`.""" + return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 3270356f022..570c48ca692 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -30,7 +30,9 @@ import numpy as np import rasterio import torch +from einops import rearrange from torch import Tensor +from torch.utils.data._utils.collate import default_collate from torchvision.datasets.utils import check_integrity, download_url from torchvision.utils import draw_segmentation_masks @@ -216,6 +218,23 @@ def download_radiant_mlhub_collection( collection.download(output_dir=download_root, api_key=api_key) +def collate_patches_per_tile(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + """Define collate function to combine patches per tile and batch size. + + Args: + batch: sample batch from dataloader containing image and mask + + Returns: + sample batch where the batch dimension is + 'train_batch_size' * 'num_patches_per_tile' + """ + r_batch: Dict[str, Any] = default_collate(batch) # type: ignore[no-untyped-call] + print(r_batch["image"].shape) + r_batch["image"] = rearrange(r_batch["image"], "b t c h w -> (b t) c h w") + r_batch["mask"] = rearrange(r_batch["mask"], "b t h w -> (b t) h w") + return r_batch + + @dataclass(frozen=True) class BoundingBox: """Data class for indexing spatiotemporal data.""" From ebfc50273e7f622430f50e225d1522843ed70874 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sat, 3 Dec 2022 17:22:13 +0100 Subject: [PATCH 2/6] remove print and fix batch_size --- torchgeo/datamodules/gid15.py | 5 +++-- torchgeo/datasets/utils.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index dc9ab5c1f64..c6218e9f0c3 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -63,9 +63,10 @@ def __init__( assert ( self.batch_size >= num_tiles_per_batch - ), "num_tiles_per_bacth should be less than or equal to batch_size." + ), "num_tiles_per_batch should be less than or equal to batch_size." self.num_patches_per_tile = self.batch_size // num_tiles_per_batch + self.num_tiles_per_batch = num_tiles_per_batch if (self.num_patches_per_tile % 2) != 0 and ( self.num_patches_per_tile != num_tiles_per_batch @@ -175,7 +176,7 @@ def train_dataloader(self) -> DataLoader[Any]: """ return DataLoader( self.train_dataset, - batch_size=self.batch_size, + batch_size=self.num_tiles_per_batch, num_workers=self.num_workers, collate_fn=collate_patches_per_tile, shuffle=True, diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 570c48ca692..ebefe3e7d59 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -229,7 +229,6 @@ def collate_patches_per_tile(batch: List[Dict[str, Any]]) -> Dict[str, Any]: 'train_batch_size' * 'num_patches_per_tile' """ r_batch: Dict[str, Any] = default_collate(batch) # type: ignore[no-untyped-call] - print(r_batch["image"].shape) r_batch["image"] = rearrange(r_batch["image"], "b t c h w -> (b t) c h w") r_batch["mask"] = rearrange(r_batch["mask"], "b t h w -> (b t) h w") return r_batch From 97e6f28627b9d0fcc067c3d98ec93e8fa6c0865b Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sat, 3 Dec 2022 18:03:28 +0100 Subject: [PATCH 3/6] typo --- tests/datamodules/test_gid15.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/datamodules/test_gid15.py b/tests/datamodules/test_gid15.py index bf3764af840..06494247bea 100644 --- a/tests/datamodules/test_gid15.py +++ b/tests/datamodules/test_gid15.py @@ -20,7 +20,7 @@ def datamodule(self, request: SubRequest) -> GID15DataModule: val_split_size = request.param dm = GID15DataModule( root=root, - train_batch_size=batch_size, + batch_size=batch_size, num_workers=num_workers, val_split_pct=val_split_size, num_tiles_per_batch=1, @@ -35,7 +35,7 @@ def test_batch_size_warning(self, datamodule: GID15DataModule) -> None: with pytest.warns(UserWarning, match=match): GID15DataModule( root=datamodule.test_dataset.root, - train_batch_size=3, + batch_size=3, num_tiles_per_batch=2, num_workers=datamodule.num_workers, val_split_pct=datamodule.val_split_pct, From 538f00d30f87d42b43917a54aef64bbc3906af70 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 10:55:00 -0600 Subject: [PATCH 4/6] Use Kornia augmentations --- conf/gid15.yaml | 21 ++ docs/api/datamodules.rst | 2 +- tests/conf/gid15.yaml | 22 ++ .../GF2_PMS1__L1A0000564539-MSS1_15label.png | Bin 0 -> 67 bytes .../GF2_PMS1__L1A0000575925-MSS1_15label.png | Bin 0 -> 67 bytes .../GF2_PMS1__L1A0001064454-MSS1_15label.png | Bin 0 -> 67 bytes .../GF2_PMS1__L1A0001118839-MSS1_15label.png | Bin 0 -> 67 bytes .../test/GF2_PMS1__L1A0000708367-MSS1.tif | Bin 0 -> 143 bytes .../test/GF2_PMS1__L1A0001344822-MSS1.tif | Bin 0 -> 143 bytes .../train/GF2_PMS1__L1A0000564539-MSS1.tif | Bin 0 -> 143 bytes .../train/GF2_PMS1__L1A0000575925-MSS1.tif | Bin 0 -> 143 bytes .../val/GF2_PMS1__L1A0001064454-MSS1.tif | Bin 0 -> 143 bytes .../val/GF2_PMS1__L1A0001118839-MSS1.tif | Bin 0 -> 143 bytes tests/datamodules/test_gid15.py | 57 ----- tests/trainers/test_segmentation.py | 2 + torchgeo/datamodules/gid15.py | 241 +++++++----------- torchgeo/datasets/gid15.py | 2 +- torchgeo/datasets/utils.py | 18 -- train.py | 2 + 19 files changed, 147 insertions(+), 220 deletions(-) create mode 100644 conf/gid15.yaml create mode 100644 tests/conf/gid15.yaml create mode 100644 tests/data/gid15/GID/ann_dir/train/GF2_PMS1__L1A0000564539-MSS1_15label.png create mode 100644 tests/data/gid15/GID/ann_dir/train/GF2_PMS1__L1A0000575925-MSS1_15label.png create mode 100644 tests/data/gid15/GID/ann_dir/val/GF2_PMS1__L1A0001064454-MSS1_15label.png create mode 100644 tests/data/gid15/GID/ann_dir/val/GF2_PMS1__L1A0001118839-MSS1_15label.png create mode 100644 tests/data/gid15/GID/img_dir/test/GF2_PMS1__L1A0000708367-MSS1.tif create mode 100644 tests/data/gid15/GID/img_dir/test/GF2_PMS1__L1A0001344822-MSS1.tif create mode 100644 tests/data/gid15/GID/img_dir/train/GF2_PMS1__L1A0000564539-MSS1.tif create mode 100644 tests/data/gid15/GID/img_dir/train/GF2_PMS1__L1A0000575925-MSS1.tif create mode 100644 tests/data/gid15/GID/img_dir/val/GF2_PMS1__L1A0001064454-MSS1.tif create mode 100644 tests/data/gid15/GID/img_dir/val/GF2_PMS1__L1A0001118839-MSS1.tif delete mode 100644 tests/datamodules/test_gid15.py diff --git a/conf/gid15.yaml b/conf/gid15.yaml new file mode 100644 index 00000000000..dd69143e574 --- /dev/null +++ b/conf/gid15.yaml @@ -0,0 +1,21 @@ +experiment: + task: "gid15" + module: + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 16 + num_filters: 1 + ignore_index: null + datamodule: + root: "data/gid15" + num_tiles_per_batch: 16 + num_patches_per_tile: 16 + patch_size: 64 + val_split_pct: 0.5 + num_workers: 0 diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index 0bd7701a79a..4833ff815e4 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -52,7 +52,7 @@ FAIR1M GID-15 ^^^^^^ -.. autoclass:: GIDDataModule +.. autoclass:: GID15DataModule Inria Aerial Image Labeling ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/conf/gid15.yaml b/tests/conf/gid15.yaml new file mode 100644 index 00000000000..56e25c7261a --- /dev/null +++ b/tests/conf/gid15.yaml @@ -0,0 +1,22 @@ +experiment: + task: "gid15" + module: + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 16 + num_filters: 1 + ignore_index: null + datamodule: + root: "tests/data/gid15" + download: true + num_tiles_per_batch: 1 + num_patches_per_tile: 1 + patch_size: 2 + val_split_pct: 0.5 + num_workers: 0 diff --git a/tests/data/gid15/GID/ann_dir/train/GF2_PMS1__L1A0000564539-MSS1_15label.png b/tests/data/gid15/GID/ann_dir/train/GF2_PMS1__L1A0000564539-MSS1_15label.png new file mode 100644 index 0000000000000000000000000000000000000000..aea7f5ff8ad3e3c7c084080bda8ac3b95e9a71cd GIT binary patch literal 67 zcmeAS@N?(olHy`uVBq!ia0vp^j3CSbBp9sfW`_bPE>9Q7kcv6U2|zXz1Ea_KC51p1 NgQu&X%Q~loCIDk>43q!> literal 0 HcmV?d00001 diff --git a/tests/data/gid15/GID/ann_dir/train/GF2_PMS1__L1A0000575925-MSS1_15label.png b/tests/data/gid15/GID/ann_dir/train/GF2_PMS1__L1A0000575925-MSS1_15label.png new file mode 100644 index 0000000000000000000000000000000000000000..aea7f5ff8ad3e3c7c084080bda8ac3b95e9a71cd GIT binary patch literal 67 zcmeAS@N?(olHy`uVBq!ia0vp^j3CSbBp9sfW`_bPE>9Q7kcv6U2|zXz1Ea_KC51p1 NgQu&X%Q~loCIDk>43q!> literal 0 HcmV?d00001 diff --git a/tests/data/gid15/GID/ann_dir/val/GF2_PMS1__L1A0001064454-MSS1_15label.png b/tests/data/gid15/GID/ann_dir/val/GF2_PMS1__L1A0001064454-MSS1_15label.png new file mode 100644 index 0000000000000000000000000000000000000000..aea7f5ff8ad3e3c7c084080bda8ac3b95e9a71cd GIT binary patch literal 67 zcmeAS@N?(olHy`uVBq!ia0vp^j3CSbBp9sfW`_bPE>9Q7kcv6U2|zXz1Ea_KC51p1 NgQu&X%Q~loCIDk>43q!> literal 0 HcmV?d00001 diff --git a/tests/data/gid15/GID/ann_dir/val/GF2_PMS1__L1A0001118839-MSS1_15label.png b/tests/data/gid15/GID/ann_dir/val/GF2_PMS1__L1A0001118839-MSS1_15label.png new file mode 100644 index 0000000000000000000000000000000000000000..aea7f5ff8ad3e3c7c084080bda8ac3b95e9a71cd GIT binary patch literal 67 zcmeAS@N?(olHy`uVBq!ia0vp^j3CSbBp9sfW`_bPE>9Q7kcv6U2|zXz1Ea_KC51p1 NgQu&X%Q~loCIDk>43q!> literal 0 HcmV?d00001 diff --git a/tests/data/gid15/GID/img_dir/test/GF2_PMS1__L1A0000708367-MSS1.tif b/tests/data/gid15/GID/img_dir/test/GF2_PMS1__L1A0000708367-MSS1.tif new file mode 100644 index 0000000000000000000000000000000000000000..26fd8ad05f8e6a7443e523a599842ff612902ea3 GIT binary patch literal 143 zcmebD)MDUZU|`^4U|?inU<9(j7>Uiq$jrbD6mJ7!W*{4;h7HPQ0*VVl)%O6|qEK;= X9x)_y#G!1ExD1jxKrYY)AP58i_^$%G literal 0 HcmV?d00001 diff --git a/tests/data/gid15/GID/img_dir/test/GF2_PMS1__L1A0001344822-MSS1.tif b/tests/data/gid15/GID/img_dir/test/GF2_PMS1__L1A0001344822-MSS1.tif new file mode 100644 index 0000000000000000000000000000000000000000..26fd8ad05f8e6a7443e523a599842ff612902ea3 GIT binary patch literal 143 zcmebD)MDUZU|`^4U|?inU<9(j7>Uiq$jrbD6mJ7!W*{4;h7HPQ0*VVl)%O6|qEK;= X9x)_y#G!1ExD1jxKrYY)AP58i_^$%G literal 0 HcmV?d00001 diff --git a/tests/data/gid15/GID/img_dir/train/GF2_PMS1__L1A0000564539-MSS1.tif b/tests/data/gid15/GID/img_dir/train/GF2_PMS1__L1A0000564539-MSS1.tif new file mode 100644 index 0000000000000000000000000000000000000000..26fd8ad05f8e6a7443e523a599842ff612902ea3 GIT binary patch literal 143 zcmebD)MDUZU|`^4U|?inU<9(j7>Uiq$jrbD6mJ7!W*{4;h7HPQ0*VVl)%O6|qEK;= X9x)_y#G!1ExD1jxKrYY)AP58i_^$%G literal 0 HcmV?d00001 diff --git a/tests/data/gid15/GID/img_dir/train/GF2_PMS1__L1A0000575925-MSS1.tif b/tests/data/gid15/GID/img_dir/train/GF2_PMS1__L1A0000575925-MSS1.tif new file mode 100644 index 0000000000000000000000000000000000000000..26fd8ad05f8e6a7443e523a599842ff612902ea3 GIT binary patch literal 143 zcmebD)MDUZU|`^4U|?inU<9(j7>Uiq$jrbD6mJ7!W*{4;h7HPQ0*VVl)%O6|qEK;= X9x)_y#G!1ExD1jxKrYY)AP58i_^$%G literal 0 HcmV?d00001 diff --git a/tests/data/gid15/GID/img_dir/val/GF2_PMS1__L1A0001064454-MSS1.tif b/tests/data/gid15/GID/img_dir/val/GF2_PMS1__L1A0001064454-MSS1.tif new file mode 100644 index 0000000000000000000000000000000000000000..26fd8ad05f8e6a7443e523a599842ff612902ea3 GIT binary patch literal 143 zcmebD)MDUZU|`^4U|?inU<9(j7>Uiq$jrbD6mJ7!W*{4;h7HPQ0*VVl)%O6|qEK;= X9x)_y#G!1ExD1jxKrYY)AP58i_^$%G literal 0 HcmV?d00001 diff --git a/tests/data/gid15/GID/img_dir/val/GF2_PMS1__L1A0001118839-MSS1.tif b/tests/data/gid15/GID/img_dir/val/GF2_PMS1__L1A0001118839-MSS1.tif new file mode 100644 index 0000000000000000000000000000000000000000..26fd8ad05f8e6a7443e523a599842ff612902ea3 GIT binary patch literal 143 zcmebD)MDUZU|`^4U|?inU<9(j7>Uiq$jrbD6mJ7!W*{4;h7HPQ0*VVl)%O6|qEK;= X9x)_y#G!1ExD1jxKrYY)AP58i_^$%G literal 0 HcmV?d00001 diff --git a/tests/datamodules/test_gid15.py b/tests/datamodules/test_gid15.py deleted file mode 100644 index 06494247bea..00000000000 --- a/tests/datamodules/test_gid15.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os - -import matplotlib.pyplot as plt -import pytest -from _pytest.fixtures import SubRequest - -from torchgeo.datamodules import GID15DataModule -from torchgeo.datasets import unbind_samples - - -class TestGID15DataModule: - @pytest.fixture(scope="class", params=[0.0, 0.5]) - def datamodule(self, request: SubRequest) -> GID15DataModule: - root = os.path.join("tests", "data", "gid15") - batch_size = 2 - num_workers = 0 - val_split_size = request.param - dm = GID15DataModule( - root=root, - batch_size=batch_size, - num_workers=num_workers, - val_split_pct=val_split_size, - num_tiles_per_batch=1, - download=True, - ) - dm.prepare_data() - dm.setup() - return dm - - def test_batch_size_warning(self, datamodule: GID15DataModule) -> None: - match = "The effective batch size will differ" - with pytest.warns(UserWarning, match=match): - GID15DataModule( - root=datamodule.test_dataset.root, - batch_size=3, - num_tiles_per_batch=2, - num_workers=datamodule.num_workers, - val_split_pct=datamodule.val_split_pct, - ) - - def test_train_dataloader(self, datamodule: GID15DataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: GID15DataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: GID15DataModule) -> None: - next(iter(datamodule.test_dataloader())) - - def test_plot(self, datamodule: GID15DataModule) -> None: - batch = next(iter(datamodule.train_dataloader())) - sample = unbind_samples(batch)[0] - datamodule.plot(sample) - plt.close() diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index d9ae82bfafc..35467c179d9 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -15,6 +15,7 @@ ChesapeakeCVPRDataModule, DeepGlobeLandCoverDataModule, ETCI2021DataModule, + GID15DataModule, InriaAerialImageLabelingDataModule, LandCoverAIDataModule, LoveDADataModule, @@ -41,6 +42,7 @@ class TestSemanticSegmentationTask: ("chesapeake_cvpr_5", ChesapeakeCVPRDataModule), ("deepglobelandcover", DeepGlobeLandCoverDataModule), ("etci2021", ETCI2021DataModule), + ("gid15", GID15DataModule), ("inria_train", InriaAerialImageLabelingDataModule), ("inria_val", InriaAerialImageLabelingDataModule), ("inria_test", InriaAerialImageLabelingDataModule), diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index c6218e9f0c3..4256c631e10 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -3,172 +3,110 @@ """GID-15 datamodule.""" -import warnings from typing import Any, Dict, Optional, Tuple, Union -import kornia.augmentation as K import matplotlib.pyplot as plt import pytorch_lightning as pl -import torch -from torch.utils.data import DataLoader, Dataset -from torchvision.transforms import Compose - -from torchgeo.datasets.utils import collate_patches_per_tile -from torchgeo.samplers.utils import _to_tuple +from einops import rearrange +from kornia.augmentation import Normalize +from torch import Tensor +from torch.utils.data import DataLoader from ..datasets import GID15 +from ..samplers.utils import _to_tuple +from ..transforms import AugmentationSequential +from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop from .utils import dataset_split class GID15DataModule(pl.LightningDataModule): - """GID15 LightningDataModule implementation for the GID-15 dataset. + """LightningDataModule implementation for the GID-15 dataset. Uses the train/test splits from the dataset. + .. versionadded:: 0.4 """ def __init__( self, - batch_size: int = 32, - num_workers: int = 0, - patch_size: Union[Tuple[int, int], int] = (64, 64), num_tiles_per_batch: int = 16, + num_patches_per_tile: int = 16, + patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, + num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for GID-15 based DataLoaders. + """Initialize a new LightningDataModule instance. + + The GID-15 dataset contains images that are too large to pass + directly through a model. Instead, we randomly sample patches from image tiles + during training and chop up image tiles into patch grids during evaluation. + During training, the effective batch size is equal to + ``num_tiles_per_batch`` x ``num_patches_per_tile``. Args: - batch_size: The batch size used in the train DataLoader - (val_batch_size == test_batch_size == 1). - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - patch_size: Size of random patch from image and mask (height, width), should - be a multiple of 32 for most segmentation architectures - num_tiles_per_batch: number of random tiles to consider sampling patches - from per sample, should evenly divide batch_size and be less than - or equal to batch_size + num_tiles_per_batch: The number of image tiles to sample from during + training + num_patches_per_tile: The number of patches to randomly sample from each + image tile during training + patch_size: The size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures + val_split_pct: The percentage of the dataset to use as a validation set + num_workers: The number of workers to use for parallel data loading **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.GID15` - - .. versionadded:: 0.4 """ super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + self.num_tiles_per_batch = num_tiles_per_batch + self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct + self.num_workers = num_workers self.kwargs = kwargs - assert ( - self.batch_size >= num_tiles_per_batch - ), "num_tiles_per_batch should be less than or equal to batch_size." - - self.num_patches_per_tile = self.batch_size // num_tiles_per_batch - self.num_tiles_per_batch = num_tiles_per_batch - - if (self.num_patches_per_tile % 2) != 0 and ( - self.num_patches_per_tile != num_tiles_per_batch - ): - warnings.warn( - "The effective batch size" - f" will differ from the specified {batch_size}" - f" and be {self.num_patches_per_tile * num_tiles_per_batch} instead." - " To match the batch_size exactly, ensure that" - " num_tiles_per_batch evenly divides batch_size" - ) - - self.rcrop = K.AugmentationSequential( - K.RandomCrop(self.patch_size), data_keys=["input", "mask"] + self.train_transform = AugmentationSequential( + Normalize(mean=0.0, std=255.0), + _RandomNCrop(self.patch_size, self.num_patches_per_tile), + data_keys=["image", "mask"], + ) + self.val_transform = AugmentationSequential( + Normalize(mean=0.0, std=255.0), + _ExtractTensorPatches(self.patch_size), + data_keys=["image", "mask"], + ) + self.predict_transform = AugmentationSequential( + Normalize(mean=0.0, std=255.0), + _ExtractTensorPatches(self.patch_size), + data_keys=["image"], ) - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary + def prepare_data(self) -> None: + """Initialize the main Dataset objects for use in :func:`setup`. - Returns: - preprocessed sample + This includes optionally downloading the dataset. This is done once per node, + while :func:`setup` is done once per GPU. """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample + if self.kwargs.get("download", False): + GID15(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + """Initialize the main Dataset objects. This method is called once per GPU per run. Args: stage: stage to set up """ - - def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: - """Construct 'num_patches_per_tile' random patches of input tile. - - Args: - sample: contains image and mask tile from dataset - - Returns: - stacked randomly cropped patches from input tile - """ - images, masks = [], [] - for i in range(self.num_patches_per_tile): - image, mask = self.rcrop(sample["image"], sample["mask"].float()) - images.append(image.squeeze(0)) - masks.append(mask.squeeze().long()) - - sample["image"] = torch.stack(images) - sample["mask"] = torch.stack(masks) - return sample - - def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: - """Pad image and mask to next multiple of 32. - - Args: - sample: contains image and mask sample from dataset - - Returns: - padded image and mask - """ - h, w = sample["image"].shape[1], sample["image"].shape[2] - new_h = int(32 * ((h // 32) + 1)) - new_w = int(32 * ((w // 32) + 1)) - - padto = K.PadTo((new_h, new_w)) - - sample["image"] = padto(sample["image"])[0] - return sample - - train_transforms = Compose([self.preprocess, n_random_crop]) - # for testing and validation we pad all inputs to next larger multiple of 32 - # to avoid issues with upsampling paths in encoder-decoder architectures - test_transforms = Compose([self.preprocess, pad_to]) - - train_dataset = GID15(split="train", transforms=train_transforms, **self.kwargs) - - self.train_dataset: Dataset[Any] - self.val_dataset: Dataset[Any] - - if self.val_split_pct > 0.0: - val_dataset = GID15( - split="train", transforms=test_transforms, **self.kwargs - ) - self.train_dataset, self.val_dataset, _ = dataset_split( - train_dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - self.val_dataset.dataset = val_dataset - else: - self.train_dataset = train_dataset - self.val_dataset = train_dataset - - self.test_dataset = GID15( - split="test", transforms=test_transforms, **self.kwargs + train_dataset = GID15(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + train_dataset, self.val_split_pct ) - def train_dataloader(self) -> DataLoader[Any]: + # Test set masks are not public, use for prediction instead + self.predict_dataset = GID15(split="test", **self.kwargs) + + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for training. Returns: @@ -178,42 +116,59 @@ def train_dataloader(self) -> DataLoader[Any]: self.train_dataset, batch_size=self.num_tiles_per_batch, num_workers=self.num_workers, - collate_fn=collate_patches_per_tile, shuffle=True, ) - def val_dataloader(self) -> DataLoader[Dict[str, Any]]: + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for validation. Returns: validation data loader """ - if self.val_split_pct > 0.0: - return DataLoader( - self.val_dataset, - batch_size=1, - num_workers=self.num_workers, - shuffle=False, - ) - else: - return DataLoader( - self.val_dataset, - batch_size=1, - num_workers=self.num_workers, - shuffle=False, - collate_fn=collate_patches_per_tile, - ) - - def test_dataloader(self) -> DataLoader[Dict[str, Any]]: - """Return a DataLoader for testing. + return DataLoader( + self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False + ) + + def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Return a DataLoader for predicting. Returns: - testing data loader + predicting data loader """ return DataLoader( - self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False + self.predict_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + # Kornia requires masks to have a channel dimension + if "mask" in batch: + batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") + + if self.trainer: + if self.trainer.training: + batch = self.train_transform(batch) + elif self.trainer.validating: + batch = self.val_transform(batch) + elif self.trainer.predicting: + batch = self.predict_transform(batch) + + # Torchmetrics does not support masks with a channel dimension + if "mask" in batch: + batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") + + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.GID15.plot`.""" - return self.test_dataset.plot(*args, **kwargs) + return self.predict_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index 44df74ddcc4..ec295d484f3 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -194,7 +194,7 @@ def _load_image(self, path: str) -> Tensor: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW - tensor = tensor.permute((2, 0, 1)) + tensor = tensor.permute((2, 0, 1)).float() return tensor def _load_target(self, path: str) -> Tensor: diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index ebefe3e7d59..3270356f022 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -30,9 +30,7 @@ import numpy as np import rasterio import torch -from einops import rearrange from torch import Tensor -from torch.utils.data._utils.collate import default_collate from torchvision.datasets.utils import check_integrity, download_url from torchvision.utils import draw_segmentation_masks @@ -218,22 +216,6 @@ def download_radiant_mlhub_collection( collection.download(output_dir=download_root, api_key=api_key) -def collate_patches_per_tile(batch: List[Dict[str, Any]]) -> Dict[str, Any]: - """Define collate function to combine patches per tile and batch size. - - Args: - batch: sample batch from dataloader containing image and mask - - Returns: - sample batch where the batch dimension is - 'train_batch_size' * 'num_patches_per_tile' - """ - r_batch: Dict[str, Any] = default_collate(batch) # type: ignore[no-untyped-call] - r_batch["image"] = rearrange(r_batch["image"], "b t c h w -> (b t) c h w") - r_batch["mask"] = rearrange(r_batch["mask"], "b t h w -> (b t) h w") - return r_batch - - @dataclass(frozen=True) class BoundingBox: """Data class for indexing spatiotemporal data.""" diff --git a/train.py b/train.py index f10cc172b9e..01dcc2f80ee 100755 --- a/train.py +++ b/train.py @@ -20,6 +20,7 @@ DeepGlobeLandCoverDataModule, ETCI2021DataModule, EuroSATDataModule, + GID15DataModule, InriaAerialImageLabelingDataModule, LandCoverAIDataModule, LoveDADataModule, @@ -54,6 +55,7 @@ "deepglobelandcover": (SemanticSegmentationTask, DeepGlobeLandCoverDataModule), "eurosat": (ClassificationTask, EuroSATDataModule), "etci2021": (SemanticSegmentationTask, ETCI2021DataModule), + "gid15": (SemanticSegmentationTask, GID15DataModule), "inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule), "landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule), "loveda": (SemanticSegmentationTask, LoveDADataModule), From a08305a8d03d890f80c428cbd7094994f9af138f Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 10:58:50 -0600 Subject: [PATCH 5/6] Style --- torchgeo/datamodules/gid15.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 4256c631e10..045509809e2 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -136,7 +136,10 @@ def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: predicting data loader """ return DataLoader( - self.predict_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False + self.predict_dataset, + batch_size=1, + num_workers=self.num_workers, + shuffle=False, ) def on_after_batch_transfer( From e54098744ebc158774fe1e719146d1732fa76cff Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 11:16:50 -0600 Subject: [PATCH 6/6] Ignore warning --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index f13fa8e5ca2..7c95b30c1ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,8 @@ filterwarnings = [ # https://github.com/lanpa/tensorboardX/issues/653 # https://github.com/lanpa/tensorboardX/pull/654 "ignore:Call to deprecated create function:DeprecationWarning:tensorboardX", + # https://github.com/kornia/kornia/issues/777 + "ignore:Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0:UserWarning:torch.nn.functional", # Expected warnings # pytorch-lightning warns us about using num_workers=0, but it's faster on macOS