From 872066b8da4daf35740810ef7299e65f915b463e Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 12:46:36 +0000 Subject: [PATCH 01/35] add file --- torchgeo/datamodules/levircd.py | 97 +++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 torchgeo/datamodules/levircd.py diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py new file mode 100644 index 00000000000..edca6e53fc9 --- /dev/null +++ b/torchgeo/datamodules/levircd.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Levircd datamodule.""" + +from ..datasets import LEVIRCDPlus +from torchgeo.samplers.utils import _to_tuple +from torchgeo.datamodules.utils import dataset_split +from torchvision.transforms import Compose +import kornia.augmentation as K + +from .geo import NonGeoDataModule + +class LEVIRCDPlusDataModule(NonGeoDataModule): + def __init__( + self, + batch_size=8, + num_workers=0, + patch_size=256, + val_split_pct=0.2, + **kwargs + ): + super().__init__() + self.kwargs = kwargs + self.batch_size = batch_size + self.num_workers = num_workers + self.patch_size = _to_tuple(patch_size) + self.val_split_pct = val_split_pct + self.mean = torch.tensor([0.485, 0.456, 0.406]) + self.std = torch.tensor([0.229, 0.224, 0.225]) + + def preprocess(self, sample): + sample["image1"] = (sample["image1"] / 255.0).float() + sample["image2"] = (sample["image2"] / 255.0).float() + # Kornia adds batch dimension which we need to remove + sample["image1"] = K.Normalize(mean=self.mean, std=self.std)(sample["image1"]).squeeze(0) + sample["image2"] = K.Normalize(mean=self.mean, std=self.std)(sample["image2"]).squeeze(0) + sample["mask"] = sample["mask"].long() + return sample + + def train_augmentations(self, batch): + augmentations = AugmentationSequential( + K.RandomRotation(p=0.5, degrees=90), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomCrop(self.patch_size), + K.RandomSharpness(p=0.5), + data_keys=["image1", "image2", "mask"], + ) + return augmentations(batch) + + def on_after_batch_transfer(self, batch, batch_idx): + if self.trainer and self.trainer.training: + batch["mask"] = batch["mask"].float().unsqueeze(1) + batch = self.train_augmentations(batch) + batch["mask"] = batch["mask"].squeeze(1).long() + return batch + + def prepare_data(self): + LEVIRCDPlus(split="train", **self.kwargs) + LEVIRCDPlus(split="test", **self.kwargs) + + def setup(self, stage=None): + train_transforms = Compose([self.preprocess]) + test_transforms = Compose([self.preprocess]) + + train_dataset = LEVIRCDPlus(split="train", transforms=train_transforms, **self.kwargs) + + if self.val_split_pct > 0.0: + self.train_dataset, self.val_dataset, _ = dataset_split( + train_dataset, val_pct=self.val_split_pct, test_pct=0.0 + ) + else: + self.train_dataset = train_dataset + self.val_dataset = train_dataset + + self.test_dataset = LEVIRCDPlus( + split="test", transforms=test_transforms, **self.kwargs + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False + ) \ No newline at end of file From b7d8d54fe1a669f034320d1afa390a182fde812b Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 12:51:49 +0000 Subject: [PATCH 02/35] Add to init --- torchgeo/datamodules/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 8761e850e18..9acdc922429 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -18,6 +18,7 @@ from .l7irish import L7IrishDataModule from .l8biome import L8BiomeDataModule from .landcoverai import LandCoverAIDataModule +from .levircd import LEVIRCDPlusDataModule from .loveda import LoveDADataModule from .naip import NAIPChesapeakeDataModule from .nasa_marine_debris import NASAMarineDebrisDataModule From dc7cec919bee36adbf93c767e7af1cb5a3281f9f Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 13:42:43 +0000 Subject: [PATCH 03/35] refactor --- torchgeo/datamodules/levircd.py | 117 ++++++++++++-------------------- 1 file changed, 44 insertions(+), 73 deletions(-) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index edca6e53fc9..c34f9eef94c 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -6,92 +6,63 @@ from ..datasets import LEVIRCDPlus from torchgeo.samplers.utils import _to_tuple from torchgeo.datamodules.utils import dataset_split -from torchvision.transforms import Compose import kornia.augmentation as K +import torch from .geo import NonGeoDataModule class LEVIRCDPlusDataModule(NonGeoDataModule): + """LightningDataModule implementation for the LEVIR-CD+ dataset. + + Uses the train/test splits from the dataset, with val split + generated from the train split + + """ def __init__( self, - batch_size=8, - num_workers=0, - patch_size=256, - val_split_pct=0.2, + batch_size: int = 8, + patch_size: Union[tuple[int, int], int] = 256, + val_split_pct: float = 0.2, + num_workers: int = 0, **kwargs - ): - super().__init__() - self.kwargs = kwargs - self.batch_size = batch_size - self.num_workers = num_workers + ) -> None: + """Initialize a new LEVIRCDPlusDataModule instance. + + Args: + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + val_split_pct: Percentage of the dataset to use as a validation set. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.LEVIRCDPlus`. + """ + super().__init__(LEVIRCDPlusDataModule, batch_size, num_workers, **kwargs) + self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct self.mean = torch.tensor([0.485, 0.456, 0.406]) self.std = torch.tensor([0.229, 0.224, 0.225]) - - def preprocess(self, sample): - sample["image1"] = (sample["image1"] / 255.0).float() - sample["image2"] = (sample["image2"] / 255.0).float() - # Kornia adds batch dimension which we need to remove - sample["image1"] = K.Normalize(mean=self.mean, std=self.std)(sample["image1"]).squeeze(0) - sample["image2"] = K.Normalize(mean=self.mean, std=self.std)(sample["image2"]).squeeze(0) - sample["mask"] = sample["mask"].long() - return sample - def train_augmentations(self, batch): - augmentations = AugmentationSequential( - K.RandomRotation(p=0.5, degrees=90), - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - K.RandomCrop(self.patch_size), - K.RandomSharpness(p=0.5), + self.aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image1", "image2", "mask"], ) - return augmentations(batch) - - def on_after_batch_transfer(self, batch, batch_idx): - if self.trainer and self.trainer.training: - batch["mask"] = batch["mask"].float().unsqueeze(1) - batch = self.train_augmentations(batch) - batch["mask"] = batch["mask"].squeeze(1).long() - return batch - - def prepare_data(self): - LEVIRCDPlus(split="train", **self.kwargs) - LEVIRCDPlus(split="test", **self.kwargs) - - def setup(self, stage=None): - train_transforms = Compose([self.preprocess]) - test_transforms = Compose([self.preprocess]) - - train_dataset = LEVIRCDPlus(split="train", transforms=train_transforms, **self.kwargs) - - if self.val_split_pct > 0.0: - self.train_dataset, self.val_dataset, _ = dataset_split( - train_dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - else: - self.train_dataset = train_dataset - self.val_dataset = train_dataset - - self.test_dataset = LEVIRCDPlus( - split="test", transforms=test_transforms, **self.kwargs - ) - - def train_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self): - return DataLoader( - self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False - ) + + def setup(self, stage: str) -> None: + """Set up datasets. - def test_dataloader(self): - return DataLoader( - self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False - ) \ No newline at end of file + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + if stage in ["fit", "validate"]: + self.train_dataset = LEVIRCDPlus(split="train", **self.kwargs) + if self.val_split_pct > 0.0: + self.train_dataset, self.val_dataset, _ = dataset_split( + self.train_dataset, val_pct=self.val_split_pct, test_pct=0.0 + ) + if stage in ["test"]: + self.test_dataset = LEVIRCDPlus( + split="test", **self.kwargs + ) \ No newline at end of file From 068e455cb5e009785bd9395fccb6ce5a0438aed0 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 13:50:05 +0000 Subject: [PATCH 04/35] format --- torchgeo/datamodules/levircd.py | 34 ++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index c34f9eef94c..2f07e40ef9e 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -1,30 +1,36 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -"""Levircd datamodule.""" +"""LEVIR-CD+ datamodule.""" + +from typing import Union -from ..datasets import LEVIRCDPlus -from torchgeo.samplers.utils import _to_tuple -from torchgeo.datamodules.utils import dataset_split import kornia.augmentation as K import torch +from torchgeo.datamodules.utils import dataset_split +from torchgeo.samplers.utils import _to_tuple + +from ..datasets import LEVIRCDPlus +from ..transforms import AugmentationSequential from .geo import NonGeoDataModule + class LEVIRCDPlusDataModule(NonGeoDataModule): """LightningDataModule implementation for the LEVIR-CD+ dataset. - Uses the train/test splits from the dataset, with val split + Uses the train/test splits from the dataset, with val split generated from the train split """ + def __init__( - self, - batch_size: int = 8, - patch_size: Union[tuple[int, int], int] = 256, - val_split_pct: float = 0.2, - num_workers: int = 0, - **kwargs + self, + batch_size: int = 8, + patch_size: Union[tuple[int, int], int] = 256, + val_split_pct: float = 0.2, + num_workers: int = 0, + **kwargs, ) -> None: """Initialize a new LEVIRCDPlusDataModule instance. @@ -49,7 +55,7 @@ def __init__( _RandomNCrop(self.patch_size, batch_size), data_keys=["image1", "image2", "mask"], ) - + def setup(self, stage: str) -> None: """Set up datasets. @@ -63,6 +69,4 @@ def setup(self, stage: str) -> None: self.train_dataset, val_pct=self.val_split_pct, test_pct=0.0 ) if stage in ["test"]: - self.test_dataset = LEVIRCDPlus( - split="test", **self.kwargs - ) \ No newline at end of file + self.test_dataset = LEVIRCDPlus(split="test", **self.kwargs) From 66e2a77b885b4cfa5cd56235f95e1eb9db4ca6b7 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 13:54:13 +0000 Subject: [PATCH 05/35] isort --- torchgeo/datamodules/levircd.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 2f07e40ef9e..ff68869e663 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -4,15 +4,14 @@ """LEVIR-CD+ datamodule.""" from typing import Union - +from ..datasets import LEVIRCDPlus +from ..transforms import AugmentationSequential +from ..transforms.transforms import _RandomNCrop +from torchgeo.samplers.utils import _to_tuple +from torchgeo.datamodules.utils import dataset_split import kornia.augmentation as K import torch -from torchgeo.datamodules.utils import dataset_split -from torchgeo.samplers.utils import _to_tuple - -from ..datasets import LEVIRCDPlus -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule From 9e1a139823fc725f949e5df195ba7f46812e8b87 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 13:58:31 +0000 Subject: [PATCH 06/35] match to oscd --- torchgeo/datamodules/levircd.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index ff68869e663..116b1432f1c 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -4,22 +4,24 @@ """LEVIR-CD+ datamodule.""" from typing import Union -from ..datasets import LEVIRCDPlus -from ..transforms import AugmentationSequential -from ..transforms.transforms import _RandomNCrop -from torchgeo.samplers.utils import _to_tuple -from torchgeo.datamodules.utils import dataset_split + import kornia.augmentation as K import torch +from torchgeo.datamodules.utils import dataset_split +from torchgeo.samplers.utils import _to_tuple + +from ..datasets import LEVIRCDPlus +from ..transforms import AugmentationSequential +from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule class LEVIRCDPlusDataModule(NonGeoDataModule): """LightningDataModule implementation for the LEVIR-CD+ dataset. - Uses the train/test splits from the dataset, with val split - generated from the train split + Uses the train/test splits from the dataset and further splits + the train split into train/val splits. """ @@ -42,10 +44,11 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.LEVIRCDPlus`. """ - super().__init__(LEVIRCDPlusDataModule, batch_size, num_workers, **kwargs) + super().__init__(LEVIRCDPlusDataModule, 1, num_workers, **kwargs) self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct + self.mean = torch.tensor([0.485, 0.456, 0.406]) self.std = torch.tensor([0.229, 0.224, 0.225]) @@ -62,10 +65,9 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ["fit", "validate"]: - self.train_dataset = LEVIRCDPlus(split="train", **self.kwargs) - if self.val_split_pct > 0.0: - self.train_dataset, self.val_dataset, _ = dataset_split( - self.train_dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) + self.dataset = LEVIRCDPlus(split="train", **self.kwargs) + self.train_dataset, self.val_dataset, _ = dataset_split( + self.dataset, val_pct=self.val_split_pct + ) if stage in ["test"]: self.test_dataset = LEVIRCDPlus(split="test", **self.kwargs) From 354c28a77a2c6c8d8f438c358cf814665d56584b Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 14:21:00 +0000 Subject: [PATCH 07/35] Add test --- tests/datamodules/test_levircd.py | 61 +++++++++++++++++++++++++++++++ torchgeo/datamodules/levircd.py | 2 +- 2 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 tests/datamodules/test_levircd.py diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py new file mode 100644 index 00000000000..cea23993dc6 --- /dev/null +++ b/tests/datamodules/test_levircd.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest +from lightning.pytorch import Trainer + +from torchgeo.datamodules import LEVIRCDPlusDataModule +from torchgeo.datasets import LEVIRCDPlus + + +class TestLEVIRCDPlusDataModule: + def datamodule(self, request: SubRequest) -> OSCDDataModule: + bands = request.param + root = os.path.join("tests", "data", "LEVIR-CD+") + dm = LEVIRCDPlusDataModule(root=root, download=True, num_workers=0) + dm.prepare_data() + dm.trainer = Trainer(accelerator="cpu", max_epochs=1) + return dm + + def test_train_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: + datamodule.setup("fit") + if datamodule.trainer: + datamodule.trainer.training = True + batch = next(iter(datamodule.train_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 + assert batch["image1"].shape[1] == 3 + assert batch["image2"].shape[1] == 3 + + def test_val_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: + datamodule.setup("validate") + if datamodule.trainer: + datamodule.trainer.validating = True + batch = next(iter(datamodule.val_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) + if datamodule.val_split_pct > 0.0: + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 + assert batch["image1"].shape[1] == 3 + assert batch["image2"].shape[1] == 3 + + def test_test_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: + datamodule.setup("test") + if datamodule.trainer: + datamodule.trainer.testing = True + batch = next(iter(datamodule.test_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 + assert batch["image1"].shape[1] == 3 + assert batch["image2"].shape[1] == 3 diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 116b1432f1c..6b4a129e433 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -67,7 +67,7 @@ def setup(self, stage: str) -> None: if stage in ["fit", "validate"]: self.dataset = LEVIRCDPlus(split="train", **self.kwargs) self.train_dataset, self.val_dataset, _ = dataset_split( - self.dataset, val_pct=self.val_split_pct + self.dataset, val_pct=self.val_split_pct, test_pct=0 ) if stage in ["test"]: self.test_dataset = LEVIRCDPlus(split="test", **self.kwargs) From f07780d968ab118d06257f2d8b83c57b4397cca3 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 14:24:34 +0000 Subject: [PATCH 08/35] remove mean and std --- torchgeo/datamodules/levircd.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 6b4a129e433..35fb3f06da0 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -49,9 +49,6 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.mean = torch.tensor([0.485, 0.456, 0.406]) - self.std = torch.tensor([0.229, 0.224, 0.225]) - self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), From c3aa4629fe6a7cd5a147b89df0eb85d1df975891 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 14:29:25 +0000 Subject: [PATCH 09/35] update docstring with versionadded --- torchgeo/datamodules/levircd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 35fb3f06da0..8686af2b1b8 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -23,6 +23,7 @@ class LEVIRCDPlusDataModule(NonGeoDataModule): Uses the train/test splits from the dataset and further splits the train split into train/val splits. + .. versionadded:: 0.6 """ def __init__( From ba7f04bdb86b297e9dd737e7ac15634d2c10ff4b Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 14:32:02 +0000 Subject: [PATCH 10/35] address test issues --- tests/datamodules/test_levircd.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py index cea23993dc6..4ca5b4c845c 100644 --- a/tests/datamodules/test_levircd.py +++ b/tests/datamodules/test_levircd.py @@ -3,17 +3,14 @@ import os -import pytest from _pytest.fixtures import SubRequest from lightning.pytorch import Trainer from torchgeo.datamodules import LEVIRCDPlusDataModule -from torchgeo.datasets import LEVIRCDPlus class TestLEVIRCDPlusDataModule: - def datamodule(self, request: SubRequest) -> OSCDDataModule: - bands = request.param + def datamodule(self, request: SubRequest) -> LEVIRCDPlusDataModule: root = os.path.join("tests", "data", "LEVIR-CD+") dm = LEVIRCDPlusDataModule(root=root, download=True, num_workers=0) dm.prepare_data() From b24828f0cfe69392203f24ef273f79f7845bf60d Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 14:34:27 +0000 Subject: [PATCH 11/35] fix init --- torchgeo/datamodules/__init__.py | 1 + torchgeo/datamodules/levircd.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 9acdc922429..66555c7b978 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -57,6 +57,7 @@ "GID15DataModule", "InriaAerialImageLabelingDataModule", "LandCoverAIDataModule", + "LEVIRCDPlusDataModule", "LoveDADataModule", "NASAMarineDebrisDataModule", "OSCDDataModule", diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 8686af2b1b8..70b4f982e24 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -6,7 +6,6 @@ from typing import Union import kornia.augmentation as K -import torch from torchgeo.datamodules.utils import dataset_split from torchgeo.samplers.utils import _to_tuple From 3b9b3aa2da418f5183856da2a361339518d518a4 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 14:41:51 +0000 Subject: [PATCH 12/35] fix init dataset --- torchgeo/datamodules/levircd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 70b4f982e24..20822e2fe3d 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -44,7 +44,7 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.LEVIRCDPlus`. """ - super().__init__(LEVIRCDPlusDataModule, 1, num_workers, **kwargs) + super().__init__(LEVIRCDPlus, 1, num_workers, **kwargs) self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct From 269be37a82cbf9474fb05ac716de8e6bb7ef1ace Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 14:58:32 +0000 Subject: [PATCH 13/35] fix type hint --- torchgeo/datamodules/levircd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 20822e2fe3d..443fed7b991 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -31,7 +31,7 @@ def __init__( patch_size: Union[tuple[int, int], int] = 256, val_split_pct: float = 0.2, num_workers: int = 0, - **kwargs, + **kwargs: Any, ) -> None: """Initialize a new LEVIRCDPlusDataModule instance. From 31488a4dd28eabe0ba209db29db0a3a4cdf488bd Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 15:03:06 +0000 Subject: [PATCH 14/35] import --- torchgeo/datamodules/levircd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 443fed7b991..4b87fb9d0e2 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -3,7 +3,7 @@ """LEVIR-CD+ datamodule.""" -from typing import Union +from typing import Union, Any import kornia.augmentation as K From 90d6cf3d2080cea350935493b88923c0543564e1 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 15:36:40 +0000 Subject: [PATCH 15/35] add fixture --- tests/datamodules/test_levircd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py index 4ca5b4c845c..51b2c03bd0b 100644 --- a/tests/datamodules/test_levircd.py +++ b/tests/datamodules/test_levircd.py @@ -10,6 +10,7 @@ class TestLEVIRCDPlusDataModule: + @pytest.fixture def datamodule(self, request: SubRequest) -> LEVIRCDPlusDataModule: root = os.path.join("tests", "data", "LEVIR-CD+") dm = LEVIRCDPlusDataModule(root=root, download=True, num_workers=0) From d35c4a2ce3f9cb03512a2bef58ec4cff0ef4aaf8 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 15:37:40 +0000 Subject: [PATCH 16/35] import pytest --- tests/datamodules/test_levircd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py index 51b2c03bd0b..a7b9bc1a096 100644 --- a/tests/datamodules/test_levircd.py +++ b/tests/datamodules/test_levircd.py @@ -3,6 +3,7 @@ import os +import pytest from _pytest.fixtures import SubRequest from lightning.pytorch import Trainer From ce90650a24eabcc7595fe6e2c62905a854281751 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 31 Oct 2023 15:42:56 +0000 Subject: [PATCH 17/35] make image float --- torchgeo/datasets/levircd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index f1e7bd390e7..9fc1cd543b0 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -156,7 +156,7 @@ def _load_image(self, path: str) -> Tensor: filename = os.path.join(path) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor From 5f9689da745d2aec7f34d3cc2d6ed181d5d3cffc Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 08:52:12 +0000 Subject: [PATCH 18/35] fix plotting --- torchgeo/datasets/levircd.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 9fc1cd543b0..f92c0981b6f 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -15,7 +15,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_and_extract_archive +from .utils import download_and_extract_archive, draw_semantic_segmentation_masks class LEVIRCDPlus(NonGeoDataset): @@ -225,20 +225,33 @@ def plot( .. versionadded:: 0.2 """ - image1, image2, mask = (sample["image1"], sample["image2"], sample["mask"]) ncols = 3 + def get_rgb(img: Tensor) -> "np.typing.NDArray[np.uint8]": + img = img.permute(1, 2, 0) + rgb_img = img.float().numpy() + per02 = np.percentile(rgb_img, 2) + per98 = np.percentile(rgb_img, 98) + rgb_img = (np.clip((rgb_img - per02) / (per98 - per02), 0, 1) * 255).astype( + np.uint8 + ) + return rgb_img + + image1 = get_rgb(sample["image1"]) + image2 = get_rgb(sample["image2"]) + mask = sample["mask"].numpy() + if "prediction" in sample: prediction = sample["prediction"] ncols += 1 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) - axs[0].imshow(image1.permute(1, 2, 0)) + axs[0].imshow(image1) axs[0].axis("off") - axs[1].imshow(image2.permute(1, 2, 0)) + axs[1].imshow(image2) axs[1].axis("off") - axs[2].imshow(mask) + axs[2].imshow(mask, cmap="gray") axs[2].axis("off") if "prediction" in sample: From 876c99c9fb6a258a36af57f462ede69b37990923 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 09:02:26 +0000 Subject: [PATCH 19/35] isort --- torchgeo/datamodules/levircd.py | 2 +- torchgeo/datasets/levircd.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 4b87fb9d0e2..9675e733772 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -3,7 +3,7 @@ """LEVIR-CD+ datamodule.""" -from typing import Union, Any +from typing import Any, Union import kornia.augmentation as K diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index f92c0981b6f..468d8eba5d1 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -15,7 +15,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import download_and_extract_archive, draw_semantic_segmentation_masks +from .utils import download_and_extract_archive class LEVIRCDPlus(NonGeoDataset): From 02234881c49887c6a493fbf1ec09f7fde31fdb5d Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 09:18:33 +0000 Subject: [PATCH 20/35] mock download --- tests/datamodules/test_levircd.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py index a7b9bc1a096..296cb7d523a 100644 --- a/tests/datamodules/test_levircd.py +++ b/tests/datamodules/test_levircd.py @@ -6,13 +6,18 @@ import pytest from _pytest.fixtures import SubRequest from lightning.pytorch import Trainer +from pytest import MonkeyPatch from torchgeo.datamodules import LEVIRCDPlusDataModule class TestLEVIRCDPlusDataModule: @pytest.fixture - def datamodule(self, request: SubRequest) -> LEVIRCDPlusDataModule: + def datamodule( + self, monkeypatch: MonkeyPatch, request: SubRequest + ) -> LEVIRCDPlusDataModule: + monkeypatch.setattr(LEVIRCDPlusDataModule, "download", Mock(return_value=True)) + root = os.path.join("tests", "data", "LEVIR-CD+") dm = LEVIRCDPlusDataModule(root=root, download=True, num_workers=0) dm.prepare_data() From 6453f8b16884d03a14ae304bf0549ff7882f9415 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 09:20:32 +0000 Subject: [PATCH 21/35] fix import --- tests/datamodules/test_levircd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py index 296cb7d523a..b7ec09f85da 100644 --- a/tests/datamodules/test_levircd.py +++ b/tests/datamodules/test_levircd.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import os +from unittest.mock import Mock import pytest from _pytest.fixtures import SubRequest From 2cac4e6d88ec4ddbc490b98fa468836d2f82bdef Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 09:30:12 +0000 Subject: [PATCH 22/35] satisfy mypy --- torchgeo/datasets/levircd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 468d8eba5d1..55feae5a2fc 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Optional +from typing import Callable, Optional, cast import matplotlib.pyplot as plt import numpy as np @@ -235,7 +235,7 @@ def get_rgb(img: Tensor) -> "np.typing.NDArray[np.uint8]": rgb_img = (np.clip((rgb_img - per02) / (per98 - per02), 0, 1) * 255).astype( np.uint8 ) - return rgb_img + return cast("np.typing.NDArray[np.uint8]", rgb_img) image1 = get_rgb(sample["image1"]) image2 = get_rgb(sample["image2"]) From 9271f62fbad7ce4d1ac0ff17be7f0c9c8ea5197c Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 09:55:51 +0000 Subject: [PATCH 23/35] Fix fixture for TestLEVIRCDPlusDataModule --- tests/datamodules/test_levircd.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py index b7ec09f85da..fa5ce742354 100644 --- a/tests/datamodules/test_levircd.py +++ b/tests/datamodules/test_levircd.py @@ -2,25 +2,35 @@ # Licensed under the MIT License. import os -from unittest.mock import Mock +import shutil +from pathlib import Path import pytest -from _pytest.fixtures import SubRequest from lightning.pytorch import Trainer from pytest import MonkeyPatch from torchgeo.datamodules import LEVIRCDPlusDataModule +from torchgeo.datasets import LEVIRCDPlus + + +def download_url(url: str, root: str, *args: str) -> None: + shutil.copy(url, root) class TestLEVIRCDPlusDataModule: - @pytest.fixture + @pytest.fixture(params=["train", "validate", "test"]) def datamodule( - self, monkeypatch: MonkeyPatch, request: SubRequest + self, monkeypatch: MonkeyPatch, tmp_path: Path ) -> LEVIRCDPlusDataModule: - monkeypatch.setattr(LEVIRCDPlusDataModule, "download", Mock(return_value=True)) + monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) + monkeypatch.setattr(LEVIRCDPlus, "md5", md5) + url = os.path.join("tests", "data", "levircd", "LEVIR-CD+.zip") + monkeypatch.setattr(LEVIRCDPlus, "url", url) - root = os.path.join("tests", "data", "LEVIR-CD+") - dm = LEVIRCDPlusDataModule(root=root, download=True, num_workers=0) + root = str(tmp_path) + dm = LEVIRCDPlusDataModule( + root=root, download=True, num_workers=0, checksum=True + ) dm.prepare_data() dm.trainer = Trainer(accelerator="cpu", max_epochs=1) return dm From 71cb8c46f27a0293e8249985795da9a06aad36f9 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 09:58:02 +0000 Subject: [PATCH 24/35] fix imports --- tests/datamodules/test_levircd.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py index fa5ce742354..7ea47c52e0c 100644 --- a/tests/datamodules/test_levircd.py +++ b/tests/datamodules/test_levircd.py @@ -9,6 +9,7 @@ from lightning.pytorch import Trainer from pytest import MonkeyPatch +import torchgeo.datasets.utils from torchgeo.datamodules import LEVIRCDPlusDataModule from torchgeo.datasets import LEVIRCDPlus @@ -23,6 +24,7 @@ def datamodule( self, monkeypatch: MonkeyPatch, tmp_path: Path ) -> LEVIRCDPlusDataModule: monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) + md5 = "1adf156f628aa32fb2e8fe6cada16c04" monkeypatch.setattr(LEVIRCDPlus, "md5", md5) url = os.path.join("tests", "data", "levircd", "LEVIR-CD+.zip") monkeypatch.setattr(LEVIRCDPlus, "url", url) From 8f18774b239d98958fafd013458af7735006f751 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 10:10:37 +0000 Subject: [PATCH 25/35] Fix test values --- tests/datamodules/test_levircd.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py index 7ea47c52e0c..76bf3fc6aec 100644 --- a/tests/datamodules/test_levircd.py +++ b/tests/datamodules/test_levircd.py @@ -43,10 +43,10 @@ def test_train_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: datamodule.trainer.training = True batch = next(iter(datamodule.train_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 - assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 4 + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 4 assert batch["image1"].shape[1] == 3 assert batch["image2"].shape[1] == 3 @@ -57,10 +57,10 @@ def test_val_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: batch = next(iter(datamodule.val_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) if datamodule.val_split_pct > 0.0: - assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 - assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 4 + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 4 assert batch["image1"].shape[1] == 3 assert batch["image2"].shape[1] == 3 @@ -70,9 +70,9 @@ def test_test_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: datamodule.trainer.testing = True batch = next(iter(datamodule.test_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 - assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 assert batch["image1"].shape[1] == 3 assert batch["image2"].shape[1] == 3 From 9604580d15bf7ada60068a3e9a43e79f74988e8d Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 12:33:49 +0000 Subject: [PATCH 26/35] fix test values --- tests/datamodules/test_levircd.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py index 76bf3fc6aec..832e5b1e8bf 100644 --- a/tests/datamodules/test_levircd.py +++ b/tests/datamodules/test_levircd.py @@ -44,9 +44,9 @@ def test_train_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: batch = next(iter(datamodule.train_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) - assert batch["image1"].shape[0] == batch["mask"].shape[0] == 4 + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8 assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) - assert batch["image2"].shape[0] == batch["mask"].shape[0] == 4 + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8 assert batch["image1"].shape[1] == 3 assert batch["image2"].shape[1] == 3 @@ -58,9 +58,9 @@ def test_val_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: batch = datamodule.on_after_batch_transfer(batch, 0) if datamodule.val_split_pct > 0.0: assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) - assert batch["image1"].shape[0] == batch["mask"].shape[0] == 4 + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8 assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) - assert batch["image2"].shape[0] == batch["mask"].shape[0] == 4 + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8 assert batch["image1"].shape[1] == 3 assert batch["image2"].shape[1] == 3 @@ -71,8 +71,8 @@ def test_test_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: batch = next(iter(datamodule.test_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) - assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 + assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8 assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) - assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 + assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8 assert batch["image1"].shape[1] == 3 assert batch["image2"].shape[1] == 3 From e2ea210532c1d0ba77eb3fb81cbb55f4933602b0 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 14:01:39 +0000 Subject: [PATCH 27/35] add val_split_pct=0.5 --- tests/datamodules/test_levircd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py index 832e5b1e8bf..671f14a8226 100644 --- a/tests/datamodules/test_levircd.py +++ b/tests/datamodules/test_levircd.py @@ -31,7 +31,7 @@ def datamodule( root = str(tmp_path) dm = LEVIRCDPlusDataModule( - root=root, download=True, num_workers=0, checksum=True + root=root, download=True, num_workers=0, checksum=True, val_split_pct=0.5 ) dm.prepare_data() dm.trainer = Trainer(accelerator="cpu", max_epochs=1) From f0976e3fb6be9d8391fa4a0dad0b87dc42fd88ba Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 14:14:48 +0000 Subject: [PATCH 28/35] Prevent divide by zero --- torchgeo/datasets/levircd.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 55feae5a2fc..9ef81fcd4f9 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -232,9 +232,11 @@ def get_rgb(img: Tensor) -> "np.typing.NDArray[np.uint8]": rgb_img = img.float().numpy() per02 = np.percentile(rgb_img, 2) per98 = np.percentile(rgb_img, 98) - rgb_img = (np.clip((rgb_img - per02) / (per98 - per02), 0, 1) * 255).astype( - np.uint8 - ) + delta = per98 - per02 + epsilon = 1e-7 + rgb_img = ( + np.clip((rgb_img - per02) / (delta + epsilon), 0, 1) * 255 + ).astype(np.uint8) return cast("np.typing.NDArray[np.uint8]", rgb_img) image1 = get_rgb(sample["image1"]) From 83f24ac13a0f0c6671331a988bb5d7f76a4a30cd Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 14:36:42 +0000 Subject: [PATCH 29/35] Update torchgeo/datamodules/levircd.py Co-authored-by: Adam J. Stewart --- torchgeo/datamodules/levircd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 9675e733772..b021d8c860b 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -63,8 +63,8 @@ def setup(self, stage: str) -> None: """ if stage in ["fit", "validate"]: self.dataset = LEVIRCDPlus(split="train", **self.kwargs) - self.train_dataset, self.val_dataset, _ = dataset_split( - self.dataset, val_pct=self.val_split_pct, test_pct=0 + self.train_dataset, self.val_dataset = dataset_split( + self.dataset, val_pct=self.val_split_pct ) if stage in ["test"]: self.test_dataset = LEVIRCDPlus(split="test", **self.kwargs) From a77d763dcebfc5b777ce08d12fab52c195b24fa1 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 14:36:48 +0000 Subject: [PATCH 30/35] Update torchgeo/datasets/levircd.py Co-authored-by: Adam J. Stewart --- torchgeo/datasets/levircd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 9ef81fcd4f9..2f433cd265c 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -234,10 +234,10 @@ def get_rgb(img: Tensor) -> "np.typing.NDArray[np.uint8]": per98 = np.percentile(rgb_img, 98) delta = per98 - per02 epsilon = 1e-7 - rgb_img = ( + rgb_img: "np.typing.NDArray[np.uint8]" = ( np.clip((rgb_img - per02) / (delta + epsilon), 0, 1) * 255 ).astype(np.uint8) - return cast("np.typing.NDArray[np.uint8]", rgb_img) + return rgb_img image1 = get_rgb(sample["image1"]) image2 = get_rgb(sample["image2"]) From 65853f92d460a0b106770ee927f188f619e0bd18 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 14:39:04 +0000 Subject: [PATCH 31/35] remove cast import --- torchgeo/datasets/levircd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 2f433cd265c..bbbf108f312 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Optional, cast +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np From 150cead56d1818674ae065cbdf68abe1f138a151 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 14:46:45 +0000 Subject: [PATCH 32/35] remove unused parameterization --- tests/datamodules/test_levircd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py index 671f14a8226..8e67152346a 100644 --- a/tests/datamodules/test_levircd.py +++ b/tests/datamodules/test_levircd.py @@ -19,7 +19,7 @@ def download_url(url: str, root: str, *args: str) -> None: class TestLEVIRCDPlusDataModule: - @pytest.fixture(params=["train", "validate", "test"]) + @pytest.fixture def datamodule( self, monkeypatch: MonkeyPatch, tmp_path: Path ) -> LEVIRCDPlusDataModule: From 2699908868e9a7d8a56d0a020a3eb20babf3332f Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 14:52:54 +0000 Subject: [PATCH 33/35] Return cast --- torchgeo/datasets/levircd.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index bbbf108f312..b356a089fd1 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Optional +from typing import Callable, Optional, cast import matplotlib.pyplot as plt import numpy as np @@ -228,16 +228,15 @@ def plot( ncols = 3 def get_rgb(img: Tensor) -> "np.typing.NDArray[np.uint8]": - img = img.permute(1, 2, 0) - rgb_img = img.float().numpy() - per02 = np.percentile(rgb_img, 2) - per98 = np.percentile(rgb_img, 98) + img = img.permute(1, 2, 0).float().numpy() + per02 = np.percentile(img, 2) + per98 = np.percentile(img, 98) delta = per98 - per02 epsilon = 1e-7 rgb_img: "np.typing.NDArray[np.uint8]" = ( - np.clip((rgb_img - per02) / (delta + epsilon), 0, 1) * 255 + np.clip((img - per02) / (delta + epsilon), 0, 1) * 255 ).astype(np.uint8) - return rgb_img + return cast("np.typing.NDArray[np.uint8]", rgb_img) image1 = get_rgb(sample["image1"]) image2 = get_rgb(sample["image2"]) From c5670fa4fd5b3f29bfaa1d8b186fa410b3581169 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 15:06:18 +0000 Subject: [PATCH 34/35] address mypy --- torchgeo/datasets/levircd.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index b356a089fd1..582e798c624 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Optional, cast +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -227,16 +227,15 @@ def plot( """ ncols = 3 - def get_rgb(img: Tensor) -> "np.typing.NDArray[np.uint8]": - img = img.permute(1, 2, 0).float().numpy() - per02 = np.percentile(img, 2) - per98 = np.percentile(img, 98) + def get_rgb(img: Tensor) -> np.ndarray[np.uint8]: + rgb_img = img.permute(1, 2, 0).float().numpy() + per02 = np.percentile(rgb_img, 2) + per98 = np.percentile(rgb_img, 98) delta = per98 - per02 epsilon = 1e-7 - rgb_img: "np.typing.NDArray[np.uint8]" = ( - np.clip((img - per02) / (delta + epsilon), 0, 1) * 255 - ).astype(np.uint8) - return cast("np.typing.NDArray[np.uint8]", rgb_img) + return (np.clip((rgb_img - per02) / (delta + epsilon), 0, 1) * 255).astype( + np.uint8 + ) image1 = get_rgb(sample["image1"]) image2 = get_rgb(sample["image2"]) From 42acb0fa90f30c63a25f645b7fdb87fb986f5037 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 1 Nov 2023 15:14:47 +0000 Subject: [PATCH 35/35] try again mypy --- torchgeo/datasets/levircd.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 582e798c624..6a87ec27105 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -227,15 +227,16 @@ def plot( """ ncols = 3 - def get_rgb(img: Tensor) -> np.ndarray[np.uint8]: + def get_rgb(img: Tensor) -> "np.typing.NDArray[np.uint8]": rgb_img = img.permute(1, 2, 0).float().numpy() per02 = np.percentile(rgb_img, 2) per98 = np.percentile(rgb_img, 98) delta = per98 - per02 epsilon = 1e-7 - return (np.clip((rgb_img - per02) / (delta + epsilon), 0, 1) * 255).astype( - np.uint8 - ) + norm_img: "np.typing.NDArray[np.uint8]" = ( + np.clip((rgb_img - per02) / (delta + epsilon), 0, 1) * 255 + ).astype(np.uint8) + return norm_img image1 = get_rgb(sample["image1"]) image2 = get_rgb(sample["image2"])