Skip to content

Commit

Permalink
Added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 committed Sep 29, 2023
1 parent af5f89d commit 6017797
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 39 deletions.
101 changes: 101 additions & 0 deletions tests/data/rwanda_field_boundary/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil

import numpy as np
import rasterio

dates = ("2021_03", "2021_04", "2021_08", "2021_10", "2021_11", "2021_12")
all_bands = ("B01", "B02", "B03", "B04")

SIZE = 32
NUM_SAMPLES = 5
np.random.seed(0)


def create_mask(fn: str) -> None:
profile = {
"driver": "GTiff",
"dtype": "uint8",
"nodata": 0.0,
"width": SIZE,
"height": SIZE,
"count": 1,
"crs": "epsg:3857",
"compress": "lzw",
"predictor": 2,
"transform": rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0),
"blockysize": 32,
"tiled": False,
"interleave": "band",
}
with rasterio.open(fn, "w", **profile) as f:
f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint8), 1)


def create_img(fn: str) -> None:
profile = {
"driver": "GTiff",
"dtype": "uint16",
"nodata": 0.0,
"width": SIZE,
"height": SIZE,
"count": 1,
"crs": "epsg:3857",
"compress": "lzw",
"predictor": 2,
"blockysize": 16,
"transform": rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0),
"tiled": False,
"interleave": "band",
}
with rasterio.open(fn, "w", **profile) as f:
f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint16), 1)


if __name__ == "__main__":
# Train and test images
for split in ("train", "test"):
for i in range(NUM_SAMPLES):
for date in dates:
directory = os.path.join(
f"nasa_rwanda_field_boundary_competition_source_{split}",
f"nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}",
)
os.makedirs(directory, exist_ok=True)
for band in all_bands:
create_img(os.path.join(directory, f"{band}.tif"))

# Create collections.json, this isn't used by the dataset but is checked to
# exist
with open(
f"nasa_rwanda_field_boundary_competition_source_{split}/collections.json",
"w",
) as f:
f.write("Not used")

# Train labels
for i in range(NUM_SAMPLES):
directory = os.path.join(
"nasa_rwanda_field_boundary_competition_labels_train",
f"nasa_rwanda_field_boundary_competition_labels_train_{i:02d}",
)
os.makedirs(directory, exist_ok=True)
create_mask(os.path.join(directory, "raster_labels.tif"))

# Create directories and compute checksums
for filename in [
"nasa_rwanda_field_boundary_competition_source_train",
"nasa_rwanda_field_boundary_competition_source_test",
"nasa_rwanda_field_boundary_competition_labels_train",
]:
shutil.make_archive(filename, "gztar", ".", filename)
# Compute checksums
with open(f"{filename}.tar.gz", "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{filename}: {md5}")
Binary file not shown.
Binary file not shown.
Binary file not shown.
140 changes: 140 additions & 0 deletions tests/datasets/test_rwanda_field_boundary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import glob
import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset

from torchgeo.datasets import RwandaFieldBoundary


class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join("tests", "data", "rwanda_field_boundary", "*.tar.gz")
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)


def fetch(dataset_id: str, **kwargs: str) -> Collection:
return Collection()


class TestRwandaFieldBoundary:
@pytest.fixture(params=["train", "test"])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> RwandaFieldBoundary:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3")
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
monkeypatch.setattr(
RwandaFieldBoundary, "number_of_patches_per_split", {"train": 5, "test": 5}
)
monkeypatch.setattr(
RwandaFieldBoundary,
"md5s",
{
"train_images": "af9395e2e49deefebb35fa65fa378ba3",
"test_images": "d104bb82323a39e7c3b3b7dd0156f550",
"train_labels": "6cceaf16a141cf73179253a783e7d51b",
},
)

root = str(tmp_path)
split = request.param
transforms = nn.Identity()
return RwandaFieldBoundary(
root, split, transforms=transforms, api_key="", download=True, checksum=True
)

def test_getitem(self, dataset: RwandaFieldBoundary) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
if dataset.split == "train":
assert isinstance(x["mask"], torch.Tensor)
else:
assert "mask" not in x

def test_len(self, dataset: RwandaFieldBoundary) -> None:
assert len(dataset) == 5

def test_add(self, dataset: RwandaFieldBoundary) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
assert len(ds) == 10

def test_needs_extraction(self, tmp_path: Path) -> None:
root = str(tmp_path)
for fn in [
"nasa_rwanda_field_boundary_competition_source_train.tar.gz",
"nasa_rwanda_field_boundary_competition_source_test.tar.gz",
"nasa_rwanda_field_boundary_competition_labels_train.tar.gz",
]:
url = os.path.join("tests", "data", "rwanda_field_boundary", fn)
shutil.copy(url, root)
RwandaFieldBoundary(root, checksum=False)

