Skip to content

Commit

Permalink
Add band selection to EuroSat and adapt plot method (microsoft#397)
Browse files Browse the repository at this point in the history
* add band selection and adapt plot method to rgb

* keep normalization in plotting method
  • Loading branch information
nilsleh authored Feb 20, 2022
1 parent b4d2014 commit 5f15b83
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 3 deletions.
17 changes: 16 additions & 1 deletion tests/datasets/test_eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,24 @@ def dataset(
root = str(tmp_path)
split = request.param
transforms = nn.Identity() # type: ignore[attr-defined]
return EuroSAT(root, split, transforms, download=True, checksum=True)
return EuroSAT(
root=root, split=split, transforms=transforms, download=True, checksum=True
)

def test_getitem(self, dataset: EuroSAT) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
EuroSAT(split="foo")

def test_invalid_bands(self) -> None:
with pytest.raises(ValueError):
EuroSAT(bands=("OK", "BK"))

def test_len(self, dataset: EuroSAT) -> None:
assert len(dataset) == 2

Expand Down Expand Up @@ -100,3 +110,8 @@ def test_plot(self, dataset: EuroSAT) -> None:
x["prediction"] = x["label"].clone()
dataset.plot(x)
plt.close()

def test_plot_rgb(self, dataset: EuroSAT, tmp_path: Path) -> None:
dataset = EuroSAT(root=str(tmp_path), bands=("B03",))
with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"):
dataset.plot(dataset[0], suptitle="Single Band")
86 changes: 84 additions & 2 deletions torchgeo/datasets/eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
"""EuroSAT dataset."""

import os
from typing import Callable, Dict, Optional, cast
from typing import Callable, Dict, Optional, Sequence, cast

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import Tensor

from .geo import VisionClassificationDataset
Expand Down Expand Up @@ -83,10 +84,31 @@ class EuroSAT(VisionClassificationDataset):
"Forest",
]

all_band_names = (
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B08A",
"B09",
"B10",
"B11",
"B12",
)

RGB_BANDS = ("B04", "B03", "B02")

BAND_SETS = {"all": all_band_names, "rgb": RGB_BANDS}

def __init__(
self,
root: str = "data",
split: str = "train",
bands: Sequence[str] = BAND_SETS["all"],
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
Expand All @@ -96,19 +118,31 @@ def __init__(
Args:
root: root directory where dataset can be found
split: one of "train", "val", or "test"
bands: a sequence of band names to load
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
AssertionError: if ``split`` argument is invalid
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
"""
self.root = root
self.transforms = transforms
self.download = download
self.checksum = checksum

assert split in ["train", "val", "test"]

self._validate_bands(bands)
self.bands = bands
self.band_indices = Tensor(
[self.all_band_names.index(b) for b in bands if b in self.all_band_names]
).long()

self._verify()

valid_fns = set()
Expand All @@ -124,6 +158,26 @@ def __init__(
is_valid_file=is_in_split,
)

def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
image, label = self._load_image(index)

image = torch.index_select( # type: ignore[attr-defined]
image, dim=0, index=self.band_indices
)
sample = {"image": image, "label": label}

if self.transforms is not None:
sample = self.transforms(sample)

return sample

def _check_integrity(self) -> bool:
"""Check integrity of dataset.
Expand Down Expand Up @@ -184,6 +238,23 @@ def _extract(self) -> None:
filepath = os.path.join(self.root, self.filename)
extract_archive(filepath)

def _validate_bands(self, bands: Sequence[str]) -> None:
"""Validate list of bands.
Args:
bands: user-provided sequence of bands to load
Raises:
AssertionError: if ``bands`` is not a sequence
ValueError: if an invalid band name is provided
.. versionadded:: 0.3
"""
assert isinstance(bands, Sequence), "'bands' must be a sequence"
for band in bands:
if band not in self.all_band_names:
raise ValueError(f"'{band}' is an invalid band name.")

def plot(
self,
sample: Dict[str, Tensor],
Expand All @@ -200,9 +271,20 @@ def plot(
Returns:
a matplotlib Figure with the rendered sample
Raises:
ValueError: if RGB bands are not found in dataset
.. versionadded:: 0.2
"""
image = np.rollaxis(sample["image"][[3, 2, 1]].numpy(), 0, 3).copy()
rgb_indices = []
for band in self.RGB_BANDS:
if band in self.bands:
rgb_indices.append(self.bands.index(band))
else:
raise ValueError("Dataset doesn't contain some of the RGB bands")

image = np.take(sample["image"].numpy(), indices=rgb_indices, axis=0)
image = np.rollaxis(image, 0, 3)
image = np.clip(image / 3000, 0, 1)

label = cast(int, sample["label"].item())
Expand Down

0 comments on commit 5f15b83

Please sign in to comment.