Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the GlobBiomass dataset #395

Merged
merged 11 commits into from
Feb 26, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ Esri2020

.. autoclass:: Esri2020

GlobBiomass
^^^^^^^^^^^
.. autoclass:: GlobBiomass

Landsat
^^^^^^^

Expand Down
Binary file added tests/data/globbiomass/N00E020_agb.tif
Binary file not shown.
Binary file added tests/data/globbiomass/N00E020_agb.zip
Binary file not shown.
Binary file added tests/data/globbiomass/N00E020_agb_err.tif
Binary file not shown.
63 changes: 63 additions & 0 deletions tests/data/globbiomass/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import random
import zipfile

import numpy as np
import rasterio

np.random.seed(0)
random.seed(0)

SIZE = 64


files = [{"image": "N00E020_agb.tif"}, {"image": "N00E020_agb_err.tif"}]


def create_file(path: str, dtype: str, num_channels: int) -> None:
profile = {}
profile["driver"] = "GTiff"
profile["dtype"] = dtype
profile["count"] = num_channels
profile["crs"] = "epsg:4326"
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
profile["height"] = SIZE
profile["width"] = SIZE
calebrob6 marked this conversation as resolved.
Show resolved Hide resolved
profile["compress"] = "lzw"
profile["predictor"] = 2

Z = np.random.randint(
np.iinfo(profile["dtype"]).max, size=(1, SIZE, SIZE), dtype=profile["dtype"]
)
src = rasterio.open(path, "w", **profile)
src.write(Z)


if __name__ == "__main__":
zipfilename = "N00E020_agb.zip"
files_to_zip = []

for file_dict in files:
path = file_dict["image"]
# remove old data
if os.path.exists(path):
os.remove(path)
# Create mask file
create_file(path, dtype="int32", num_channels=1)
files_to_zip.append(path)

# Compress data
with zipfile.ZipFile(zipfilename, "w") as zip:
for file in files_to_zip:
zip.write(file, arcname=file)

# Compute checksums
with open(zipfilename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{zipfilename}: {md5}")
83 changes: 83 additions & 0 deletions tests/datasets/test_globbiomass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path
from typing import Generator

import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
GlobBiomass,
IntersectionDataset,
UnionDataset,
)


class TestGlobBiomass:
@pytest.fixture
def dataset(
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path
) -> GlobBiomass:
zipfile = "N00E020_agb.zip"

shutil.copy(os.path.join("tests", "data", "globbiomass", zipfile), tmp_path)

md5s = {zipfile: "7b7b981149aa31a099f453fef32b644f"}

monkeypatch.setattr(GlobBiomass, "md5s", md5s) # type: ignore[attr-defined]
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[attr-defined]
return GlobBiomass(root, transforms=transforms, checksum=True)

def test_getitem(self, dataset: GlobBiomass) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
assert isinstance(x["error_mask"], torch.Tensor)

def test_already_extracted(self, dataset: GlobBiomass) -> None:
GlobBiomass(root=dataset.root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
GlobBiomass(str(tmp_path), checksum=True)

def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "N00E020_agb.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
GlobBiomass(root=str(tmp_path), checksum=True)

def test_and(self, dataset: GlobBiomass) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: GlobBiomass) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_plot(self, dataset: GlobBiomass) -> None:
query = dataset.bounds
x = dataset[query]
dataset.plot(x, suptitle="Test")

def test_plot_prediction(self, dataset: GlobBiomass) -> None:
query = dataset.bounds
x = dataset[query]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")

def test_invalid_query(self, dataset: GlobBiomass) -> None:
query = BoundingBox(100, 100, 100, 100, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
VisionDataset,
)
from .gid15 import GID15
from .globbiomass import GlobBiomass
from .idtrees import IDTReeS
from .inria import InriaAerialImageLabeling
from .landcoverai import LandCoverAI
Expand Down Expand Up @@ -98,6 +99,7 @@
"ChesapeakeWV",
"ChesapeakeCVPR",
"Esri2020",
"GlobBiomass",
"Landsat",
"Landsat1",
"Landsat2",
Expand Down
Loading