Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ETCI2021 Dataset #119

Merged
merged 12 commits into from
Sep 12, 2021
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ CV4A Kenya Crop Type Competition

.. autoclass:: CV4AKenyaCropType

ETCI2021 Flood Detection
^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: ETCI2021

LandCover.ai (Land Cover from Aerial Imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
Binary file added tests/data/etci2021/test_without_ref_labels.zip
Binary file not shown.
Binary file added tests/data/etci2021/train.zip
Binary file not shown.
Binary file added tests/data/etci2021/val_with_ref_labels.zip
Binary file not shown.
86 changes: 86 additions & 0 deletions tests/datasets/test_etci2021.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path
from typing import Generator

import pytest
import torch
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch

import torchgeo.datasets.utils
from torchgeo.datasets import ETCI2021
from torchgeo.transforms import Identity


def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)


class TestETCI2021:
@pytest.fixture(params=["train", "val", "test"])
def dataset(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
request: SubRequest,
) -> ETCI2021:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.utils, "download_url", download_url
)
data_dir = os.path.join("tests", "data", "etci2021")
metadata = {
"train": {
"filename": "train.zip",
"md5": "50c10eb07d6db9aee3ba36401e4a2c45",
"directory": "train",
"url": os.path.join(data_dir, "train.zip"),
},
"val": {
"filename": "val_with_ref_labels.zip",
"md5": "3e8b5a3cb95e6029e0e2c2d4b4ec6fba",
"directory": "test",
"url": os.path.join(data_dir, "val_with_ref_labels.zip"),
},
"test": {
"filename": "test_without_ref_labels.zip",
"md5": "c8ee1e5d3e478761cd00ebc6f28b0ae7",
"directory": "test_internal",
"url": os.path.join(data_dir, "test_without_ref_labels.zip"),
},
}
monkeypatch.setattr(ETCI2021, "metadata", metadata) # type: ignore[attr-defined] # noqa: E501
root = str(tmp_path)
split = request.param
transforms = Identity()
return ETCI2021(root, split, transforms, download=True, checksum=True)

def test_getitem(self, dataset: ETCI2021) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
assert x["image"].shape[0] == 6
assert x["image"].shape[-2:] == x["mask"].shape[-2:]

if dataset.split != "test":
assert x["mask"].shape[0] == 2
else:
assert x["mask"].shape[0] == 1

def test_len(self, dataset: ETCI2021) -> None:
assert len(dataset) == 2

def test_already_downloaded(self, dataset: ETCI2021) -> None:
ETCI2021(root=dataset.root, download=True)

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
ETCI2021(split="foo")

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
ETCI2021(str(tmp_path))
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .cowc import COWC, COWCCounting, COWCDetection
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .cyclone import TropicalCycloneWindEstimation
from .etci2021 import ETCI2021
from .geo import GeoDataset, RasterDataset, VectorDataset, VisionDataset, ZipDataset
from .landcoverai import LandCoverAI
from .landsat import (
Expand Down Expand Up @@ -81,6 +82,7 @@
"COWCCounting",
"COWCDetection",
"CV4AKenyaCropType",
"ETCI2021",
"LandCoverAI",
"LEVIRCDPlus",
"PatternNet",
Expand Down
249 changes: 249 additions & 0 deletions torchgeo/datasets/etci2021.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""ETCI 2021 dataset."""

import glob
import os
import shutil
from typing import Callable, Dict, List, Optional

import numpy as np
import torch
from PIL import Image
from torch import Tensor

from .geo import VisionDataset
from .utils import download_and_extract_archive


class ETCI2021(VisionDataset):
"""ETCI 2021 Flood Detection dataset.

The `ETCI2021 <https://nasa-impact.github.io/etci2021/>`_
dataset is a dataset for flood detection

Dataset features:
* 33,405 VV & VH Sentinel-1 Synthetic Aperture Radar (SAR) images
* 2 binary masks per image representing water body & flood, respectively
* 2 polarization band images (VV, VH) of 3 RGB channels per band
* 3 RGB channels per band generated by the Hybrid Pluggable Processing Pipeline 'hyp3'
* Images with 5x20m per pixel resolution (256x256) px) taken in Interferometric Wide Swath acquisition mode
* Flood events from 5 different regions

Dataset format:
* VV band three-channel png
* VH band three-channel png
* water body mask single-channel png where no water body = 0, water body = 255
* flood mask single-channel png where no flood = 0, flood = 255

Dataset classes:
1. no flood/water
2. flood/water

If you use this dataset in your research, please add the following to your
acknowledgements section:
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved

