diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 70c88928fc0..7e04af7a20c 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -297,6 +297,11 @@ RESISC45 .. autoclass:: RESISC45 +Rwanda Field Boundary +^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: RwandaFieldBoundary + Seasonal Contrast ^^^^^^^^^^^^^^^^^ diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index cf6c8ebec8d..1b745544dca 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -26,6 +26,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands `Potsdam`_,S,Aerial,38,6,"6,000x6,000",0.05,MSI `ReforesTree`_,"OD, R",Aerial,100,6,"4,000x4,000",0.02,RGB `RESISC45`_,C,Google Earth,"31,500",45,256x256,0.2--30,RGB +`Rwanda Field Boundary`_,S,Planetscope,70,2,256x256,4.7,RGB + NIR `Seasonal Contrast`_,T,Sentinel-2,100K--1M,-,264x264,10,MSI `SeasoNet`_,S,Sentinel-2,"1,759,830",33,120x120,10,MSI `SEN12MS`_,S,"Sentinel-1/2, MODIS","180,662",33,256x256,10,"SAR, MSI" diff --git a/tests/data/rwanda_field_boundary/data.py b/tests/data/rwanda_field_boundary/data.py new file mode 100644 index 00000000000..7a23b385cf6 --- /dev/null +++ b/tests/data/rwanda_field_boundary/data.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil + +import numpy as np +import rasterio + +dates = ("2021_03", "2021_04", "2021_08", "2021_10", "2021_11", "2021_12") +all_bands = ("B01", "B02", "B03", "B04") + +SIZE = 32 +NUM_SAMPLES = 5 +np.random.seed(0) + + +def create_mask(fn: str) -> None: + profile = { + "driver": "GTiff", + "dtype": "uint8", + "nodata": 0.0, + "width": SIZE, + "height": SIZE, + "count": 1, + "crs": "epsg:3857", + "compress": "lzw", + "predictor": 2, + "transform": rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0), + "blockysize": 32, + "tiled": False, + "interleave": "band", + } + with rasterio.open(fn, "w", **profile) as f: + f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint8), 1) + + +def create_img(fn: str) -> None: + profile = { + "driver": "GTiff", + "dtype": "uint16", + "nodata": 0.0, + "width": SIZE, + "height": SIZE, + "count": 1, + "crs": "epsg:3857", + "compress": "lzw", + "predictor": 2, + "blockysize": 16, + "transform": rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0), + "tiled": False, + "interleave": "band", + } + with rasterio.open(fn, "w", **profile) as f: + f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint16), 1) + + +if __name__ == "__main__": + # Train and test images + for split in ("train", "test"): + for i in range(NUM_SAMPLES): + for date in dates: + directory = os.path.join( + f"nasa_rwanda_field_boundary_competition_source_{split}", + f"nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}", # noqa: E501 + ) + os.makedirs(directory, exist_ok=True) + for band in all_bands: + create_img(os.path.join(directory, f"{band}.tif")) + + # Create collections.json, this isn't used by the dataset but is checked to + # exist + with open( + f"nasa_rwanda_field_boundary_competition_source_{split}/collections.json", + "w", + ) as f: + f.write("Not used") + + # Train labels + for i in range(NUM_SAMPLES): + directory = os.path.join( + "nasa_rwanda_field_boundary_competition_labels_train", + f"nasa_rwanda_field_boundary_competition_labels_train_{i:02d}", + ) + os.makedirs(directory, exist_ok=True) + create_mask(os.path.join(directory, "raster_labels.tif")) + + # Create directories and compute checksums + for filename in [ + "nasa_rwanda_field_boundary_competition_source_train", + "nasa_rwanda_field_boundary_competition_source_test", + "nasa_rwanda_field_boundary_competition_labels_train", + ]: + shutil.make_archive(filename, "gztar", ".", filename) + # Compute checksums + with open(f"{filename}.tar.gz", "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{filename}: {md5}") diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz new file mode 100644 index 00000000000..ffa98bb53d6 Binary files /dev/null and b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz differ diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz new file mode 100644 index 00000000000..a834f66bf38 Binary files /dev/null and b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz differ diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz new file mode 100644 index 00000000000..8239f70c200 Binary files /dev/null and b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz differ diff --git a/tests/datasets/test_rwanda_field_boundary.py b/tests/datasets/test_rwanda_field_boundary.py new file mode 100644 index 00000000000..e0736b32e7c --- /dev/null +++ b/tests/datasets/test_rwanda_field_boundary.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import glob +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch +from torch.utils.data import ConcatDataset + +from torchgeo.datasets import RwandaFieldBoundary + + +class Collection: + def download(self, output_dir: str, **kwargs: str) -> None: + glob_path = os.path.join("tests", "data", "rwanda_field_boundary", "*.tar.gz") + for tarball in glob.iglob(glob_path): + shutil.copy(tarball, output_dir) + + +def fetch(dataset_id: str, **kwargs: str) -> Collection: + return Collection() + + +class TestRwandaFieldBoundary: + @pytest.fixture(params=["train", "test"]) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> RwandaFieldBoundary: + radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3") + monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) + monkeypatch.setattr( + RwandaFieldBoundary, "number_of_patches_per_split", {"train": 5, "test": 5} + ) + monkeypatch.setattr( + RwandaFieldBoundary, + "md5s", + { + "train_images": "af9395e2e49deefebb35fa65fa378ba3", + "test_images": "d104bb82323a39e7c3b3b7dd0156f550", + "train_labels": "6cceaf16a141cf73179253a783e7d51b", + }, + ) + + root = str(tmp_path) + split = request.param + transforms = nn.Identity() + return RwandaFieldBoundary( + root, split, transforms=transforms, api_key="", download=True, checksum=True + ) + + def test_getitem(self, dataset: RwandaFieldBoundary) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + if dataset.split == "train": + assert isinstance(x["mask"], torch.Tensor) + else: + assert "mask" not in x + + def test_len(self, dataset: RwandaFieldBoundary) -> None: + assert len(dataset) == 5 + + def test_add(self, dataset: RwandaFieldBoundary) -> None: + ds = dataset + dataset + assert isinstance(ds, ConcatDataset) + assert len(ds) == 10 + + def test_needs_extraction(self, tmp_path: Path) -> None: + root = str(tmp_path) + for fn in [ + "nasa_rwanda_field_boundary_competition_source_train.tar.gz", + "nasa_rwanda_field_boundary_competition_source_test.tar.gz", + "nasa_rwanda_field_boundary_competition_labels_train.tar.gz", + ]: + url = os.path.join("tests", "data", "rwanda_field_boundary", fn) + shutil.copy(url, root) + RwandaFieldBoundary(root, checksum=False) + + def test_already_downloaded(self, dataset: RwandaFieldBoundary) -> None: + RwandaFieldBoundary(root=dataset.root) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found in"): + RwandaFieldBoundary(str(tmp_path)) + + def test_corrupted(self, tmp_path: Path) -> None: + for fn in [ + "nasa_rwanda_field_boundary_competition_source_train.tar.gz", + "nasa_rwanda_field_boundary_competition_source_test.tar.gz", + "nasa_rwanda_field_boundary_competition_labels_train.tar.gz", + ]: + with open(os.path.join(tmp_path, fn), "w") as f: + f.write("bad") + with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + RwandaFieldBoundary(root=str(tmp_path), checksum=True) + + def test_failed_download(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: + radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3") + monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) + monkeypatch.setattr( + RwandaFieldBoundary, + "md5s", + {"train_images": "bad", "test_images": "bad", "train_labels": "bad"}, + ) + root = str(tmp_path) + with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + RwandaFieldBoundary(root, "train", api_key="", download=True, checksum=True) + + def test_no_api_key(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Must provide an API key to download"): + RwandaFieldBoundary(str(tmp_path), api_key=None, download=True) + + def test_invalid_bands(self) -> None: + with pytest.raises(ValueError, match="is an invalid band name."): + RwandaFieldBoundary(bands=("foo", "bar")) + + def test_plot(self, dataset: RwandaFieldBoundary) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + + if dataset.split == "train": + x["prediction"] = x["mask"].clone() + dataset.plot(x) + plt.close() + + def test_failed_plot(self, dataset: RwandaFieldBoundary) -> None: + single_band_dataset = RwandaFieldBoundary(root=dataset.root, bands=("B01",)) + with pytest.raises(ValueError, match="Dataset doesn't contain"): + x = single_band_dataset[0].copy() + single_band_dataset.plot(x, suptitle="Test") diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 166ef64f9ea..4d32728349a 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -83,6 +83,7 @@ from .potsdam import Potsdam2D from .reforestree import ReforesTree from .resisc45 import RESISC45 +from .rwanda_field_boundary import RwandaFieldBoundary from .seasonet import SeasoNet from .seco import SeasonalContrastS2 from .sen12ms import SEN12MS @@ -201,6 +202,7 @@ "Potsdam2D", "RESISC45", "ReforesTree", + "RwandaFieldBoundary", "SeasonalContrastS2", "SeasoNet", "SEN12MS", diff --git a/torchgeo/datasets/rwanda_field_boundary.py b/torchgeo/datasets/rwanda_field_boundary.py new file mode 100644 index 00000000000..8d40960f8da --- /dev/null +++ b/torchgeo/datasets/rwanda_field_boundary.py @@ -0,0 +1,328 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Rwanda Field Boundary Competition dataset.""" + +import os +from collections.abc import Sequence +from typing import Callable, Optional + +import matplotlib.pyplot as plt +import numpy as np +import rasterio +import rasterio.features +import torch +from matplotlib.figure import Figure +from torch import Tensor + +from .geo import NonGeoDataset +from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive + + +class RwandaFieldBoundary(NonGeoDataset): + r"""Rwanda Field Boundary Competition dataset. + + This dataset contains field boundaries for smallholder farms in eastern Rwanda. + The Nasa Harvest program funded a team of annotators from TaQadam to label Planet + imagery for the 2021 growing season for the purpose of conducting the Rwanda Field + boundary detection Challenge. The dataset includes rasterized labeled field + boundaries and time series satellite imagery from Planet's NICFI program. + Planet's basemap imagery is provided for six months (March, April, August, October, + November and December). Note: only fields that were big enough to be differentiated + on the Planetscope imagery were labeled, only fields that were fully contained + within the chips were labeled. The paired dataset is provided in 256x256 chips for a + total of 70 tiles covering 1532 individual fields. + + The labels are provided as binary semantic segmentation labels: + + 0. No field-boundary + 1. Field-boundary + + If you use this dataset in your research, please cite the following: + + * https://doi.org/10.34911/RDNT.G580WW + + .. note:: + + This dataset requires the following additional library to be installed: + + * `radiant-mlhub `_ to download the + imagery and labels from the Radiant Earth MLHub + + .. versionadded:: 0.5 + """ + + dataset_id = "nasa_rwanda_field_boundary_competition" + collection_ids = [ + "nasa_rwanda_field_boundary_competition_source_train", + "nasa_rwanda_field_boundary_competition_labels_train", + "nasa_rwanda_field_boundary_competition_source_test", + ] + number_of_patches_per_split = {"train": 57, "test": 13} + + filenames = { + "train_images": "nasa_rwanda_field_boundary_competition_source_train.tar.gz", + "test_images": "nasa_rwanda_field_boundary_competition_source_test.tar.gz", + "train_labels": "nasa_rwanda_field_boundary_competition_labels_train.tar.gz", + } + md5s = { + "train_images": "1f9ec08038218e67e11f82a86849b333", + "test_images": "17bb0e56eedde2e7a43c57aa908dc125", + "train_labels": "10e4eb761523c57b6d3bdf9394004f5f", + } + + dates = ("2021_03", "2021_04", "2021_08", "2021_10", "2021_11", "2021_12") + + all_bands = ("B01", "B02", "B03", "B04") + rgb_bands = ("B03", "B02", "B01") + + classes = ["No field-boundary", "Field-boundary"] + + splits = ["train", "test"] + + def __init__( + self, + root: str = "data", + split: str = "train", + bands: Sequence[str] = all_bands, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + download: bool = False, + api_key: Optional[str] = None, + checksum: bool = False, + ) -> None: + """Initialize a new RwandaFieldBoundary instance. + + Args: + root: root directory where dataset can be found + split: one of "train" or "test" + bands: the subset of bands to load + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + api_key: a RadiantEarth MLHub API key to use for downloading the dataset + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + or if ``download=True`` and ``api_key=None`` + """ + self._validate_bands(bands) + assert split in self.splits + if download and api_key is None: + raise RuntimeError("Must provide an API key to download the dataset") + self.root = os.path.expanduser(root) + self.bands = bands + self.transforms = transforms + self.split = split + self.download = download + self.api_key = api_key + self.checksum = checksum + self._verify() + + self.image_filenames: list[list[list[str]]] = [] + self.mask_filenames: list[str] = [] + for i in range(self.number_of_patches_per_split[split]): + dates = [] + for date in self.dates: + patch = [] + for band in self.bands: + fn = os.path.join( + self.root, + f"nasa_rwanda_field_boundary_competition_source_{split}", + f"nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}", # noqa: E501 + f"{band}.tif", + ) + patch.append(fn) + dates.append(patch) + self.image_filenames.append(dates) + self.mask_filenames.append( + os.path.join( + self.root, + f"nasa_rwanda_field_boundary_competition_labels_{split}", + f"nasa_rwanda_field_boundary_competition_labels_{split}_{i:02d}", + "raster_labels.tif", + ) + ) + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + a dict containing image, mask, transform, crs, and metadata at index. + """ + img_fns = self.image_filenames[index] + mask_fn = self.mask_filenames[index] + + imgs = [] + for date_fns in img_fns: + bands = [] + for band_fn in date_fns: + with rasterio.open(band_fn) as f: + bands.append(f.read(1).astype(np.int32)) + imgs.append(bands) + img = torch.from_numpy(np.array(imgs)) + + sample = {"image": img} + + if self.split == "train": + with rasterio.open(mask_fn) as f: + mask = f.read(1) + mask = torch.from_numpy(mask) + sample["mask"] = mask + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of chips in the dataset. + + Returns: + length of the dataset + """ + return len(self.image_filenames) + + def _validate_bands(self, bands: Sequence[str]) -> None: + """Validate list of bands. + + Args: + bands: user-provided sequence of bands to load + + Raises: + ValueError: if an invalid band name is provided + """ + for band in bands: + if band not in self.all_bands: + raise ValueError(f"'{band}' is an invalid band name.") + + def _verify(self) -> None: + """Verify the integrity of the dataset. + + Raises: + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + # Check if the subdirectories already exist and have the correct number of files + checks = [] + for split, num_patches in self.number_of_patches_per_split.items(): + path = os.path.join( + self.root, f"nasa_rwanda_field_boundary_competition_source_{split}" + ) + if os.path.exists(path): + num_files = len(os.listdir(path)) + # 6 dates + 1 collection.json file + checks.append(num_files == (num_patches * 6) + 1) + else: + checks.append(False) + + if all(checks): + return + + # Check if tar file already exists (if so then extract) + have_all_files = True + for group in ["train_images", "train_labels", "test_images"]: + filepath = os.path.join(self.root, self.filenames[group]) + if os.path.exists(filepath): + if self.checksum and not check_integrity(filepath, self.md5s[group]): + raise RuntimeError("Dataset found, but corrupted.") + extract_archive(filepath) + else: + have_all_files = False + if have_all_files: + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.root}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automatically download the dataset." + ) + + # Download and extract the dataset + self._download() + + def _download(self) -> None: + """Download the dataset and extract it. + + Raises: + RuntimeError: if download doesn't work correctly or checksums don't match + """ + for collection_id in self.collection_ids: + download_radiant_mlhub_collection(collection_id, self.root, self.api_key) + + for group in ["train_images", "train_labels", "test_images"]: + filepath = os.path.join(self.root, self.filenames[group]) + if self.checksum and not check_integrity(filepath, self.md5s[group]): + raise RuntimeError("Dataset not found or corrupted.") + extract_archive(filepath, self.root) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + time_step: int = 0, + suptitle: Optional[str] = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + time_step: time step at which to access image, beginning with 0 + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + Raises: + ValueError: if the RGB bands are not included in ``self.bands`` + """ + rgb_indices = [] + for band in self.rgb_bands: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise ValueError("Dataset doesn't contain some of the RGB bands") + + num_time_points = sample["image"].shape[0] + assert time_step < num_time_points + + image = np.rollaxis(sample["image"][time_step, rgb_indices].numpy(), 0, 3) + image = np.clip(image / 2000, 0, 1) + + if "mask" in sample: + mask = sample["mask"].numpy() + else: + mask = np.zeros_like(image) + + num_panels = 2 + showing_predictions = "prediction" in sample + if showing_predictions: + predictions = sample["prediction"].numpy() + num_panels += 1 + + fig, axs = plt.subplots(ncols=num_panels, figsize=(4 * num_panels, 4)) + + axs[0].imshow(image) + axs[0].axis("off") + if show_titles: + axs[0].set_title(f"t={time_step}") + + axs[1].imshow(mask, vmin=0, vmax=1, interpolation="none") + axs[1].axis("off") + if show_titles: + axs[1].set_title("Mask") + + if showing_predictions: + axs[2].imshow(predictions, vmin=0, vmax=1, interpolation="none") + axs[2].axis("off") + if show_titles: + axs[2].set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + return fig