Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data module for LEVIR-CD+ dataset #1707

Merged
merged 36 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
872066b
add file
robmarkcole Oct 31, 2023
b7d8d54
Add to init
robmarkcole Oct 31, 2023
dc7cec9
refactor
robmarkcole Oct 31, 2023
068e455
format
robmarkcole Oct 31, 2023
66e2a77
isort
robmarkcole Oct 31, 2023
9e1a139
match to oscd
robmarkcole Oct 31, 2023
354c28a
Add test
robmarkcole Oct 31, 2023
f07780d
remove mean and std
robmarkcole Oct 31, 2023
d2e438e
Merge branch 'main' into issue-1706
robmarkcole Oct 31, 2023
c3aa462
update docstring with versionadded
robmarkcole Oct 31, 2023
ba7f04b
address test issues
robmarkcole Oct 31, 2023
b24828f
fix init
robmarkcole Oct 31, 2023
3b9b3aa
fix init dataset
robmarkcole Oct 31, 2023
269be37
fix type hint
robmarkcole Oct 31, 2023
31488a4
import
robmarkcole Oct 31, 2023
90d6cf3
add fixture
robmarkcole Oct 31, 2023
d35c4a2
import pytest
robmarkcole Oct 31, 2023
ce90650
make image float
robmarkcole Oct 31, 2023
5f9689d
fix plotting
robmarkcole Nov 1, 2023
876c99c
isort
robmarkcole Nov 1, 2023
0223488
mock download
robmarkcole Nov 1, 2023
6453f8b
fix import
robmarkcole Nov 1, 2023
2cac4e6
satisfy mypy
robmarkcole Nov 1, 2023
9271f62
Fix fixture for TestLEVIRCDPlusDataModule
robmarkcole Nov 1, 2023
71cb8c4
fix imports
robmarkcole Nov 1, 2023
8f18774
Fix test values
robmarkcole Nov 1, 2023
9604580
fix test values
robmarkcole Nov 1, 2023
e2ea210
add val_split_pct=0.5
robmarkcole Nov 1, 2023
f0976e3
Prevent divide by zero
robmarkcole Nov 1, 2023
83f24ac
Update torchgeo/datamodules/levircd.py
robmarkcole Nov 1, 2023
a77d763
Update torchgeo/datasets/levircd.py
robmarkcole Nov 1, 2023
65853f9
remove cast import
robmarkcole Nov 1, 2023
150cead
remove unused parameterization
robmarkcole Nov 1, 2023
2699908
Return cast
robmarkcole Nov 1, 2023
c5670fa
address mypy
robmarkcole Nov 1, 2023
42acb0f
try again mypy
robmarkcole Nov 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions tests/datamodules/test_levircd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import pytest
from lightning.pytorch import Trainer
from pytest import MonkeyPatch

import torchgeo.datasets.utils
from torchgeo.datamodules import LEVIRCDPlusDataModule
from torchgeo.datasets import LEVIRCDPlus


def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)


class TestLEVIRCDPlusDataModule:
@pytest.fixture(params=["train", "validate", "test"])
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
def datamodule(
self, monkeypatch: MonkeyPatch, tmp_path: Path
) -> LEVIRCDPlusDataModule:
monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url)
md5 = "1adf156f628aa32fb2e8fe6cada16c04"
monkeypatch.setattr(LEVIRCDPlus, "md5", md5)
url = os.path.join("tests", "data", "levircd", "LEVIR-CD+.zip")
monkeypatch.setattr(LEVIRCDPlus, "url", url)

root = str(tmp_path)
dm = LEVIRCDPlusDataModule(
root=root, download=True, num_workers=0, checksum=True
)
dm.prepare_data()
dm.trainer = Trainer(accelerator="cpu", max_epochs=1)
return dm

def test_train_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None:
datamodule.setup("fit")
if datamodule.trainer:
datamodule.trainer.training = True
batch = next(iter(datamodule.train_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256)
assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8
assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256)
assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8
assert batch["image1"].shape[1] == 3
assert batch["image2"].shape[1] == 3

def test_val_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None:
datamodule.setup("validate")
if datamodule.trainer:
datamodule.trainer.validating = True
batch = next(iter(datamodule.val_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
if datamodule.val_split_pct > 0.0:
assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256)
assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8
assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256)
assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8
assert batch["image1"].shape[1] == 3
assert batch["image2"].shape[1] == 3

def test_test_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None:
datamodule.setup("test")
if datamodule.trainer:
datamodule.trainer.testing = True
batch = next(iter(datamodule.test_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256)
assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8
assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256)
assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8
assert batch["image1"].shape[1] == 3
assert batch["image2"].shape[1] == 3
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .l7irish import L7IrishDataModule
from .l8biome import L8BiomeDataModule
from .landcoverai import LandCoverAIDataModule
from .levircd import LEVIRCDPlusDataModule
from .loveda import LoveDADataModule
from .naip import NAIPChesapeakeDataModule
from .nasa_marine_debris import NASAMarineDebrisDataModule
Expand Down Expand Up @@ -56,6 +57,7 @@
"GID15DataModule",
"InriaAerialImageLabelingDataModule",
"LandCoverAIDataModule",
"LEVIRCDPlusDataModule",
"LoveDADataModule",
"NASAMarineDebrisDataModule",
"OSCDDataModule",
Expand Down
70 changes: 70 additions & 0 deletions torchgeo/datamodules/levircd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""LEVIR-CD+ datamodule."""

from typing import Any, Union

import kornia.augmentation as K

from torchgeo.datamodules.utils import dataset_split
from torchgeo.samplers.utils import _to_tuple

from ..datasets import LEVIRCDPlus
from ..transforms import AugmentationSequential
from ..transforms.transforms import _RandomNCrop
from .geo import NonGeoDataModule


class LEVIRCDPlusDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the LEVIR-CD+ dataset.

Uses the train/test splits from the dataset and further splits
the train split into train/val splits.

robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
.. versionadded:: 0.6
"""

def __init__(
self,
batch_size: int = 8,
patch_size: Union[tuple[int, int], int] = 256,
val_split_pct: float = 0.2,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new LEVIRCDPlusDataModule instance.

Args:
batch_size: Size of each mini-batch.
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
Should be a multiple of 32 for most segmentation architectures.
val_split_pct: Percentage of the dataset to use as a validation set.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.LEVIRCDPlus`.
"""
super().__init__(LEVIRCDPlus, 1, num_workers, **kwargs)

self.patch_size = _to_tuple(patch_size)
self.val_split_pct = val_split_pct

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
_RandomNCrop(self.patch_size, batch_size),
data_keys=["image1", "image2", "mask"],
)

def setup(self, stage: str) -> None:
"""Set up datasets.

Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
if stage in ["fit", "validate"]:
self.dataset = LEVIRCDPlus(split="train", **self.kwargs)
self.train_dataset, self.val_dataset, _ = dataset_split(
self.dataset, val_pct=self.val_split_pct, test_pct=0
)
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
if stage in ["test"]:
self.test_dataset = LEVIRCDPlus(split="test", **self.kwargs)
25 changes: 19 additions & 6 deletions torchgeo/datasets/levircd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import glob
import os
from typing import Callable, Optional
from typing import Callable, Optional, cast

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -156,7 +156,7 @@ def _load_image(self, path: str) -> Tensor:
filename = os.path.join(path)
with Image.open(filename) as img:
array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB"))
tensor = torch.from_numpy(array)
tensor = torch.from_numpy(array).float()
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
return tensor
Expand Down Expand Up @@ -225,20 +225,33 @@ def plot(

.. versionadded:: 0.2
"""
image1, image2, mask = (sample["image1"], sample["image2"], sample["mask"])
ncols = 3

def get_rgb(img: Tensor) -> "np.typing.NDArray[np.uint8]":
img = img.permute(1, 2, 0)
rgb_img = img.float().numpy()
per02 = np.percentile(rgb_img, 2)
per98 = np.percentile(rgb_img, 98)
rgb_img = (np.clip((rgb_img - per02) / (per98 - per02), 0, 1) * 255).astype(
np.uint8
)
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
return cast("np.typing.NDArray[np.uint8]", rgb_img)

image1 = get_rgb(sample["image1"])
image2 = get_rgb(sample["image2"])
mask = sample["mask"].numpy()

if "prediction" in sample:
prediction = sample["prediction"]
ncols += 1

fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5))

axs[0].imshow(image1.permute(1, 2, 0))
axs[0].imshow(image1)
axs[0].axis("off")
axs[1].imshow(image2.permute(1, 2, 0))
axs[1].imshow(image2)
axs[1].axis("off")
axs[2].imshow(mask)
axs[2].imshow(mask, cmap="gray")
axs[2].axis("off")

if "prediction" in sample:
Expand Down
Loading