Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plot method and band selection to Zueri crop dataset #334

Merged
merged 9 commits into from
Dec 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified tests/data/zuericrop/ZueriCrop.hdf5
Binary file not shown.
23 changes: 21 additions & 2 deletions tests/datasets/test_zuericrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import Any, Generator

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -35,12 +36,12 @@ def dataset(
os.path.join(data_dir, "ZueriCrop.hdf5"),
os.path.join(data_dir, "labels.csv"),
]
md5s = ["8c0ca5ad53903aeba8a1d06bba50a5ec", "d41d8cd98f00b204e9800998ecf8427e"]
md5s = ["1635231df67f3d25f4f1e62c98e221a4", "5118398c7a5bbc246f5f6bb35d8d529b"]
monkeypatch.setattr(ZueriCrop, "urls", urls) # type: ignore[attr-defined]
monkeypatch.setattr(ZueriCrop, "md5s", md5s) # type: ignore[attr-defined]
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[attr-defined]
return ZueriCrop(root, transforms, download=True, checksum=True)
return ZueriCrop(root=root, transforms=transforms, download=True, checksum=True)

@pytest.fixture
def mock_missing_module(
Expand Down Expand Up @@ -100,3 +101,21 @@ def test_mock_missing_module(
match="h5py is not installed and is required to use this dataset",
):
ZueriCrop(dataset.root, download=True, checksum=True)

def test_invalid_bands(self) -> None:
with pytest.raises(ValueError):
ZueriCrop(bands=("OK", "BK"))

def test_plot(self, dataset: ZueriCrop) -> None:
dataset.plot(dataset[0], suptitle="Test")
plt.close()

sample = dataset[0]
sample["prediction"] = sample["mask"].clone()
dataset.plot(sample, suptitle="prediction")
plt.close()

def test_plot_rgb(self, dataset: ZueriCrop) -> None:
dataset = ZueriCrop(root=dataset.root, bands=("B02",))
with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"):
dataset.plot(dataset[0], time_step=0, suptitle="Single Band")
100 changes: 98 additions & 2 deletions torchgeo/datasets/zuericrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
"""ZueriCrop dataset."""

import os
from typing import Callable, Dict, Optional, Tuple
from typing import Callable, Dict, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import torch
from torch import Tensor

from .geo import VisionDataset
from .utils import download_url
from .utils import download_url, percentile_normalization


class ZueriCrop(VisionDataset):
Expand Down Expand Up @@ -56,9 +57,13 @@ class ZueriCrop(VisionDataset):
md5s = ["1635231df67f3d25f4f1e62c98e221a4", "5118398c7a5bbc246f5f6bb35d8d529b"]
filenames = ["ZueriCrop.hdf5", "labels.csv"]

band_names = ("NIR", "B03", "B02", "B04", "B05", "B06", "B07", "B11", "B12")
RGB_BANDS = ["B04", "B03", "B02"]

def __init__(
self,
root: str = "data",
bands: Sequence[str] = band_names,
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
Expand All @@ -67,6 +72,7 @@ def __init__(

Args:
root: root directory where dataset can be found
bands: the subset of bands to load
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
Expand All @@ -76,7 +82,13 @@ def __init__(
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
"""
self._validate_bands(bands)
self.band_indices = torch.tensor( # type: ignore[attr-defined]
[self.band_names.index(b) for b in bands]
).long()

self.root = root
self.bands = bands
self.transforms = transforms
self.download = download
self.checksum = checksum
Expand Down Expand Up @@ -139,6 +151,9 @@ def _load_image(self, index: int) -> Tensor:
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
# Convert from TxHxWxC to TxCxHxW
tensor = tensor.permute((0, 3, 1, 2))
tensor = torch.index_select( # type: ignore[attr-defined]
tensor, dim=1, index=self.band_indices
)
return tensor

def _load_target(self, index: int) -> Tuple[Tensor, Tensor, Tensor]:
Expand Down Expand Up @@ -230,3 +245,84 @@ def _download(self) -> None:
filename=filename,
md5=md5 if self.checksum else None,
)

def _validate_bands(self, bands: Sequence[str]) -> None:
"""Validate list of bands.

Args:
bands: user-provided sequence of bands to load
Raises:
AssertionError: if ``bands`` is not a sequence
ValueError: if an invalid band name is provided

.. versionadded:: 0.2
"""
assert isinstance(bands, Sequence), "'bands' must be a sequence"
for band in bands:
if band not in self.band_names:
raise ValueError(f"'{band}' is an invalid band name.")

def plot(
self,
sample: Dict[str, Tensor],
time_step: int = 0,
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.

Args:
sample: a sample returned by :meth:`__getitem__`
time_step: time step at which to access image, beginning with 0
show_titles: flag indicating whether to show titles above each panel
suptitle: optional suptitle to use for figure

Returns:
a matplotlib Figure with the rendered sample

.. versionadded:: 0.2
"""
rgb_indices = []
for band in self.RGB_BANDS:
if band in self.bands:
rgb_indices.append(self.bands.index(band))
else:
raise ValueError("Dataset doesn't contain some of the RGB bands")

ncols = 2
image, mask = sample["image"][time_step, rgb_indices], sample["mask"]

image = torch.tensor( # type: ignore[attr-defined]
percentile_normalization(image.numpy()) * 255,
dtype=torch.uint8, # type: ignore[attr-defined]
)

mask = torch.argmax(mask, dim=0) # type: ignore[attr-defined]

if "prediction" in sample:
ncols += 1
preds = torch.argmax( # type: ignore[attr-defined]
sample["prediction"], dim=0
)

fig, axs = plt.subplots(ncols=ncols, figsize=(10, 10 * ncols))

axs[0].imshow(image.permute(1, 2, 0))
axs[0].axis("off")
axs[1].imshow(mask)
axs[1].axis("off")

if show_titles:
axs[0].set_title("Image")
axs[1].set_title("Mask")

if "prediction" in sample:
axs[2].imshow(preds)
axs[2].axis("off")
if show_titles:
axs[2].set_title("Prediction")

if suptitle is not None:
plt.suptitle(suptitle)

return fig