diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index f4e90812e07..239d822954d 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -157,6 +157,11 @@ FAIR1M (Fine-grAined object recognItion in high-Resolution imagery) .. autoclass:: FAIR1M +Forest Damage +^^^^^^^^^^^^^ + +.. autoclass:: ForestDamage + GID-15 (Gaofen Image Dataset) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/data/forestdamage/Data_Set_Larch_Casebearer.zip b/tests/data/forestdamage/Data_Set_Larch_Casebearer.zip new file mode 100644 index 00000000000..d472a94a0f1 Binary files /dev/null and b/tests/data/forestdamage/Data_Set_Larch_Casebearer.zip differ diff --git a/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0004.xml b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0004.xml new file mode 100644 index 00000000000..bf5a549fb68 --- /dev/null +++ b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0004.xml @@ -0,0 +1 @@ +B01_0004.xml32323other882424 \ No newline at end of file diff --git a/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0005.xml b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0005.xml new file mode 100644 index 00000000000..f6e8957d533 --- /dev/null +++ b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0005.xml @@ -0,0 +1 @@ +B01_0005.xml32323other882424 \ No newline at end of file diff --git a/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0004.JPG b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0004.JPG new file mode 100644 index 00000000000..afc980b49dd Binary files /dev/null and b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0004.JPG differ diff --git a/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0005.JPG b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0005.JPG new file mode 100644 index 00000000000..8d742f2c51b Binary files /dev/null and b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0005.JPG differ diff --git a/tests/data/forestdamage/data.py b/tests/data/forestdamage/data.py new file mode 100644 index 00000000000..bd026a1c204 --- /dev/null +++ b/tests/data/forestdamage/data.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil +import xml.etree.ElementTree as ET + +import numpy as np +from PIL import Image + +SIZE = 32 + +np.random.seed(0) + +PATHS = { + "images": [ + "Bebehojd_20190527/Images/B01_0004.JPG", + "Bebehojd_20190527/Images/B01_0005.JPG", + ], + "annotations": [ + "Bebehojd_20190527/Annotations/B01_0004.xml", + "Bebehojd_20190527/Annotations/B01_0005.xml", + ], +} + + +def create_annotation(path: str) -> None: + root = ET.Element("annotation") + + ET.SubElement(root, "filename").text = os.path.basename(path) + + size = ET.SubElement(root, "size") + + ET.SubElement(size, "width").text = str(SIZE) + ET.SubElement(size, "height").text = str(SIZE) + ET.SubElement(size, "depth").text = str(3) + + annotation = ET.SubElement(root, "object") + + ET.SubElement(annotation, "damage").text = "other" + + bbox = ET.SubElement(annotation, "bndbox") + ET.SubElement(bbox, "xmin").text = str(0 + int(SIZE / 4)) + ET.SubElement(bbox, "ymin").text = str(0 + int(SIZE / 4)) + ET.SubElement(bbox, "xmax").text = str(SIZE - int(SIZE / 4)) + ET.SubElement(bbox, "ymax").text = str(SIZE - int(SIZE / 4)) + + tree = ET.ElementTree(root) + tree.write(path) + + +def create_file(path: str) -> None: + Z = np.random.rand(SIZE, SIZE, 3) * 255 + img = Image.fromarray(Z.astype("uint8")).convert("RGB") + img.save(path) + + +if __name__ == "__main__": + data_root = "Data_Set_Larch_Casebearer" + # remove old data + if os.path.isdir(data_root): + shutil.rmtree(data_root) + else: + os.makedirs(data_root) + + for path in PATHS["images"]: + os.makedirs(os.path.join(data_root, os.path.dirname(path)), exist_ok=True) + create_file(os.path.join(data_root, path)) + + for path in PATHS["annotations"]: + os.makedirs(os.path.join(data_root, os.path.dirname(path)), exist_ok=True) + create_annotation(os.path.join(data_root, path)) + + # compress data + shutil.make_archive(data_root, "zip", ".", data_root) + + # Compute checksums + with open(data_root + ".zip", "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{data_root}: {md5}") diff --git a/tests/datasets/test_forestdamage.py b/tests/datasets/test_forestdamage.py new file mode 100644 index 00000000000..5db2e160d1f --- /dev/null +++ b/tests/datasets/test_forestdamage.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.monkeypatch import MonkeyPatch + +import torchgeo.datasets.utils +from torchgeo.datasets import ForestDamage + + +def download_url(url: str, root: str, *args: str) -> None: + shutil.copy(url, root) + + +class TestForestDamage: + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ForestDamage: + monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) + data_dir = os.path.join("tests", "data", "forestdamage") + + url = os.path.join(data_dir, "Data_Set_Larch_Casebearer.zip") + + md5 = "a6adc19879c1021cc1ba8d424e19c9e0" + + monkeypatch.setattr(ForestDamage, "url", url) + monkeypatch.setattr(ForestDamage, "md5", md5) + root = str(tmp_path) + transforms = nn.Identity() # type: ignore[no-untyped-call] + return ForestDamage( + root=root, transforms=transforms, download=True, checksum=True + ) + + def test_already_downloaded(self, dataset: ForestDamage) -> None: + ForestDamage(root=dataset.root, download=True) + + def test_getitem(self, dataset: ForestDamage) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["label"], torch.Tensor) + assert isinstance(x["boxes"], torch.Tensor) + assert x["image"].shape[0] == 3 + assert x["image"].ndim == 3 + + def test_len(self, dataset: ForestDamage) -> None: + assert len(dataset) == 2 + + def test_not_extracted(self, tmp_path: Path) -> None: + url = os.path.join( + "tests", "data", "forestdamage", "Data_Set_Larch_Casebearer.zip" + ) + shutil.copy(url, tmp_path) + ForestDamage(root=str(tmp_path)) + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, "Data_Set_Larch_Casebearer.zip"), "w") as f: + f.write("bad") + with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + ForestDamage(root=str(tmp_path), checksum=True) + + def test_not_found(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found in."): + ForestDamage(str(tmp_path)) + + def test_plot(self, dataset: ForestDamage) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: ForestDamage) -> None: + x = dataset[0].copy() + x["prediction_boxes"] = x["boxes"].clone() + dataset.plot(x, suptitle="Prediction") + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 0bbdb94b01b..3b7eebd6a34 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -34,6 +34,7 @@ from .eudem import EUDEM from .eurosat import EuroSAT from .fair1m import FAIR1M +from .forestdamage import ForestDamage from .geo import ( GeoDataset, IntersectionDataset, @@ -146,6 +147,7 @@ "ETCI2021", "EuroSAT", "FAIR1M", + "ForestDamage", "GID15", "IDTReeS", "InriaAerialImageLabeling", diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py new file mode 100644 index 00000000000..35676a4d277 --- /dev/null +++ b/torchgeo/datasets/forestdamage.py @@ -0,0 +1,327 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Forest Damage dataset.""" + +import glob +import os +from typing import Any, Callable, Dict, List, Optional, Tuple +from xml.etree import ElementTree + +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image +from torch import Tensor + +from .geo import VisionDataset +from .utils import check_integrity, download_and_extract_archive, extract_archive + + +def parse_pascal_voc(path: str) -> Dict[str, Any]: + """Read a PASCAL VOC annotation file. + + Args: + path: path to xml file + + Returns: + dict of image filename, points, and class labels + """ + et = ElementTree.parse(path) + element = et.getroot() + filename = element.find("filename").text # type: ignore[union-attr] + labels, bboxes = [], [] + for obj in element.findall("object"): + bndbox = obj.find("bndbox") + bbox = [ + int(bndbox.find("xmin").text), # type: ignore[union-attr, arg-type] + int(bndbox.find("ymin").text), # type: ignore[union-attr, arg-type] + int(bndbox.find("xmax").text), # type: ignore[union-attr, arg-type] + int(bndbox.find("ymax").text), # type: ignore[union-attr, arg-type] + ] + label = obj.find("damage").text # type: ignore[union-attr] + bboxes.append(bbox) + labels.append(label) + return dict(filename=filename, bboxes=bboxes, labels=labels) + + +class ForestDamage(VisionDataset): + """Forest Damage dataset. + + The `ForestDamage + `_ + dataset contains drone imagery that can be used for tree identification, + as well as tree damage classification for larch trees. + + Dataset features: + + * 1543 images + * 101,878 tree annotations + * subset of 840 images contain 44,522 annotations about tree health + (Healthy (H), Light Damage (LD), High Damage (HD)), all other + images have "other" as damage level + + Dataset format: + + * images are three-channel jpgs + * annotations are in `Pascal VOC XML format + `_ + + Dataset Classes: + + 0. other + 1. healthy + 2. light damage + 3. high damage + + If the download fails or stalls, it is recommended to try azcopy + as suggested `here `__. It is expected that the + downloaded data file with name ``Data_Set_Larch_Casebearer`` + can be found in ``root``. + + If you use this dataset in your research, please use the following citation: + + * Swedish Forest Agency (2021): Forest Damages - Larch Casebearer 1.0. + National Forest Data Lab. Dataset. + + .. versionadded:: 0.3 + """ + + classes = ["other", "H", "LD", "HD"] + url = ( + "https://lilablobssc.blob.core.windows.net/larch-casebearer/" + "Data_Set_Larch_Casebearer.zip" + ) + data_dir = "Data_Set_Larch_Casebearer" + md5 = "907815bcc739bff89496fac8f8ce63d7" + + def __init__( + self, + root: str = "data", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new ForestDamage dataset instance. + + Args: + root: root directory where dataset can be found + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if ``download=False`` and data is not found, or checksums + don't match + """ + self.root = root + self.transforms = transforms + self.checksum = checksum + self.download = download + + self._verify() + + self.files = self._load_files(self.root) + + self.class_to_idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)} + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + files = self.files[index] + parsed = parse_pascal_voc(files["annotation"]) + image = self._load_image(files["image"]) + + boxes, labels = self._load_target(parsed["bboxes"], parsed["labels"]) + + sample = {"image": image, "boxes": boxes, "label": labels} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.files) + + def _load_files(self, root: str) -> List[Dict[str, str]]: + """Return the paths of the files in the dataset. + + Args: + root: root dir of dataset + + Returns: + list of dicts containing paths for each pair of image, annotation + """ + images = sorted( + glob.glob(os.path.join(root, self.data_dir, "**", "Images", "*.JPG")) + ) + annotations = sorted( + glob.glob(os.path.join(root, self.data_dir, "**", "Annotations", "*.xml")) + ) + + files = [ + dict(image=image, annotation=annotation) + for image, annotation in zip(images, annotations) + ] + + return files + + def _load_image(self, path: str) -> Tensor: + """Load a single image. + + Args: + path: path to the image + + Returns: + the image + """ + with Image.open(path) as img: + array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + tensor: Tensor = torch.from_numpy(array) + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _load_target( + self, bboxes: List[List[int]], labels_list: List[str] + ) -> Tuple[Tensor, Tensor]: + """Load the target mask for a single image. + + Args: + bboxes: list of bbox coordinats [xmin, ymin, xmax, ymax] + labels_list: list of class labels + + Returns: + the target bounding boxes and labels + """ + labels = torch.tensor([self.class_to_idx[label] for label in labels_list]) + boxes = torch.tensor(bboxes).to(torch.float) + return boxes, labels + + def _verify(self) -> None: + """Checks the integrity of the dataset structure. + + Returns: + True if the dataset directories are found, else False + """ + filepath = os.path.join(self.root, self.data_dir) + if os.path.isdir(filepath): + return + + filepath = os.path.join(self.root, self.data_dir + ".zip") + if os.path.isfile(filepath): + if self.checksum and not check_integrity(filepath, self.md5): + raise RuntimeError("Dataset found, but corrupted.") + extract_archive(filepath) + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + "Dataset not found in `root` directory, either specify a different" + + " `root` directory or manually download " + + "the dataset to this directory." + ) + + # else download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset and extract it. + + Raises: + AssertionError: if the checksum does not match + """ + download_and_extract_archive( + self.url, + self.root, + filename=self.data_dir + ".zip", + md5=self.md5 if self.checksum else None, + ) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + image = sample["image"].permute((1, 2, 0)).numpy() + + ncols = 1 + showing_predictions = "prediction_boxes" in sample + if showing_predictions: + ncols += 1 + + fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) + if not showing_predictions: + axs = [axs] + + axs[0].imshow(image) + axs[0].axis("off") + + bboxes = [ + patches.Rectangle( + (bbox[0], bbox[1]), + bbox[2] - bbox[0], + bbox[3] - bbox[1], + linewidth=1, + edgecolor="r", + facecolor="none", + ) + for bbox in sample["boxes"].numpy() + ] + for bbox in bboxes: + axs[0].add_patch(bbox) + + if show_titles: + axs[0].set_title("Ground Truth") + + if showing_predictions: + axs[1].imshow(image) + axs[1].axis("off") + + pred_bboxes = [ + patches.Rectangle( + (bbox[0], bbox[1]), + bbox[2] - bbox[0], + bbox[3] - bbox[1], + linewidth=1, + edgecolor="r", + facecolor="none", + ) + for bbox in sample["prediction_boxes"].numpy() + ] + for bbox in pred_bboxes: + axs[1].add_patch(bbox) + + if show_titles: + axs[1].set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig