From 5f15b83c0841aefb7d88fa35ada8c65425fa708c Mon Sep 17 00:00:00 2001 From: Nils Lehmann <35272119+nilsleh@users.noreply.github.com> Date: Sun, 20 Feb 2022 20:45:10 +0100 Subject: [PATCH] Add band selection to EuroSat and adapt plot method (#397) * add band selection and adapt plot method to rgb * keep normalization in plotting method --- tests/datasets/test_eurosat.py | 17 ++++++- torchgeo/datasets/eurosat.py | 86 +++++++++++++++++++++++++++++++++- 2 files changed, 100 insertions(+), 3 deletions(-) diff --git a/tests/datasets/test_eurosat.py b/tests/datasets/test_eurosat.py index a8b47ea2561..61fff6112dc 100644 --- a/tests/datasets/test_eurosat.py +++ b/tests/datasets/test_eurosat.py @@ -58,7 +58,9 @@ def dataset( root = str(tmp_path) split = request.param transforms = nn.Identity() # type: ignore[attr-defined] - return EuroSAT(root, split, transforms, download=True, checksum=True) + return EuroSAT( + root=root, split=split, transforms=transforms, download=True, checksum=True + ) def test_getitem(self, dataset: EuroSAT) -> None: x = dataset[0] @@ -66,6 +68,14 @@ def test_getitem(self, dataset: EuroSAT) -> None: assert isinstance(x["image"], torch.Tensor) assert isinstance(x["label"], torch.Tensor) + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + EuroSAT(split="foo") + + def test_invalid_bands(self) -> None: + with pytest.raises(ValueError): + EuroSAT(bands=("OK", "BK")) + def test_len(self, dataset: EuroSAT) -> None: assert len(dataset) == 2 @@ -100,3 +110,8 @@ def test_plot(self, dataset: EuroSAT) -> None: x["prediction"] = x["label"].clone() dataset.plot(x) plt.close() + + def test_plot_rgb(self, dataset: EuroSAT, tmp_path: Path) -> None: + dataset = EuroSAT(root=str(tmp_path), bands=("B03",)) + with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"): + dataset.plot(dataset[0], suptitle="Single Band") diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index e1e14072039..970e56c33b7 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -4,10 +4,11 @@ """EuroSAT dataset.""" import os -from typing import Callable, Dict, Optional, cast +from typing import Callable, Dict, Optional, Sequence, cast import matplotlib.pyplot as plt import numpy as np +import torch from torch import Tensor from .geo import VisionClassificationDataset @@ -83,10 +84,31 @@ class EuroSAT(VisionClassificationDataset): "Forest", ] + all_band_names = ( + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B08A", + "B09", + "B10", + "B11", + "B12", + ) + + RGB_BANDS = ("B04", "B03", "B02") + + BAND_SETS = {"all": all_band_names, "rgb": RGB_BANDS} + def __init__( self, root: str = "data", split: str = "train", + bands: Sequence[str] = BAND_SETS["all"], transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, @@ -96,19 +118,31 @@ def __init__( Args: root: root directory where dataset can be found split: one of "train", "val", or "test" + bands: a sequence of band names to load transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: + AssertionError: if ``split`` argument is invalid RuntimeError: if ``download=False`` and data is not found, or checksums don't match + """ self.root = root self.transforms = transforms self.download = download self.checksum = checksum + + assert split in ["train", "val", "test"] + + self._validate_bands(bands) + self.bands = bands + self.band_indices = Tensor( + [self.all_band_names.index(b) for b in bands if b in self.all_band_names] + ).long() + self._verify() valid_fns = set() @@ -124,6 +158,26 @@ def __init__( is_valid_file=is_in_split, ) + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + Returns: + data and label at that index + """ + image, label = self._load_image(index) + + image = torch.index_select( # type: ignore[attr-defined] + image, dim=0, index=self.band_indices + ) + sample = {"image": image, "label": label} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + def _check_integrity(self) -> bool: """Check integrity of dataset. @@ -184,6 +238,23 @@ def _extract(self) -> None: filepath = os.path.join(self.root, self.filename) extract_archive(filepath) + def _validate_bands(self, bands: Sequence[str]) -> None: + """Validate list of bands. + + Args: + bands: user-provided sequence of bands to load + + Raises: + AssertionError: if ``bands`` is not a sequence + ValueError: if an invalid band name is provided + + .. versionadded:: 0.3 + """ + assert isinstance(bands, Sequence), "'bands' must be a sequence" + for band in bands: + if band not in self.all_band_names: + raise ValueError(f"'{band}' is an invalid band name.") + def plot( self, sample: Dict[str, Tensor], @@ -200,9 +271,20 @@ def plot( Returns: a matplotlib Figure with the rendered sample + Raises: + ValueError: if RGB bands are not found in dataset + .. versionadded:: 0.2 """ - image = np.rollaxis(sample["image"][[3, 2, 1]].numpy(), 0, 3).copy() + rgb_indices = [] + for band in self.RGB_BANDS: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise ValueError("Dataset doesn't contain some of the RGB bands") + + image = np.take(sample["image"].numpy(), indices=rgb_indices, axis=0) + image = np.rollaxis(image, 0, 3) image = np.clip(image / 3000, 0, 1) label = cast(int, sample["label"].item())