diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index f4e90812e07..239d822954d 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -157,6 +157,11 @@ FAIR1M (Fine-grAined object recognItion in high-Resolution imagery)
.. autoclass:: FAIR1M
+Forest Damage
+^^^^^^^^^^^^^
+
+.. autoclass:: ForestDamage
+
GID-15 (Gaofen Image Dataset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/tests/data/forestdamage/Data_Set_Larch_Casebearer.zip b/tests/data/forestdamage/Data_Set_Larch_Casebearer.zip
new file mode 100644
index 00000000000..d472a94a0f1
Binary files /dev/null and b/tests/data/forestdamage/Data_Set_Larch_Casebearer.zip differ
diff --git a/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0004.xml b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0004.xml
new file mode 100644
index 00000000000..bf5a549fb68
--- /dev/null
+++ b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0004.xml
@@ -0,0 +1 @@
+B01_0004.xml32323
\ No newline at end of file
diff --git a/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0005.xml b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0005.xml
new file mode 100644
index 00000000000..f6e8957d533
--- /dev/null
+++ b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0005.xml
@@ -0,0 +1 @@
+B01_0005.xml32323
\ No newline at end of file
diff --git a/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0004.JPG b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0004.JPG
new file mode 100644
index 00000000000..afc980b49dd
Binary files /dev/null and b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0004.JPG differ
diff --git a/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0005.JPG b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0005.JPG
new file mode 100644
index 00000000000..8d742f2c51b
Binary files /dev/null and b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0005.JPG differ
diff --git a/tests/data/forestdamage/data.py b/tests/data/forestdamage/data.py
new file mode 100644
index 00000000000..bd026a1c204
--- /dev/null
+++ b/tests/data/forestdamage/data.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import hashlib
+import os
+import shutil
+import xml.etree.ElementTree as ET
+
+import numpy as np
+from PIL import Image
+
+SIZE = 32
+
+np.random.seed(0)
+
+PATHS = {
+ "images": [
+ "Bebehojd_20190527/Images/B01_0004.JPG",
+ "Bebehojd_20190527/Images/B01_0005.JPG",
+ ],
+ "annotations": [
+ "Bebehojd_20190527/Annotations/B01_0004.xml",
+ "Bebehojd_20190527/Annotations/B01_0005.xml",
+ ],
+}
+
+
+def create_annotation(path: str) -> None:
+ root = ET.Element("annotation")
+
+ ET.SubElement(root, "filename").text = os.path.basename(path)
+
+ size = ET.SubElement(root, "size")
+
+ ET.SubElement(size, "width").text = str(SIZE)
+ ET.SubElement(size, "height").text = str(SIZE)
+ ET.SubElement(size, "depth").text = str(3)
+
+ annotation = ET.SubElement(root, "object")
+
+ ET.SubElement(annotation, "damage").text = "other"
+
+ bbox = ET.SubElement(annotation, "bndbox")
+ ET.SubElement(bbox, "xmin").text = str(0 + int(SIZE / 4))
+ ET.SubElement(bbox, "ymin").text = str(0 + int(SIZE / 4))
+ ET.SubElement(bbox, "xmax").text = str(SIZE - int(SIZE / 4))
+ ET.SubElement(bbox, "ymax").text = str(SIZE - int(SIZE / 4))
+
+ tree = ET.ElementTree(root)
+ tree.write(path)
+
+
+def create_file(path: str) -> None:
+ Z = np.random.rand(SIZE, SIZE, 3) * 255
+ img = Image.fromarray(Z.astype("uint8")).convert("RGB")
+ img.save(path)
+
+
+if __name__ == "__main__":
+ data_root = "Data_Set_Larch_Casebearer"
+ # remove old data
+ if os.path.isdir(data_root):
+ shutil.rmtree(data_root)
+ else:
+ os.makedirs(data_root)
+
+ for path in PATHS["images"]:
+ os.makedirs(os.path.join(data_root, os.path.dirname(path)), exist_ok=True)
+ create_file(os.path.join(data_root, path))
+
+ for path in PATHS["annotations"]:
+ os.makedirs(os.path.join(data_root, os.path.dirname(path)), exist_ok=True)
+ create_annotation(os.path.join(data_root, path))
+
+ # compress data
+ shutil.make_archive(data_root, "zip", ".", data_root)
+
+ # Compute checksums
+ with open(data_root + ".zip", "rb") as f:
+ md5 = hashlib.md5(f.read()).hexdigest()
+ print(f"{data_root}: {md5}")
diff --git a/tests/datasets/test_forestdamage.py b/tests/datasets/test_forestdamage.py
new file mode 100644
index 00000000000..5db2e160d1f
--- /dev/null
+++ b/tests/datasets/test_forestdamage.py
@@ -0,0 +1,81 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import os
+import shutil
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import pytest
+import torch
+import torch.nn as nn
+from _pytest.monkeypatch import MonkeyPatch
+
+import torchgeo.datasets.utils
+from torchgeo.datasets import ForestDamage
+
+
+def download_url(url: str, root: str, *args: str) -> None:
+ shutil.copy(url, root)
+
+
+class TestForestDamage:
+ @pytest.fixture
+ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ForestDamage:
+ monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url)
+ data_dir = os.path.join("tests", "data", "forestdamage")
+
+ url = os.path.join(data_dir, "Data_Set_Larch_Casebearer.zip")
+
+ md5 = "a6adc19879c1021cc1ba8d424e19c9e0"
+
+ monkeypatch.setattr(ForestDamage, "url", url)
+ monkeypatch.setattr(ForestDamage, "md5", md5)
+ root = str(tmp_path)
+ transforms = nn.Identity() # type: ignore[no-untyped-call]
+ return ForestDamage(
+ root=root, transforms=transforms, download=True, checksum=True
+ )
+
+ def test_already_downloaded(self, dataset: ForestDamage) -> None:
+ ForestDamage(root=dataset.root, download=True)
+
+ def test_getitem(self, dataset: ForestDamage) -> None:
+ x = dataset[0]
+ assert isinstance(x, dict)
+ assert isinstance(x["image"], torch.Tensor)
+ assert isinstance(x["label"], torch.Tensor)
+ assert isinstance(x["boxes"], torch.Tensor)
+ assert x["image"].shape[0] == 3
+ assert x["image"].ndim == 3
+
+ def test_len(self, dataset: ForestDamage) -> None:
+ assert len(dataset) == 2
+
+ def test_not_extracted(self, tmp_path: Path) -> None:
+ url = os.path.join(
+ "tests", "data", "forestdamage", "Data_Set_Larch_Casebearer.zip"
+ )
+ shutil.copy(url, tmp_path)
+ ForestDamage(root=str(tmp_path))
+
+ def test_corrupted(self, tmp_path: Path) -> None:
+ with open(os.path.join(tmp_path, "Data_Set_Larch_Casebearer.zip"), "w") as f:
+ f.write("bad")
+ with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
+ ForestDamage(root=str(tmp_path), checksum=True)
+
+ def test_not_found(self, tmp_path: Path) -> None:
+ with pytest.raises(RuntimeError, match="Dataset not found in."):
+ ForestDamage(str(tmp_path))
+
+ def test_plot(self, dataset: ForestDamage) -> None:
+ x = dataset[0].copy()
+ dataset.plot(x, suptitle="Test")
+ plt.close()
+
+ def test_plot_prediction(self, dataset: ForestDamage) -> None:
+ x = dataset[0].copy()
+ x["prediction_boxes"] = x["boxes"].clone()
+ dataset.plot(x, suptitle="Prediction")
+ plt.close()
diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py
index 0bbdb94b01b..3b7eebd6a34 100644
--- a/torchgeo/datasets/__init__.py
+++ b/torchgeo/datasets/__init__.py
@@ -34,6 +34,7 @@
from .eudem import EUDEM
from .eurosat import EuroSAT
from .fair1m import FAIR1M
+from .forestdamage import ForestDamage
from .geo import (
GeoDataset,
IntersectionDataset,
@@ -146,6 +147,7 @@
"ETCI2021",
"EuroSAT",
"FAIR1M",
+ "ForestDamage",
"GID15",
"IDTReeS",
"InriaAerialImageLabeling",
diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py
new file mode 100644
index 00000000000..35676a4d277
--- /dev/null
+++ b/torchgeo/datasets/forestdamage.py
@@ -0,0 +1,327 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""Forest Damage dataset."""
+
+import glob
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple
+from xml.etree import ElementTree
+
+import matplotlib.patches as patches
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from PIL import Image
+from torch import Tensor
+
+from .geo import VisionDataset
+from .utils import check_integrity, download_and_extract_archive, extract_archive
+
+
+def parse_pascal_voc(path: str) -> Dict[str, Any]:
+ """Read a PASCAL VOC annotation file.
+
+ Args:
+ path: path to xml file
+
+ Returns:
+ dict of image filename, points, and class labels
+ """
+ et = ElementTree.parse(path)
+ element = et.getroot()
+ filename = element.find("filename").text # type: ignore[union-attr]
+ labels, bboxes = [], []
+ for obj in element.findall("object"):
+ bndbox = obj.find("bndbox")
+ bbox = [
+ int(bndbox.find("xmin").text), # type: ignore[union-attr, arg-type]
+ int(bndbox.find("ymin").text), # type: ignore[union-attr, arg-type]
+ int(bndbox.find("xmax").text), # type: ignore[union-attr, arg-type]
+ int(bndbox.find("ymax").text), # type: ignore[union-attr, arg-type]
+ ]
+ label = obj.find("damage").text # type: ignore[union-attr]
+ bboxes.append(bbox)
+ labels.append(label)
+ return dict(filename=filename, bboxes=bboxes, labels=labels)
+
+
+class ForestDamage(VisionDataset):
+ """Forest Damage dataset.
+
+ The `ForestDamage
+ `_
+ dataset contains drone imagery that can be used for tree identification,
+ as well as tree damage classification for larch trees.
+
+ Dataset features:
+
+ * 1543 images
+ * 101,878 tree annotations
+ * subset of 840 images contain 44,522 annotations about tree health
+ (Healthy (H), Light Damage (LD), High Damage (HD)), all other
+ images have "other" as damage level
+
+ Dataset format:
+
+ * images are three-channel jpgs
+ * annotations are in `Pascal VOC XML format
+ `_
+
+ Dataset Classes:
+
+ 0. other
+ 1. healthy
+ 2. light damage
+ 3. high damage
+
+ If the download fails or stalls, it is recommended to try azcopy
+ as suggested `here `__. It is expected that the
+ downloaded data file with name ``Data_Set_Larch_Casebearer``
+ can be found in ``root``.
+
+ If you use this dataset in your research, please use the following citation:
+
+ * Swedish Forest Agency (2021): Forest Damages - Larch Casebearer 1.0.
+ National Forest Data Lab. Dataset.
+
+ .. versionadded:: 0.3
+ """
+
+ classes = ["other", "H", "LD", "HD"]
+ url = (
+ "https://lilablobssc.blob.core.windows.net/larch-casebearer/"
+ "Data_Set_Larch_Casebearer.zip"
+ )
+ data_dir = "Data_Set_Larch_Casebearer"
+ md5 = "907815bcc739bff89496fac8f8ce63d7"
+
+ def __init__(
+ self,
+ root: str = "data",
+ transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
+ download: bool = False,
+ checksum: bool = False,
+ ) -> None:
+ """Initialize a new ForestDamage dataset instance.
+
+ Args:
+ root: root directory where dataset can be found
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ download: if True, download dataset and store it in the root directory
+ checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ RuntimeError: if ``download=False`` and data is not found, or checksums
+ don't match
+ """
+ self.root = root
+ self.transforms = transforms
+ self.checksum = checksum
+ self.download = download
+
+ self._verify()
+
+ self.files = self._load_files(self.root)
+
+ self.class_to_idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)}
+
+ def __getitem__(self, index: int) -> Dict[str, Tensor]:
+ """Return an index within the dataset.
+
+ Args:
+ index: index to return
+
+ Returns:
+ data and label at that index
+ """
+ files = self.files[index]
+ parsed = parse_pascal_voc(files["annotation"])
+ image = self._load_image(files["image"])
+
+ boxes, labels = self._load_target(parsed["bboxes"], parsed["labels"])
+
+ sample = {"image": image, "boxes": boxes, "label": labels}
+
+ if self.transforms is not None:
+ sample = self.transforms(sample)
+
+ return sample
+
+ def __len__(self) -> int:
+ """Return the number of data points in the dataset.
+
+ Returns:
+ length of the dataset
+ """
+ return len(self.files)
+
+ def _load_files(self, root: str) -> List[Dict[str, str]]:
+ """Return the paths of the files in the dataset.
+
+ Args:
+ root: root dir of dataset
+
+ Returns:
+ list of dicts containing paths for each pair of image, annotation
+ """
+ images = sorted(
+ glob.glob(os.path.join(root, self.data_dir, "**", "Images", "*.JPG"))
+ )
+ annotations = sorted(
+ glob.glob(os.path.join(root, self.data_dir, "**", "Annotations", "*.xml"))
+ )
+
+ files = [
+ dict(image=image, annotation=annotation)
+ for image, annotation in zip(images, annotations)
+ ]
+
+ return files
+
+ def _load_image(self, path: str) -> Tensor:
+ """Load a single image.
+
+ Args:
+ path: path to the image
+
+ Returns:
+ the image
+ """
+ with Image.open(path) as img:
+ array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB"))
+ tensor: Tensor = torch.from_numpy(array)
+ # Convert from HxWxC to CxHxW
+ tensor = tensor.permute((2, 0, 1))
+ return tensor
+
+ def _load_target(
+ self, bboxes: List[List[int]], labels_list: List[str]
+ ) -> Tuple[Tensor, Tensor]:
+ """Load the target mask for a single image.
+
+ Args:
+ bboxes: list of bbox coordinats [xmin, ymin, xmax, ymax]
+ labels_list: list of class labels
+
+ Returns:
+ the target bounding boxes and labels
+ """
+ labels = torch.tensor([self.class_to_idx[label] for label in labels_list])
+ boxes = torch.tensor(bboxes).to(torch.float)
+ return boxes, labels
+
+ def _verify(self) -> None:
+ """Checks the integrity of the dataset structure.
+
+ Returns:
+ True if the dataset directories are found, else False
+ """
+ filepath = os.path.join(self.root, self.data_dir)
+ if os.path.isdir(filepath):
+ return
+
+ filepath = os.path.join(self.root, self.data_dir + ".zip")
+ if os.path.isfile(filepath):
+ if self.checksum and not check_integrity(filepath, self.md5):
+ raise RuntimeError("Dataset found, but corrupted.")
+ extract_archive(filepath)
+ return
+
+ # Check if the user requested to download the dataset
+ if not self.download:
+ raise RuntimeError(
+ "Dataset not found in `root` directory, either specify a different"
+ + " `root` directory or manually download "
+ + "the dataset to this directory."
+ )
+
+ # else download the dataset
+ self._download()
+
+ def _download(self) -> None:
+ """Download the dataset and extract it.
+
+ Raises:
+ AssertionError: if the checksum does not match
+ """
+ download_and_extract_archive(
+ self.url,
+ self.root,
+ filename=self.data_dir + ".zip",
+ md5=self.md5 if self.checksum else None,
+ )
+
+ def plot(
+ self,
+ sample: Dict[str, Tensor],
+ show_titles: bool = True,
+ suptitle: Optional[str] = None,
+ ) -> plt.Figure:
+ """Plot a sample from the dataset.
+
+ Args:
+ sample: a sample returned by :meth:`__getitem__`
+ show_titles: flag indicating whether to show titles above each panel
+ suptitle: optional string to use as a suptitle
+
+ Returns:
+ a matplotlib Figure with the rendered sample
+ """
+ image = sample["image"].permute((1, 2, 0)).numpy()
+
+ ncols = 1
+ showing_predictions = "prediction_boxes" in sample
+ if showing_predictions:
+ ncols += 1
+
+ fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
+ if not showing_predictions:
+ axs = [axs]
+
+ axs[0].imshow(image)
+ axs[0].axis("off")
+
+ bboxes = [
+ patches.Rectangle(
+ (bbox[0], bbox[1]),
+ bbox[2] - bbox[0],
+ bbox[3] - bbox[1],
+ linewidth=1,
+ edgecolor="r",
+ facecolor="none",
+ )
+ for bbox in sample["boxes"].numpy()
+ ]
+ for bbox in bboxes:
+ axs[0].add_patch(bbox)
+
+ if show_titles:
+ axs[0].set_title("Ground Truth")
+
+ if showing_predictions:
+ axs[1].imshow(image)
+ axs[1].axis("off")
+
+ pred_bboxes = [
+ patches.Rectangle(
+ (bbox[0], bbox[1]),
+ bbox[2] - bbox[0],
+ bbox[3] - bbox[1],
+ linewidth=1,
+ edgecolor="r",
+ facecolor="none",
+ )
+ for bbox in sample["prediction_boxes"].numpy()
+ ]
+ for bbox in pred_bboxes:
+ axs[1].add_patch(bbox)
+
+ if show_titles:
+ axs[1].set_title("Predictions")
+
+ if suptitle is not None:
+ plt.suptitle(suptitle)
+
+ return fig