forked from microsoft/torchgeo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding the GlobBiomass dataset (microsoft#395)
* globBiomass Dataset * add tests and testdata * add description and error messages * doc correction * added plot method * orientation plot figure * fix documentation * add compression * camel * gsv fake data and filename glob * 2 channel tensor and requested changes
- Loading branch information
Showing
9 changed files
with
440 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import hashlib | ||
import os | ||
import random | ||
import zipfile | ||
|
||
import numpy as np | ||
import rasterio | ||
|
||
np.random.seed(0) | ||
random.seed(0) | ||
|
||
SIZE = 64 | ||
|
||
|
||
files = { | ||
"agb": ["N00E020_agb.tif", "N00E020_agb_err.tif"], | ||
"gsv": ["N00E020_gsv.tif", "N00E020_gsv_err.tif"], | ||
} | ||
|
||
|
||
def create_file(path: str, dtype: str, num_channels: int) -> None: | ||
profile = {} | ||
profile["driver"] = "GTiff" | ||
profile["dtype"] = dtype | ||
profile["count"] = num_channels | ||
profile["crs"] = "epsg:4326" | ||
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) | ||
profile["height"] = SIZE | ||
profile["width"] = SIZE | ||
profile["compress"] = "lzw" | ||
profile["predictor"] = 2 | ||
|
||
Z = np.random.randint( | ||
np.iinfo(profile["dtype"]).max, size=(1, SIZE, SIZE), dtype=profile["dtype"] | ||
) | ||
src = rasterio.open(path, "w", **profile) | ||
src.write(Z) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
for measurement, file_paths in files.items(): | ||
zipfilename = "N00E020_{}.zip".format(measurement) | ||
files_to_zip = [] | ||
for path in file_paths: | ||
# remove old data | ||
if os.path.exists(path): | ||
os.remove(path) | ||
# Create mask file | ||
create_file(path, dtype="int32", num_channels=1) | ||
files_to_zip.append(path) | ||
|
||
# Compress data | ||
with zipfile.ZipFile(zipfilename, "w") as zip: | ||
for file in files_to_zip: | ||
zip.write(file, arcname=file) | ||
|
||
# Compute checksums | ||
with open(zipfilename, "rb") as f: | ||
md5 = hashlib.md5(f.read()).hexdigest() | ||
print(f"{zipfilename}: {md5}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
import shutil | ||
from pathlib import Path | ||
from typing import Generator | ||
|
||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from _pytest.monkeypatch import MonkeyPatch | ||
from rasterio.crs import CRS | ||
|
||
from torchgeo.datasets import ( | ||
BoundingBox, | ||
GlobBiomass, | ||
IntersectionDataset, | ||
UnionDataset, | ||
) | ||
|
||
|
||
class TestGlobBiomass: | ||
@pytest.fixture | ||
def dataset( | ||
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path | ||
) -> GlobBiomass: | ||
shutil.copy( | ||
os.path.join("tests", "data", "globbiomass", "N00E020_agb.zip"), tmp_path | ||
) | ||
shutil.copy( | ||
os.path.join("tests", "data", "globbiomass", "N00E020_gsv.zip"), tmp_path | ||
) | ||
|
||
md5s = { | ||
"N00E020_agb.zip": "22e11817ede672a2a76b8a5588bc4bf4", | ||
"N00E020_gsv.zip": "e79bf051ac5d659cb21c566c53ce7b98", | ||
} | ||
|
||
monkeypatch.setattr(GlobBiomass, "md5s", md5s) # type: ignore[attr-defined] | ||
root = str(tmp_path) | ||
transforms = nn.Identity() # type: ignore[attr-defined] | ||
return GlobBiomass(root, transforms=transforms, checksum=True) | ||
|
||
def test_getitem(self, dataset: GlobBiomass) -> None: | ||
x = dataset[dataset.bounds] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["crs"], CRS) | ||
assert isinstance(x["mask"], torch.Tensor) | ||
|
||
def test_already_extracted(self, dataset: GlobBiomass) -> None: | ||
GlobBiomass(root=dataset.root) | ||
|
||
def test_not_downloaded(self, tmp_path: Path) -> None: | ||
with pytest.raises(RuntimeError, match="Dataset not found"): | ||
GlobBiomass(str(tmp_path), checksum=True) | ||
|
||
def test_corrupted(self, tmp_path: Path) -> None: | ||
with open(os.path.join(tmp_path, "N00E020_agb.zip"), "w") as f: | ||
f.write("bad") | ||
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): | ||
GlobBiomass(root=str(tmp_path), checksum=True) | ||
|
||
def test_and(self, dataset: GlobBiomass) -> None: | ||
ds = dataset & dataset | ||
assert isinstance(ds, IntersectionDataset) | ||
|
||
def test_or(self, dataset: GlobBiomass) -> None: | ||
ds = dataset | dataset | ||
assert isinstance(ds, UnionDataset) | ||
|
||
def test_plot(self, dataset: GlobBiomass) -> None: | ||
query = dataset.bounds | ||
x = dataset[query] | ||
dataset.plot(x, suptitle="Test") | ||
|
||
def test_plot_prediction(self, dataset: GlobBiomass) -> None: | ||
query = dataset.bounds | ||
x = dataset[query] | ||
x["prediction"] = x["mask"].clone() | ||
dataset.plot(x, suptitle="Prediction") | ||
|
||
def test_invalid_query(self, dataset: GlobBiomass) -> None: | ||
query = BoundingBox(100, 100, 100, 100, 0, 0) | ||
with pytest.raises( | ||
IndexError, match="query: .* not found in index with bounds:" | ||
): | ||
dataset[query] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.