'The authors would like to thank the NASA Earth Science Data Systems Program,
NASA Digital Transformation AI/ML thrust, and IEEE GRSS for organizing the ETCI competition'.
""" # noqa: E501
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved

splits = ["train", "val", "test"]
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
bands = ["VV", "VH"]
masks = ["flood", "water_body"]
metadata = {
"train": {
"filename": "train.zip",
"md5": "1e95792fe0f6e3c9000abdeab2a8ab0f",
"directory": "train",
"url": "https://drive.google.com/file/d/14HqNW5uWLS92n7KrxKgDwUTsSEST6LCr",
},
"val": {
"filename": "val_with_ref_labels.zip",
"md5": "fd18cecb318efc69f8319f90c3771bdf",
"directory": "test",
"url": "https://drive.google.com/file/d/19sriKPHCZLfJn_Jmk3Z_0b3VaCBVRVyn",
},
"test": {
"filename": "test_without_ref_labels.zip",
"md5": "da9fa69e1498bd49d5c766338c6dac3d",
"directory": "test_internal",
"url": "https://drive.google.com/file/d/1rpMVluASnSHBfm2FhpPDio0GyCPOqg7E",
},
}

def __init__(
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new ETCI 2021 dataset instance.

Args:
root: root directory where dataset can be found
split: one of "train", "val", or "test"
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)

Raises:
AssertionError: if ``split`` argument is invalid
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
"""
assert split in self.splits

self.root = root
self.split = split
self.transforms = transforms
self.checksum = checksum

if download:
self._download()

if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)

self.files = self._load_files(self.root, self.split)

def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.

Args:
index: index to return

Returns:
data and label at that index
"""
files = self.files[index]
vv = self._load_image(files["vv"])
vh = self._load_image(files["vh"])
water_mask = self._load_target(files["water_mask"])

if self.split != "test":
flood_mask = self._load_target(files["flood_mask"])
mask = torch.stack(tensors=[water_mask, flood_mask], dim=0)
else:
mask = water_mask.unsqueeze(0)

image = torch.cat(tensors=[vv, vh], dim=0) # type: ignore[attr-defined]
sample = {"image": image, "mask": mask}

if self.transforms is not None:
sample = self.transforms(sample)

return sample

def __len__(self) -> int:
"""Return the number of data points in the dataset.

Returns:
length of the dataset
"""
return len(self.files)

def _load_files(self, root: str, split: str) -> List[Dict[str, str]]:
"""Return the paths of the files in the dataset.

Args:
root: root dir of dataset
split: subset of dataset, one of [train, val, test]

Returns:
list of dicts containing paths for each pair of vv, vh,
water body mask, flood mask (train/val only)
"""
files = []
directory = self.metadata[split]["directory"]
folders = sorted(glob.glob(os.path.join(root, directory, "*")))
folders = [os.path.join(folder, "tiles") for folder in folders]
for folder in folders:
vvs = glob.glob(os.path.join(folder, "vv", "*.png"))
vhs = glob.glob(os.path.join(folder, "vh", "*.png"))
water_masks = glob.glob(os.path.join(folder, "water_body_label", "*.png"))

if split == "test":
flood_masks = [""] * len(water_masks)
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
else:
flood_masks = glob.glob(os.path.join(folder, "flood_label", "*.png"))

for vv, vh, flood_mask, water_mask in zip(
vvs, vhs, flood_masks, water_masks
):
files.append(
dict(vv=vv, vh=vh, flood_mask=flood_mask, water_mask=water_mask)
)
return files

def _load_image(self, path: str) -> Tensor:
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
"""Load a single image.

Args:
path: path to the image

Returns:
the image
"""
filename = os.path.join(path)
with Image.open(filename) as img:
array = np.array(img.convert("RGB"))
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
return tensor

def _load_target(self, path: str) -> Tensor:
"""Load the target mask for a single image.

Args:
path: path to the image

Returns:
the target mask
"""
filename = os.path.join(path)
with Image.open(filename) as img:
array = np.array(img.convert("L"))
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
tensor = torch.clip(tensor, min=0, max=1) # type: ignore[attr-defined]
tensor = tensor.to(torch.long) # type: ignore[attr-defined]
return tensor

def _check_integrity(self) -> bool:
"""Checks the integrity of the dataset structure.

Returns:
True if the dataset directories and split files are found, else False
"""
directory = self.metadata[self.split]["directory"]
dirpath = os.path.join(self.root, directory)
if not os.path.exists(dirpath):
return False
return True
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved

def _download(self) -> None:
"""Download the dataset and extract it.

Raises:
AssertionError: if the checksum of split.py does not match
"""
if self._check_integrity():
print("Files already downloaded and verified")
return

download_and_extract_archive(
self.metadata[self.split]["url"],
self.root,
filename=self.metadata[self.split]["filename"],
md5=self.metadata[self.split]["md5"] if self.checksum else None,
)

if os.path.exists(os.path.join(self.root, "__MACOSX")):
shutil.rmtree(os.path.join(self.root, "__MACOSX"))
Comment on lines +255 to +256
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@isaaccorley is this strictly needed? It's currently not covered by our unit tests and I'm wondering if we can just remove it.