diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 9d5da998add..abc8465c625 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -222,6 +222,11 @@ UC Merced .. autoclass:: UCMerced +USAVars +^^^^^^^ + +.. autoclass:: USAVars + Vaihingen ^^^^^^^^^ diff --git a/tests/data/usavars/data.py b/tests/data/usavars/data.py new file mode 100644 index 00000000000..c392887b04e --- /dev/null +++ b/tests/data/usavars/data.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import glob +import hashlib +import os +import shutil + +import numpy as np +import pandas as pd +import rasterio + +data_dir = "uar" +labels = [ + "elevation", + "population", + "treecover", + "income", + "nightlights", + "housing", + "roads", +] +SIZE = 3 + + +def create_file(path: str, dtype: str, num_channels: int) -> None: + profile = {} + profile["driver"] = "GTiff" + profile["dtype"] = dtype + profile["count"] = num_channels + profile["crs"] = "epsg:4326" + profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile["height"] = SIZE + profile["width"] = SIZE + profile["compress"] = "lzw" + profile["predictor"] = 2 + + Z = np.random.randint( + np.iinfo(profile["dtype"]).max, size=(4, SIZE, SIZE), dtype=profile["dtype"] + ) + src = rasterio.open(path, "w", **profile) + src.write(Z) + + +# Remove old data +filename = f"{data_dir}.zip" +csvs = glob.glob("*.csv") + +for csv in csvs: + os.remove(csv) +if os.path.exists(filename): + os.remove(filename) +if os.path.exists(data_dir): + shutil.rmtree(data_dir) + +# Create tifs: +os.makedirs(data_dir) +create_file(os.path.join(data_dir, "tile_0,0.tif"), np.uint8, 4) +create_file(os.path.join(data_dir, "tile_0,1.tif"), np.uint8, 4) + +# Create labels: +columns = [["ID", "lon", "lat", lab] for lab in labels] +fake_vals = [["0,0", 0.0, 0.0, 0.0], ["0,1", 0.1, 0.1, 1.0]] +for lab, cols in zip(labels, columns): + df = pd.DataFrame(fake_vals, columns=cols) + df.to_csv(lab + ".csv") + +# Compress data +shutil.make_archive(data_dir, "zip", ".", data_dir) + +# Compute checksums +filename = f"{data_dir}.zip" +with open(filename, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(repr(filename) + ": " + repr(md5) + ",") diff --git a/tests/data/usavars/elevation.csv b/tests/data/usavars/elevation.csv new file mode 100644 index 00000000000..9e10a3141e9 --- /dev/null +++ b/tests/data/usavars/elevation.csv @@ -0,0 +1,3 @@ +,ID,lon,lat,elevation +0,"0,0",0.0,0.0,0.0 +1,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/housing.csv b/tests/data/usavars/housing.csv new file mode 100644 index 00000000000..a99a21131ac --- /dev/null +++ b/tests/data/usavars/housing.csv @@ -0,0 +1,3 @@ +,ID,lon,lat,housing +0,"0,0",0.0,0.0,0.0 +1,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/income.csv b/tests/data/usavars/income.csv new file mode 100644 index 00000000000..3ac6b3d2f05 --- /dev/null +++ b/tests/data/usavars/income.csv @@ -0,0 +1,3 @@ +,ID,lon,lat,income +0,"0,0",0.0,0.0,0.0 +1,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/nightlights.csv b/tests/data/usavars/nightlights.csv new file mode 100644 index 00000000000..9a55b7a8838 --- /dev/null +++ b/tests/data/usavars/nightlights.csv @@ -0,0 +1,3 @@ +,ID,lon,lat,nightlights +0,"0,0",0.0,0.0,0.0 +1,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/population.csv b/tests/data/usavars/population.csv new file mode 100644 index 00000000000..5c53019d3fb --- /dev/null +++ b/tests/data/usavars/population.csv @@ -0,0 +1,3 @@ +,ID,lon,lat,population +0,"0,0",0.0,0.0,0.0 +1,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/roads.csv b/tests/data/usavars/roads.csv new file mode 100644 index 00000000000..352db63ec1b --- /dev/null +++ b/tests/data/usavars/roads.csv @@ -0,0 +1,3 @@ +,ID,lon,lat,roads +0,"0,0",0.0,0.0,0.0 +1,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/treecover.csv b/tests/data/usavars/treecover.csv new file mode 100644 index 00000000000..eb3d5030c81 --- /dev/null +++ b/tests/data/usavars/treecover.csv @@ -0,0 +1,3 @@ +,ID,lon,lat,treecover +0,"0,0",0.0,0.0,0.0 +1,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/uar.zip b/tests/data/usavars/uar.zip new file mode 100644 index 00000000000..a6e122f80ab Binary files /dev/null and b/tests/data/usavars/uar.zip differ diff --git a/tests/data/usavars/uar/tile_0,0.tif b/tests/data/usavars/uar/tile_0,0.tif new file mode 100644 index 00000000000..372474660d3 Binary files /dev/null and b/tests/data/usavars/uar/tile_0,0.tif differ diff --git a/tests/data/usavars/uar/tile_0,1.tif b/tests/data/usavars/uar/tile_0,1.tif new file mode 100644 index 00000000000..c9703a26850 Binary files /dev/null and b/tests/data/usavars/uar/tile_0,1.tif differ diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py new file mode 100644 index 00000000000..5b203d5937d --- /dev/null +++ b/tests/datasets/test_usavars.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import builtins +import os +import shutil +from pathlib import Path +from typing import Any, Generator + +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch +from matplotlib import pyplot as plt +from torch.utils.data import ConcatDataset + +import torchgeo.datasets.utils +from torchgeo.datasets import USAVars + +pytest.importorskip("pandas", minversion="0.19.1") + + +def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: + shutil.copy(url, root) + + +class TestUSAVars: + @pytest.fixture() + def dataset( + self, + monkeypatch: Generator[MonkeyPatch, None, None], + tmp_path: Path, + request: SubRequest, + ) -> USAVars: + + monkeypatch.setattr( # type: ignore[attr-defined] + torchgeo.datasets.usavars, "download_url", download_url + ) + + md5 = "b504580a00bdc27097d5421dec50481b" + monkeypatch.setattr(USAVars, "md5", md5) # type: ignore[attr-defined] + + data_url = os.path.join("tests", "data", "usavars", "uar.zip") + monkeypatch.setattr(USAVars, "data_url", data_url) # type: ignore[attr-defined] + + label_urls = { + "elevation": os.path.join("tests", "data", "usavars", "elevation.csv"), + "population": os.path.join("tests", "data", "usavars", "population.csv"), + "treecover": os.path.join("tests", "data", "usavars", "treecover.csv"), + "income": os.path.join("tests", "data", "usavars", "income.csv"), + "nightlights": os.path.join("tests", "data", "usavars", "nightlights.csv"), + "roads": os.path.join("tests", "data", "usavars", "roads.csv"), + "housing": os.path.join("tests", "data", "usavars", "housing.csv"), + } + monkeypatch.setattr( # type: ignore[attr-defined] + USAVars, "label_urls", label_urls + ) + + root = str(tmp_path) + transforms = nn.Identity() # type: ignore[attr-defined] + + return USAVars(root, transforms=transforms, download=True, checksum=True) + + def test_getitem(self, dataset: USAVars) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert x["image"].ndim == 3 + assert len(x.keys()) == 2 # image, elevation, population, treecover + assert x["image"].shape[0] == 4 # R, G, B, Inf + + def test_len(self, dataset: USAVars) -> None: + assert len(dataset) == 2 + + def test_add(self, dataset: USAVars) -> None: + ds = dataset + dataset + assert isinstance(ds, ConcatDataset) + + def test_already_extracted(self, dataset: USAVars) -> None: + USAVars(root=dataset.root, download=True) + + def test_already_downloaded(self, tmp_path: Path) -> None: + pathname = os.path.join("tests", "data", "usavars", "uar.zip") + root = str(tmp_path) + shutil.copy(pathname, root) + csvs = [ + "elevation.csv", + "population.csv", + "treecover.csv", + "income.csv", + "nightlights.csv", + "roads.csv", + "housing.csv", + ] + for csv in csvs: + shutil.copy(os.path.join("tests", "data", "usavars", csv), root) + + USAVars(root) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + USAVars(str(tmp_path)) + + @pytest.fixture(params=["pandas"]) + def mock_missing_module( + self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest + ) -> str: + import_orig = builtins.__import__ + package = str(request.param) + + def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == package: + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr( # type: ignore[attr-defined] + builtins, "__import__", mocked_import + ) + return package + + def test_mock_missing_module( + self, dataset: USAVars, mock_missing_module: str + ) -> None: + package = mock_missing_module + if package == "pandas": + with pytest.raises( + ImportError, + match=f"{package} is not installed and is required to use this dataset", + ): + USAVars(dataset.root) + + def test_plot(self, dataset: USAVars) -> None: + dataset.plot(dataset[0], suptitle="Test") + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 548f7dc0d0b..d2ac09a16b4 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -72,6 +72,7 @@ from .so2sat import So2Sat from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7 from .ucmerced import UCMerced +from .usavars import USAVars from .utils import ( BoundingBox, concat_samples, @@ -149,6 +150,7 @@ "SpaceNet7", "TropicalCycloneWindEstimation", "UCMerced", + "USAVars", "Vaihingen2D", "VHR10", "XView2", diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py new file mode 100644 index 00000000000..9d8118efddb --- /dev/null +++ b/torchgeo/datasets/usavars.py @@ -0,0 +1,257 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""USAVars dataset.""" + +import glob +import os +from typing import Callable, Dict, List, Optional, Sequence + +import matplotlib.pyplot as plt +import numpy as np +import rasterio +import torch +from matplotlib.figure import Figure +from torch import Tensor + +from .geo import VisionDataset +from .utils import download_url, extract_archive + + +class USAVars(VisionDataset): + """USAVars dataset. + + The USAVars dataset is reproduction of the dataset used in the paper "`A + generalizable and accessible approach to machine learning with global satellite + imagery <https://doi.org/10.1038/s41467-021-24638-z>`_". Specifically, this dataset + includes 1 sq km. crops of NAIP imagery resampled to 4m/px cenetered on ~100k points + that are sampled randomly from the contiguous states in the USA. Each point contains + three continous valued labels (taken from the dataset released in the paper): tree + cover percentage, elevation, and population density. + + Dataset format: + * images are 4-channel GeoTIFFs + * labels are singular float values + + Dataset labels: + * tree cover + * elevation + * population density + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.1038/s41467-021-24638-z + + .. versionadded:: 0.3 + """ + + url_prefix = ( + "https://files.codeocean.com/files/verified/" + + "fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications" + ) + pop_csv_suffix = "CONTUS_16_640_POP_100000_0.csv?download" + uar_csv_suffix = "CONTUS_16_640_UAR_100000_0.csv?download" + + data_url = "https://mosaiks.blob.core.windows.net/datasets/uar.zip" + dirname = "uar" + + md5 = "677e89fd20e5dd0fe4d29b61827c2456" + + label_urls = { + "housing": f"{url_prefix}/housing/outcomes_sampled_housing_{pop_csv_suffix}", + "income": f"{url_prefix}/income/outcomes_sampled_income_{pop_csv_suffix}", + "roads": f"{url_prefix}/roads/outcomes_sampled_roads_{pop_csv_suffix}", + "nightlights": f"{url_prefix}/nightlights/" + + f"outcomes_sampled_nightlights_{pop_csv_suffix}", + "population": f"{url_prefix}/population/" + + f"outcomes_sampled_population_{uar_csv_suffix}", + "elevation": f"{url_prefix}/elevation/" + + f"outcomes_sampled_elevation_{uar_csv_suffix}", + "treecover": f"{url_prefix}/treecover/" + + f"outcomes_sampled_treecover_{uar_csv_suffix}", + } + + ALL_LABELS = list(label_urls.keys()) + + def __init__( + self, + root: str = "data", + labels: Sequence[str] = ALL_LABELS, + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new USAVars dataset instance. + + Args: + root: root directory where dataset can be found + labels: list of labels to include + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: if invalid labels are provided + ImportError: if pandas is not installed + RuntimeError: if ``download=False`` and data is not found, or checksums + don't match + """ + self.root = root + + for lab in labels: + assert lab in self.ALL_LABELS + + self.labels = labels + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + try: + import pandas as pd # noqa: F401 + except ImportError: + raise ImportError( + "pandas is not installed and is required to use this dataset" + ) + + self.files = self._load_files() + + self.label_dfs = dict( + [ + ( + lab, + pd.read_csv(os.path.join(self.root, lab + ".csv"), index_col="ID"), + ) + for lab in self.labels + ] + ) + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + tif_file = self.files[index] + id_ = tif_file[5:-4] + + sample = { + "labels": Tensor( + [self.label_dfs[lab].loc[id_][lab] for lab in self.labels] + ), + "image": self._load_image(os.path.join(self.root, "uar", tif_file)), + } + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.files) + + def _load_files(self) -> List[str]: + """Loads file names.""" + file_path = os.path.join(self.root, "uar") + files = os.listdir(file_path) + return files + + def _load_image(self, path: str) -> Tensor: + """Load a single image. + + Args: + path: path to the image + + Returns: + the image + """ + with rasterio.open(path) as f: + array: "np.typing.NDArray[np.int_]" = f.read() + tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] + return tensor + + def _verify(self) -> None: + """Verify the integrity of the dataset. + + Raises: + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + # Check if the extracted files already exist + pathname = os.path.join(self.root, "uar") + csv_pathname = os.path.join(self.root, "*.csv") + + if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: + return + + # Check if the zip files have already been downloaded + pathname = os.path.join(self.root, self.dirname + ".zip") + if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: + self._extract() + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.root}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automaticaly download the dataset." + ) + + self._download() + self._extract() + + def _download(self) -> None: + """Download the dataset.""" + for f_name in self.label_urls: + download_url(self.label_urls[f_name], self.root, filename=f_name + ".csv") + + download_url(self.data_url, self.root, md5=self.md5 if self.checksum else None) + + def _extract(self) -> None: + """Extract the dataset.""" + extract_archive(os.path.join(self.root, self.dirname + ".zip")) + + def plot( + self, + sample: Dict[str, Tensor], + show_labels: bool = True, + suptitle: Optional[str] = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_labels: flag indicating whether to show labels above panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + image = sample["image"][:3].numpy() # get RGB inds + image = np.moveaxis(image, 0, 2) + + fig, axs = plt.subplots(figsize=(10, 10)) + axs.imshow(image) + axs.axis("off") + + if show_labels: + labels = [(lab, val) for lab, val in sample.items() if lab != "image"] + label_string = "" + for lab, val in labels: + label_string += f"{lab}={round(val[0].item(), 2)} " + axs.set_title(label_string) + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig