diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py new file mode 100644 index 00000000000..8e67152346a --- /dev/null +++ b/tests/datamodules/test_levircd.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import pytest +from lightning.pytorch import Trainer +from pytest import MonkeyPatch + +import torchgeo.datasets.utils +from torchgeo.datamodules import LEVIRCDPlusDataModule +from torchgeo.datasets import LEVIRCDPlus + + +def download_url(url: str, root: str, *args: str) -> None: + shutil.copy(url, root) + + +class TestLEVIRCDPlusDataModule: + @pytest.fixture + def datamodule( + self, monkeypatch: MonkeyPatch, tmp_path: Path + ) -> LEVIRCDPlusDataModule: + monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) + md5 = "1adf156f628aa32fb2e8fe6cada16c04" + monkeypatch.setattr(LEVIRCDPlus, "md5", md5) + url = os.path.join("tests", "data", "levircd", "LEVIR-CD+.zip") + monkeypatch.setattr(LEVIRCDPlus, "url", url) + + root = str(tmp_path) + dm = LEVIRCDPlusDataModule( + root=root, download=True, num_workers=0, checksum=True, val_split_pct=0.5 + ) + dm.prepare_data() + dm.trainer = Trainer(accelerator="cpu", max_epochs=1) + return dm + + def test_train_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: + datamodule.setup("fit") + if datamodule.trainer: + datamodule.trainer.training = True + batch = next(iter(datamodule.train_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8 + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8 + assert batch["image1"].shape[1] == 3 + assert batch["image2"].shape[1] == 3 + + def test_val_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: + datamodule.setup("validate") + if datamodule.trainer: + datamodule.trainer.validating = True + batch = next(iter(datamodule.val_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) + if datamodule.val_split_pct > 0.0: + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8 + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8 + assert batch["image1"].shape[1] == 3 + assert batch["image2"].shape[1] == 3 + + def test_test_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: + datamodule.setup("test") + if datamodule.trainer: + datamodule.trainer.testing = True + batch = next(iter(datamodule.test_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8 + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8 + assert batch["image1"].shape[1] == 3 + assert batch["image2"].shape[1] == 3 diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 8761e850e18..66555c7b978 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -18,6 +18,7 @@ from .l7irish import L7IrishDataModule from .l8biome import L8BiomeDataModule from .landcoverai import LandCoverAIDataModule +from .levircd import LEVIRCDPlusDataModule from .loveda import LoveDADataModule from .naip import NAIPChesapeakeDataModule from .nasa_marine_debris import NASAMarineDebrisDataModule @@ -56,6 +57,7 @@ "GID15DataModule", "InriaAerialImageLabelingDataModule", "LandCoverAIDataModule", + "LEVIRCDPlusDataModule", "LoveDADataModule", "NASAMarineDebrisDataModule", "OSCDDataModule", diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py new file mode 100644 index 00000000000..b021d8c860b --- /dev/null +++ b/torchgeo/datamodules/levircd.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""LEVIR-CD+ datamodule.""" + +from typing import Any, Union + +import kornia.augmentation as K + +from torchgeo.datamodules.utils import dataset_split +from torchgeo.samplers.utils import _to_tuple + +from ..datasets import LEVIRCDPlus +from ..transforms import AugmentationSequential +from ..transforms.transforms import _RandomNCrop +from .geo import NonGeoDataModule + + +class LEVIRCDPlusDataModule(NonGeoDataModule): + """LightningDataModule implementation for the LEVIR-CD+ dataset. + + Uses the train/test splits from the dataset and further splits + the train split into train/val splits. + + .. versionadded:: 0.6 + """ + + def __init__( + self, + batch_size: int = 8, + patch_size: Union[tuple[int, int], int] = 256, + val_split_pct: float = 0.2, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new LEVIRCDPlusDataModule instance. + + Args: + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + val_split_pct: Percentage of the dataset to use as a validation set. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.LEVIRCDPlus`. + """ + super().__init__(LEVIRCDPlus, 1, num_workers, **kwargs) + + self.patch_size = _to_tuple(patch_size) + self.val_split_pct = val_split_pct + + self.aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), + data_keys=["image1", "image2", "mask"], + ) + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + if stage in ["fit", "validate"]: + self.dataset = LEVIRCDPlus(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + self.dataset, val_pct=self.val_split_pct + ) + if stage in ["test"]: + self.test_dataset = LEVIRCDPlus(split="test", **self.kwargs) diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index f1e7bd390e7..6a87ec27105 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -156,7 +156,7 @@ def _load_image(self, path: str) -> Tensor: filename = os.path.join(path) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor @@ -225,20 +225,34 @@ def plot( .. versionadded:: 0.2 """ - image1, image2, mask = (sample["image1"], sample["image2"], sample["mask"]) ncols = 3 + def get_rgb(img: Tensor) -> "np.typing.NDArray[np.uint8]": + rgb_img = img.permute(1, 2, 0).float().numpy() + per02 = np.percentile(rgb_img, 2) + per98 = np.percentile(rgb_img, 98) + delta = per98 - per02 + epsilon = 1e-7 + norm_img: "np.typing.NDArray[np.uint8]" = ( + np.clip((rgb_img - per02) / (delta + epsilon), 0, 1) * 255 + ).astype(np.uint8) + return norm_img + + image1 = get_rgb(sample["image1"]) + image2 = get_rgb(sample["image2"]) + mask = sample["mask"].numpy() + if "prediction" in sample: prediction = sample["prediction"] ncols += 1 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) - axs[0].imshow(image1.permute(1, 2, 0)) + axs[0].imshow(image1) axs[0].axis("off") - axs[1].imshow(image2.permute(1, 2, 0)) + axs[1].imshow(image2) axs[1].axis("off") - axs[2].imshow(mask) + axs[2].imshow(mask, cmap="gray") axs[2].axis("off") if "prediction" in sample: