Skip to content

Commit

Permalink
Adding the GlobBiomass dataset (microsoft#395)
Browse files Browse the repository at this point in the history
* globBiomass Dataset

* add tests and testdata

* add description and error messages

* doc correction

* added plot method

* orientation plot figure

* fix documentation

* add compression

* camel

* gsv fake data and filename glob

* 2 channel tensor and requested changes
  • Loading branch information
nilsleh authored Feb 26, 2022
1 parent 47a8024 commit fb9989c
Show file tree
Hide file tree
Showing 9 changed files with 440 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ Esri2020

.. autoclass:: Esri2020

GlobBiomass
^^^^^^^^^^^
.. autoclass:: GlobBiomass

Landsat
^^^^^^^

Expand Down
Binary file added tests/data/globbiomass/N00E020_agb.tif
Binary file not shown.
Binary file added tests/data/globbiomass/N00E020_agb.zip
Binary file not shown.
Binary file added tests/data/globbiomass/N00E020_agb_err.tif
Binary file not shown.
Binary file added tests/data/globbiomass/N00E020_gsv.zip
Binary file not shown.
66 changes: 66 additions & 0 deletions tests/data/globbiomass/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import random
import zipfile

import numpy as np
import rasterio

np.random.seed(0)
random.seed(0)

SIZE = 64


files = {
"agb": ["N00E020_agb.tif", "N00E020_agb_err.tif"],
"gsv": ["N00E020_gsv.tif", "N00E020_gsv_err.tif"],
}


def create_file(path: str, dtype: str, num_channels: int) -> None:
profile = {}
profile["driver"] = "GTiff"
profile["dtype"] = dtype
profile["count"] = num_channels
profile["crs"] = "epsg:4326"
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
profile["height"] = SIZE
profile["width"] = SIZE
profile["compress"] = "lzw"
profile["predictor"] = 2

Z = np.random.randint(
np.iinfo(profile["dtype"]).max, size=(1, SIZE, SIZE), dtype=profile["dtype"]
)
src = rasterio.open(path, "w", **profile)
src.write(Z)


if __name__ == "__main__":

for measurement, file_paths in files.items():
zipfilename = "N00E020_{}.zip".format(measurement)
files_to_zip = []
for path in file_paths:
# remove old data
if os.path.exists(path):
os.remove(path)
# Create mask file
create_file(path, dtype="int32", num_channels=1)
files_to_zip.append(path)

# Compress data
with zipfile.ZipFile(zipfilename, "w") as zip:
for file in files_to_zip:
zip.write(file, arcname=file)

# Compute checksums
with open(zipfilename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{zipfilename}: {md5}")
88 changes: 88 additions & 0 deletions tests/datasets/test_globbiomass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path
from typing import Generator

import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
GlobBiomass,
IntersectionDataset,
UnionDataset,
)


class TestGlobBiomass:
@pytest.fixture
def dataset(
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path
) -> GlobBiomass:
shutil.copy(
os.path.join("tests", "data", "globbiomass", "N00E020_agb.zip"), tmp_path
)
shutil.copy(
os.path.join("tests", "data", "globbiomass", "N00E020_gsv.zip"), tmp_path
)

md5s = {
"N00E020_agb.zip": "22e11817ede672a2a76b8a5588bc4bf4",
"N00E020_gsv.zip": "e79bf051ac5d659cb21c566c53ce7b98",
}

monkeypatch.setattr(GlobBiomass, "md5s", md5s) # type: ignore[attr-defined]
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[attr-defined]
return GlobBiomass(root, transforms=transforms, checksum=True)

def test_getitem(self, dataset: GlobBiomass) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)

def test_already_extracted(self, dataset: GlobBiomass) -> None:
GlobBiomass(root=dataset.root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
GlobBiomass(str(tmp_path), checksum=True)

def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "N00E020_agb.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
GlobBiomass(root=str(tmp_path), checksum=True)

def test_and(self, dataset: GlobBiomass) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: GlobBiomass) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_plot(self, dataset: GlobBiomass) -> None:
query = dataset.bounds
x = dataset[query]
dataset.plot(x, suptitle="Test")

def test_plot_prediction(self, dataset: GlobBiomass) -> None:
query = dataset.bounds
x = dataset[query]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")

def test_invalid_query(self, dataset: GlobBiomass) -> None:
query = BoundingBox(100, 100, 100, 100, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
VisionDataset,
)
from .gid15 import GID15
from .globbiomass import GlobBiomass
from .idtrees import IDTReeS
from .inria import InriaAerialImageLabeling
from .landcoverai import LandCoverAI
Expand Down Expand Up @@ -102,6 +103,7 @@
"ChesapeakeCVPR",
"CMSGlobalMangroveCanopy",
"Esri2020",
"GlobBiomass",
"Landsat",
"Landsat1",
"Landsat2",
Expand Down
Loading

0 comments on commit fb9989c

Please sign in to comment.