def test_already_downloaded(self, dataset: RwandaFieldBoundary) -> None:
RwandaFieldBoundary(root=dataset.root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in"):
RwandaFieldBoundary(str(tmp_path))

def test_corrupted(self, tmp_path: Path) -> None:
for fn in [
"nasa_rwanda_field_boundary_competition_source_train.tar.gz",
"nasa_rwanda_field_boundary_competition_source_test.tar.gz",
"nasa_rwanda_field_boundary_competition_labels_train.tar.gz",
]:
with open(os.path.join(tmp_path, fn), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
RwandaFieldBoundary(root=str(tmp_path), checksum=True)

def test_failed_download(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3")
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
monkeypatch.setattr(
RwandaFieldBoundary,
"md5s",
{"train_images": "bad", "test_images": "bad", "train_labels": "bad"},
)
root = str(tmp_path)
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
RwandaFieldBoundary(root, "train", api_key="", download=True, checksum=True)

def test_no_api_key(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Must provide an API key to download"):
RwandaFieldBoundary(str(tmp_path), api_key=None, download=True)

def test_invalid_bands(self) -> None:
with pytest.raises(ValueError, match="is an invalid band name."):
RwandaFieldBoundary(bands=("foo", "bar"))

def test_plot(self, dataset: RwandaFieldBoundary) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()

if dataset.split == "train":
x["prediction"] = x["mask"].clone()
dataset.plot(x)
plt.close()

def test_failed_plot(self, dataset: RwandaFieldBoundary) -> None:
single_band_dataset = RwandaFieldBoundary(root=dataset.root, bands=("B01",))
with pytest.raises(ValueError, match="Dataset doesn't contain"):
x = single_band_dataset[0].copy()
single_band_dataset.plot(x, suptitle="Test")
54 changes: 15 additions & 39 deletions torchgeo/datasets/rwanda_field_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,17 @@ class RwandaFieldBoundary(NonGeoDataset):
]
number_of_patches_per_split = {"train": 57, "test": 13}

image_meta = {
"filename": "nasa_rwanda_field_boundary_competition_source_train.tar.gz",
"md5": "1f9ec08038218e67e11f82a86849b333",
filenames = {
"train_images": "nasa_rwanda_field_boundary_competition_source_train.tar.gz",
"test_images": "nasa_rwanda_field_boundary_competition_source_test.tar.gz",
"train_labels": "nasa_rwanda_field_boundary_competition_labels_train.tar.gz",
}
target_meta = {
"filename": "nasa_rwanda_field_boundary_competition_labels_train.tar.gz",
"md5": "10e4eb761523c57b6d3bdf9394004f5f",
}
image_test_meta = {
"filename": "nasa_rwanda_field_boundary_competition_source_test.tar.gz",
"md5": "17bb0e56eedde2e7a43c57aa908dc125",
md5s = {
"train_images": "1f9ec08038218e67e11f82a86849b333",
"test_images": "17bb0e56eedde2e7a43c57aa908dc125",
"train_labels": "10e4eb761523c57b6d3bdf9394004f5f",
}

dates = ("2021_03", "2021_04", "2021_08", "2021_10", "2021_11", "2021_12")

all_bands = ("B01", "B02", "B03", "B04")
Expand Down Expand Up @@ -223,10 +222,10 @@ def _verify(self) -> None:

# Check if tar file already exists (if so then extract)
have_all_files = True
for group in [self.image_meta, self.target_meta, self.image_test_meta]:
filepath = os.path.join(self.root, group["filename"])
for group in ["train_images", "train_labels", "test_images"]:
filepath = os.path.join(self.root, self.filenames[group])
if os.path.exists(filepath):
if self.checksum and not check_integrity(filepath, group["md5"]):
if self.checksum and not check_integrity(filepath, self.md5s[group]):
raise RuntimeError("Dataset found, but corrupted.")
extract_archive(filepath)
else:
Expand All @@ -245,29 +244,6 @@ def _verify(self) -> None:
# Download and extract the dataset
self._download()

def _check_integrity(self) -> bool:
"""Check integrity of dataset.
Returns:
True if dataset files are found and/or MD5s match, else False
"""
images: bool = check_integrity(
os.path.join(self.root, self.image_meta["filename"]),
self.image_meta["md5"] if self.checksum else None,
)

targets: bool = check_integrity(
os.path.join(self.root, self.target_meta["filename"]),
self.target_meta["md5"] if self.checksum else None,
)

test_images: bool = check_integrity(
os.path.join(self.root, self.image_test_meta["filename"]),
self.image_test_meta["md5"] if self.checksum else None,
)

return images and targets and test_images

def _download(self) -> None:
"""Download the dataset and extract it.
Expand All @@ -277,9 +253,9 @@ def _download(self) -> None:
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, self.api_key)

for group in [self.image_meta, self.target_meta, self.image_test_meta]:
filepath = os.path.join(self.root, group["filename"])
if self.checksum and not check_integrity(filepath, group["md5"]):
for group in ["train_images", "train_labels", "test_images"]:
filepath = os.path.join(self.root, self.filenames[group])
if self.checksum and not check_integrity(filepath, self.md5s[group]):
raise RuntimeError("Dataset not found or corrupted.")
extract_archive(filepath, self.root)

Expand Down

0 comments on commit 6017797

Please sign in to comment.