diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 1e6131b60f2..08432e6bee6 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -138,6 +138,11 @@ RESISC45 (Remote Sensing Image Scene Classification) .. autoclass:: RESISC45 .. autoclass:: RESISC45DataModule +Seasonal Contrast +^^^^^^^^^^^^^^^^^ + +.. autoclass:: SeasonalContrastS2 + SEN12MS ^^^^^^^ diff --git a/tests/data/seco/seco_100k.zip b/tests/data/seco/seco_100k.zip new file mode 100644 index 00000000000..0e8c7db2761 Binary files /dev/null and b/tests/data/seco/seco_100k.zip differ diff --git a/tests/data/seco/seco_1m.zip b/tests/data/seco/seco_1m.zip new file mode 100644 index 00000000000..06e06364e79 Binary files /dev/null and b/tests/data/seco/seco_1m.zip differ diff --git a/tests/datasets/test_seco.py b/tests/datasets/test_seco.py new file mode 100644 index 00000000000..a95b43ad7f4 --- /dev/null +++ b/tests/datasets/test_seco.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import glob +import os +import shutil +from pathlib import Path +from typing import Generator + +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch +from torch.utils.data import ConcatDataset + +import torchgeo.datasets.utils +from torchgeo.datasets import SeasonalContrastS2 + + +def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: + shutil.copy(url, root) + + +class TestSeasonalContrastS2: + @pytest.fixture(params=zip(["100k", "1m"], [["B1"], SeasonalContrastS2.ALL_BANDS])) + def dataset( + self, + monkeypatch: Generator[MonkeyPatch, None, None], + tmp_path: Path, + request: SubRequest, + ) -> SeasonalContrastS2: + monkeypatch.setattr( # type: ignore[attr-defined] + torchgeo.datasets.seco, "download_url", download_url + ) + monkeypatch.setattr( # type: ignore[attr-defined] + SeasonalContrastS2, + "md5s", + { + "100k": "4d3e6e4afed7e581b7de1bfa2f7c29da", + "1m": "3bb3fcf90f5de7d5781ce0cb85fd20af", + }, + ) + monkeypatch.setattr( # type: ignore[attr-defined] + SeasonalContrastS2, + "urls", + { + "100k": os.path.join("tests", "data", "seco", "seco_100k.zip"), + "1m": os.path.join("tests", "data", "seco", "seco_1m.zip"), + }, + ) + root = str(tmp_path) + version, bands = request.param + transforms = nn.Identity() # type: ignore[attr-defined] + return SeasonalContrastS2( + root, version, bands, transforms, download=True, checksum=True + ) + + def test_getitem(self, dataset: SeasonalContrastS2) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + + def test_len(self, dataset: SeasonalContrastS2) -> None: + assert len(dataset) == 2 + + def test_add(self, dataset: SeasonalContrastS2) -> None: + ds = dataset + dataset + assert isinstance(ds, ConcatDataset) + assert len(ds) == 4 + + def test_already_extracted(self, dataset: SeasonalContrastS2) -> None: + SeasonalContrastS2(root=dataset.root, download=True) + + def test_already_downloaded(self, tmp_path: Path) -> None: + pathname = os.path.join("tests", "data", "seco", "*.zip") + root = str(tmp_path) + for zipfile in glob.iglob(pathname): + shutil.copy(zipfile, root) + SeasonalContrastS2(root) + + def test_invalid_version(self) -> None: + with pytest.raises(AssertionError): + SeasonalContrastS2(version="foo") + + def test_invalid_band(self) -> None: + with pytest.raises(AssertionError): + SeasonalContrastS2(bands=["A1steaksauce"]) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + SeasonalContrastS2(str(tmp_path)) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 2c60708d792..6a6e88ba958 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -55,6 +55,7 @@ from .nwpu import VHR10 from .patternnet import PatternNet from .resisc45 import RESISC45, RESISC45DataModule +from .seco import SeasonalContrastS2 from .sen12ms import SEN12MS, SEN12MSDataModule from .sentinel import Sentinel, Sentinel2 from .so2sat import So2Sat, So2SatDataModule @@ -113,6 +114,7 @@ "PatternNet", "RESISC45", "RESISC45DataModule", + "SeasonalContrastS2", "SEN12MS", "SEN12MSDataModule", "So2Sat", diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py new file mode 100644 index 00000000000..c6a38d888d8 --- /dev/null +++ b/torchgeo/datasets/seco.py @@ -0,0 +1,243 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Sentinel 2 imagery from the Seasonal Contrast paper.""" + +import os +from collections import defaultdict +from typing import Callable, Dict, List, Optional, cast + +import numpy as np +import rasterio +import torch +from PIL import Image +from torch import Tensor + +from .geo import VisionDataset +from .utils import download_url, extract_archive + + +class SeasonalContrastS2(VisionDataset): + """Sentinel 2 imagery from the Seasonal Contrast paper. + + The `Seasonal Contrast imagery `_ + dataset contains Sentinel 2 imagery patches sampled from different points in time + around the 10k most populated cities on Earth. + + Dataset features: + + * Two versions: 100K and 1M patches + * 12 band Sentinel 2 imagery from 5 points in time at each location + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/pdf/2103.16607.pdf + """ + + ALL_BANDS = [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8", + "B8A", + "B9", + "B11", + "B12", + ] + RGB_BANDS = ["B4", "B3", "B2"] + + urls = { + # 7.3 GB + "100k": "https://zenodo.org/record/4728033/files/seco_100k.zip?download=1", + # 36.3 GB + "1m": "https://zenodo.org/record/4728033/files/seco_1m.zip?download=1", + } + filenames = { + "100k": "seco_100k.zip", + "1m": "seco_1m.zip", + } + md5s = { + "100k": "ebf2d5e03adc6e657f9a69a20ad863e0", + "1m": "187963d852d4d3ce6637743ec3a4bd9e", + } + directory_names = { + "100k": "seasonal_contrast_100k", + "1m": "seasonal_contrast_1m", + } + + def __init__( + self, + root: str = "data", + version: str = "100k", + bands: List[str] = RGB_BANDS, + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new SeCo dataset instance. + + Args: + root: root directory where dataset can be found + version: one of "100k" or "1m" for the version of the dataset to use + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: if ``version`` argument is invalid + RuntimeError: if ``download=False`` and data is not found, or checksums + don't match + """ + assert version in ["100k", "1m"] + for band in bands: + assert band in self.ALL_BANDS + + self.root = root + self.bands = bands + self.url = self.urls[version] + self.filename = self.filenames[version] + self.md5 = self.md5s[version] + self.directory_name = self.directory_names[version] + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + # TODO: This is slow, I think this should be generated on download and then + # loaded in the constructor + self.scene_to_patches = defaultdict(list) + for root_directory, directories, fns in os.walk( + os.path.join(self.root, self.directory_name) + ): + if len(directories) == 0 and len(fns) > 0: + root_directory, patch_name = os.path.split(root_directory) + _, scene_name = os.path.split(root_directory) + self.scene_to_patches[scene_name].append(patch_name) + + self.scenes = sorted(self.scene_to_patches.keys()) + for scene_name in self.scenes: + self.scene_to_patches[scene_name] = sorted( + self.scene_to_patches[scene_name] + ) + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + sample with an "image" in 5xCxHxW format where the 5 indexes over the same + patch sampled from different points in time by the SeCo method + """ + scene_name = self.scenes[index] + patch_names = self.scene_to_patches[scene_name] + + imagery = [ + self._load_patch(scene_name, patch_name) for patch_name in patch_names + ] + + sample = {"image": torch.stack(imagery, dim=0)} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.scenes) + + def _load_patch(self, scene_name: str, patch_name: str) -> Tensor: + """Load a single image patch. + + Args: + scene_name: the name of the scene to load from, e.g. '019999' + patch_name: the name of the patch to load, e.g. + '20200713T075609_20200713T081050_T36QZH' + + Returns: + the image with the subset of bands specified by ``self.bands`` + """ + all_data = [] + for band in self.bands: + fn = os.path.join( + self.root, + self.directory_name, + scene_name, + patch_name, + f"{band}.tif", + ) + with rasterio.open(fn) as f: + band_data = f.read(1) + height, width = band_data.shape + assert height == width + size = height + if size < 264: + # TODO: PIL resize is much slower than cv2, we should check to see + # what could be sped up throughout later. There is also a potential + # slowdown here from converting to/from a PIL Image just to resize. + # https://gist.github.com/calebrob6/748045ac8d844154067b2eefa47de92f + pil_image = Image.fromarray(band_data) + band_data = np.array( + pil_image.resize((264, 264), resample=Image.BILINEAR) + ) + all_data.append(band_data) + image = torch.from_numpy( # type: ignore[attr-defined] + np.stack(all_data, axis=0) + ) + return cast(Tensor, image) + + def _verify(self) -> None: + """Verify the integrity of the dataset. + + Raises: + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + # Check if the extracted files already exist + directory_path = os.path.join(self.root, self.directory_name) + if os.path.exists(directory_path): + return + + # Check if the zip files have already been downloaded + zip_path = os.path.join(self.root, self.filename) + if os.path.exists(zip_path): + self._extract() + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.root}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automaticaly download the dataset." + ) + + # Download the dataset + self._download() + self._extract() + + def _download(self) -> None: + """Download the dataset.""" + download_url( + self.url, + self.root, + filename=self.filename, + md5=self.md5 if self.checksum else None, + ) + + def _extract(self) -> None: + """Extract the dataset.""" + extract_archive( + os.path.join(self.root, self.filename), + )