forked from microsoft/torchgeo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add AsterGDEM dataset (microsoft#404)
* add astergdem dataset * add astergdem dataset * add plot method * typo * fix docs * requested changes * Update docs/api/datasets.rst Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/astergdem.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * split regex * split regex * split regex * regex Co-authored-by: Caleb Robinson <calebrob6@gmail.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
- Loading branch information
1 parent
863924c
commit 6e974b4
Showing
8 changed files
with
270 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import hashlib | ||
import os | ||
import random | ||
import zipfile | ||
|
||
import numpy as np | ||
import rasterio | ||
|
||
np.random.seed(0) | ||
random.seed(0) | ||
|
||
SIZE = 64 | ||
|
||
files = [ | ||
{"image": "ASTGTMV003_N000000_dem.tif"}, | ||
{"image": "ASTGTMV003_N000010_dem.tif"}, | ||
] | ||
|
||
|
||
def create_file(path: str, dtype: str, num_channels: int) -> None: | ||
profile = {} | ||
profile["driver"] = "GTiff" | ||
profile["dtype"] = dtype | ||
profile["count"] = num_channels | ||
profile["crs"] = "epsg:4326" | ||
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) | ||
profile["height"] = SIZE | ||
profile["width"] = SIZE | ||
profile["compress"] = "lzw" | ||
profile["predictor"] = 2 | ||
|
||
Z = np.random.randint( | ||
np.iinfo(profile["dtype"]).max, size=(1, SIZE, SIZE), dtype=profile["dtype"] | ||
) | ||
src = rasterio.open(path, "w", **profile) | ||
src.write(Z) | ||
|
||
|
||
if __name__ == "__main__": | ||
zipfilename = "astergdem.zip" | ||
files_to_zip = [] | ||
|
||
for file_dict in files: | ||
path = file_dict["image"] | ||
# remove old data | ||
if os.path.exists(path): | ||
os.remove(path) | ||
# Create mask file | ||
create_file(path, dtype="int32", num_channels=1) | ||
files_to_zip.append(path) | ||
|
||
# Compress data | ||
with zipfile.ZipFile(zipfilename, "w") as zip: | ||
for file in files_to_zip: | ||
zip.write(file, arcname=file) | ||
|
||
# Compute checksums | ||
with open(zipfilename, "rb") as f: | ||
md5 = hashlib.md5(f.read()).hexdigest() | ||
print(f"{zipfilename}: {md5}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
import shutil | ||
from pathlib import Path | ||
|
||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from rasterio.crs import CRS | ||
|
||
from torchgeo.datasets import AsterGDEM, BoundingBox, IntersectionDataset, UnionDataset | ||
|
||
|
||
class TestAsterGDEM: | ||
@pytest.fixture | ||
def dataset(self, tmp_path: Path) -> AsterGDEM: | ||
zipfile = os.path.join("tests", "data", "astergdem", "astergdem.zip") | ||
shutil.unpack_archive(zipfile, tmp_path, "zip") | ||
root = str(tmp_path) | ||
transforms = nn.Identity() # type: ignore[attr-defined] | ||
return AsterGDEM(root, transforms=transforms) | ||
|
||
def test_datasetmissing(self, tmp_path: Path) -> None: | ||
shutil.rmtree(tmp_path) | ||
os.makedirs(tmp_path) | ||
with pytest.raises(RuntimeError, match="Dataset not found in"): | ||
AsterGDEM(root=str(tmp_path)) | ||
|
||
def test_getitem(self, dataset: AsterGDEM) -> None: | ||
x = dataset[dataset.bounds] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["crs"], CRS) | ||
assert isinstance(x["mask"], torch.Tensor) | ||
|
||
def test_and(self, dataset: AsterGDEM) -> None: | ||
ds = dataset & dataset | ||
assert isinstance(ds, IntersectionDataset) | ||
|
||
def test_or(self, dataset: AsterGDEM) -> None: | ||
ds = dataset | dataset | ||
assert isinstance(ds, UnionDataset) | ||
|
||
def test_plot(self, dataset: AsterGDEM) -> None: | ||
query = dataset.bounds | ||
x = dataset[query] | ||
dataset.plot(x, suptitle="Test") | ||
|
||
def test_plot_prediction(self, dataset: AsterGDEM) -> None: | ||
query = dataset.bounds | ||
x = dataset[query] | ||
x["prediction"] = x["mask"].clone() | ||
dataset.plot(x, suptitle="Prediction") | ||
|
||
def test_invalid_query(self, dataset: AsterGDEM) -> None: | ||
query = BoundingBox(100, 100, 100, 100, 0, 0) | ||
with pytest.raises( | ||
IndexError, match="query: .* not found in index with bounds:" | ||
): | ||
dataset[query] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
"""Aster Global Digital Evaluation Model dataset.""" | ||
|
||
import glob | ||
import os | ||
from typing import Any, Callable, Dict, Optional | ||
|
||
import matplotlib.pyplot as plt | ||
from rasterio.crs import CRS | ||
from torch import Tensor | ||
|
||
from .geo import RasterDataset | ||
|
||
|
||
class AsterGDEM(RasterDataset): | ||
"""Aster Global Digital Evaluation Model Dataset. | ||
The `Aster Global Digital Evaluation Model | ||
<https://lpdaac.usgs.gov/products/astgtmv003/>`_ | ||
dataset is a Digital Elevation Model (DEM) on a global scale. | ||
The dataset can be downloaded from the | ||
`Earth Data website <https://search.earthdata.nasa.gov/search/>`_ | ||
after making an account. | ||
Dataset features: | ||
* DEMs at 30 m per pixel spatial resolution (3601x3601 px) | ||
* data collected from the `Aster | ||
<https://terra.nasa.gov/about/terra-instruments/aster>`_ instrument | ||
Dataset format: | ||
* DEMs are single-channel tif files | ||
.. versionadded:: 0.3 | ||
""" | ||
|
||
is_image = False | ||
filename_glob = "ASTGTMV003_*_dem*" | ||
filename_regex = r""" | ||
(?P<name>[ASTGTMV003]{10}) | ||
_(?P<id>[A-Z0-9]{7}) | ||
_(?P<data>[a-z]{3})* | ||
""" | ||
|
||
def __init__( | ||
self, | ||
root: str = "data", | ||
crs: Optional[CRS] = None, | ||
res: Optional[float] = None, | ||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, | ||
cache: bool = True, | ||
) -> None: | ||
"""Initialize a new Dataset instance. | ||
Args: | ||
root: root directory where dataset can be found, here the collection of | ||
individual zip files for each tile should be found | ||
crs: :term:`coordinate reference system (CRS)` to warp to | ||
(defaults to the CRS of the first file found) | ||
res: resolution of the dataset in units of CRS | ||
(defaults to the resolution of the first file found) | ||
transforms: a function/transform that takes an input sample | ||
and returns a transformed version | ||
cache: if True, cache file handle to speed up repeated sampling | ||
Raises: | ||
FileNotFoundError: if no files are found in ``root`` | ||
RuntimeError: if dataset is missing | ||
""" | ||
self.root = root | ||
|
||
self._verify() | ||
|
||
super().__init__(root, crs, res, transforms, cache) | ||
|
||
def _verify(self) -> None: | ||
"""Verify the integrity of the dataset. | ||
Raises: | ||
RuntimeError: if dataset is missing | ||
""" | ||
# Check if the extracted files already exists | ||
pathname = os.path.join(self.root, self.filename_glob) | ||
if glob.glob(pathname): | ||
return | ||
|
||
raise RuntimeError( | ||
f"Dataset not found in `root={self.root}` " | ||
"either specify a different `root` directory or make sure you " | ||
"have manually downloaded dataset tiles as suggested in the documentation." | ||
) | ||
|
||
def plot( # type: ignore[override] | ||
self, | ||
sample: Dict[str, Tensor], | ||
show_titles: bool = True, | ||
suptitle: Optional[str] = None, | ||
) -> plt.Figure: | ||
"""Plot a sample from the dataset. | ||
Args: | ||
sample: a sample returned by :meth:`RasterDataset.__getitem__` | ||
show_titles: flag indicating whether to show titles above each panel | ||
suptitle: optional string to use as a suptitle | ||
Returns: | ||
a matplotlib Figure with the rendered sample | ||
""" | ||
mask = sample["mask"].squeeze() | ||
ncols = 1 | ||
|
||
showing_predictions = "prediction" in sample | ||
if showing_predictions: | ||
prediction = sample["prediction"].squeeze() | ||
ncols = 2 | ||
|
||
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4 * ncols, 4)) | ||
|
||
if showing_predictions: | ||
axs[0].imshow(mask) | ||
axs[0].axis("off") | ||
axs[1].imshow(prediction) | ||
axs[1].axis("off") | ||
if show_titles: | ||
axs[0].set_title("Mask") | ||
axs[1].set_title("Prediction") | ||
else: | ||
axs.imshow(mask) | ||
axs.axis("off") | ||
if show_titles: | ||
axs.set_title("Mask") | ||
|
||
if suptitle is not None: | ||
plt.suptitle(suptitle) | ||
|
||
return fig |