From 03392c797cfa77ba6dd5f4c88c034d9a19b71993 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Thu, 10 Mar 2022 22:07:41 +0100 Subject: [PATCH 01/11] add dataset --- torchgeo/datasets/__init__.py | 2 + torchgeo/datasets/forestdamage.py | 268 ++++++++++++++++++++++++++++++ 2 files changed, 270 insertions(+) create mode 100644 torchgeo/datasets/forestdamage.py diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 8d28060b4cf..47b72e57760 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -34,6 +34,7 @@ from .eudem import EUDEM from .eurosat import EuroSAT from .fair1m import FAIR1M +from .forestdamage import ForestDamage from .geo import ( GeoDataset, IntersectionDataset, @@ -138,6 +139,7 @@ "ETCI2021", "EuroSAT", "FAIR1M", + "ForestDamage", "GID15", "IDTReeS", "InriaAerialImageLabeling", diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py new file mode 100644 index 00000000000..8d5a4d499bc --- /dev/null +++ b/torchgeo/datasets/forestdamage.py @@ -0,0 +1,268 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Forest Damage dataset.""" + +import glob +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from xml.etree import ElementTree + +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image +from torch import Tensor + +from .geo import VisionDataset + + +def parse_pascal_voc(path: str) -> Dict[str, Any]: + """Read a PASCAL VOC annotation file. + + Args: + path: path to xml file + + Returns: + dict of image filename, points, and class labels + """ + et = ElementTree.parse(path) + element = et.getroot() + filename = element.find("filename").text # type: ignore[union-attr] + labels, bboxes = [], [] + for obj in element.findall("object"): + bndbox = obj.find("bndbox") + bbox = [ + int(bndbox.find("xmin").text), # type: ignore[union-attr, arg-type] + int(bndbox.find("ymin").text), # type: ignore[union-attr, arg-type] + int(bndbox.find("xmax").text), # type: ignore[union-attr, arg-type] + int(bndbox.find("ymax").text), # type: ignore[union-attr, arg-type] + ] + label = obj.find("damage").text # type: ignore[union-attr] + bboxes.append(bbox) + labels.append(label) + return dict(filename=filename, bboxes=bboxes, labels=labels) + + +class ForestDamage(VisionDataset): + """Forest Damage dataset. + + The `ForestDamage + `_ + dataset contains drone imagery that can be used for tree identification, + as well as tree damage classification for larch trees. + + Dataset features: + + * 1543 images + * 101,878 tree annotations + * subset of 840 images contain 44,522 annotations about tree health + (Healthy (H), Light Damage (LD), High Damage (HD)), all other + images have "other" as damage level + + Dataset format: + + * images are three-channel jpgs + * annotations are in `Pascal VOC XML format + `_n + + If you use this dataset in your research, please use the following citation: + + * Swedish Forest Agency (2021): Forest Damages – Larch Casebearer 1.0. + National Forest Data Lab. Dataset. + + .. versionadded:: 0.3 + """ + + classes = ["H", "LD", "HD", "other"] + + data_dir = "Data_Set_Larch_Casebearer" + + def __init__( + self, + root: str = "data", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new ForestDamage dataset instance. + + Args: + root: root directory where dataset can be found + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if ``download=False`` and data is not found, or checksums + don't match + """ + self.root = root + self.transforms = transforms + self.checksum = checksum + self.download = download + + self.files = self._load_files(self.root) + + self.class_to_idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)} + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + files = self.files[index] + parsed = parse_pascal_voc(files["annotation"]) + image = self._load_image(files["image"]) + + boxes, labels = self._load_target(parsed["bboxes"], parsed["labels"]) + + sample = {"image": image, "boxes": boxes, "label": labels} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.files) + + def _load_files(self, root: str) -> List[Dict[str, str]]: + """Return the paths of the files in the dataset. + + Args: + root: root dir of dataset + + Returns: + list of dicts containing paths for each pair of image, audio, label + """ + images = sorted( + glob.glob(os.path.join(root, self.data_dir, "**", "Images", "*.JPG")) + ) + annotations = sorted( + glob.glob(os.path.join(root, self.data_dir, "**", "Annotations", "*.xml")) + ) + + files = [ + dict(image=image, annotation=annotation) + for image, annotation in zip(images, annotations) + ] + + return files + + def _load_image(self, path: str) -> Tensor: + """Load a single image. + + Args: + path: path to the image + + Returns: + the image + """ + with Image.open(path) as img: + array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _load_target( + self, bboxes: List[List[int]], labels: List[str] + ) -> Tuple[Tensor, Tensor]: + """Load the target mask for a single image. + + Args: + bboxes: list of bbox coordinats [xmin, ymin, xmax, ymax] + labels: list of class labels + + Returns: + the target bounding boxes and labels + """ + labels_list = [self.class_to_idx[label] for label in labels] + boxes = torch.tensor(bboxes).to(torch.float) # type: ignore[attr-defined] + labels = torch.tensor(labels_list) # type: ignore[attr-defined] + return boxes, cast(Tensor, labels) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + image = sample["image"].permute((1, 2, 0)).numpy() + + ncols = 1 + showing_predictions = "prediction_boxes" in sample + if showing_predictions: + ncols += 1 + + fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) + if not showing_predictions: + axs = [axs] + + axs[0].imshow(image) + axs[0].axis("off") + + bboxes = [ + patches.Rectangle( + (bbox[0], bbox[1]), + bbox[2] - bbox[0], + bbox[3] - bbox[1], + linewidth=1, + edgecolor="r", + facecolor="none", + ) + for bbox in sample["boxes"].numpy() + ] + for bbox in bboxes: + axs[0].add_patch(bbox) + + if show_titles: + axs[0].set_title("Ground Truth") + + if showing_predictions: + axs[1].imshow(image) + axs[1].axis("off") + + pred_bboxes = [ + patches.Rectangle( + (bbox[0], bbox[1]), + bbox[2] - bbox[0], + bbox[3] - bbox[1], + linewidth=1, + edgecolor="r", + facecolor="none", + ) + for bbox in sample["prediction_boxes"].numpy() + ] + for bbox in pred_bboxes: + axs[1].add_patch(bbox) + + if show_titles: + axs[1].set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig From 31908c5bf9f6c629b1d7c4f3bfe17b665686b654 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Fri, 11 Mar 2022 10:24:00 +0100 Subject: [PATCH 02/11] md5 --- torchgeo/datasets/forestdamage.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index 8d5a4d499bc..ef9b95c002e 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -78,6 +78,7 @@ class ForestDamage(VisionDataset): classes = ["H", "LD", "HD", "other"] data_dir = "Data_Set_Larch_Casebearer" + md5 = "907815bcc739bff89496fac8f8ce63d7" def __init__( self, From 7ade643674ab77f057f3c5c3607dd98164a73beb Mon Sep 17 00:00:00 2001 From: nilsleh Date: Fri, 11 Mar 2022 17:07:13 +0100 Subject: [PATCH 03/11] added tests and data --- docs/api/datasets.rst | 7 +- .../Data_Set_Larch_Casebearer.zip | Bin 0 -> 3961 bytes .../Annotations/B01_0004.xml | 1 + .../Annotations/B01_0005.xml | 1 + .../Bebehojd_20190527/Images/B01_0004.JPG | Bin 0 -> 1254 bytes .../Bebehojd_20190527/Images/B01_0005.JPG | Bin 0 -> 1231 bytes tests/data/forestdamage/data.py | 80 ++++++++++++++++ tests/datasets/test_forestdamage.py | 86 ++++++++++++++++++ torchgeo/datasets/forestdamage.py | 69 +++++++++++++- 9 files changed, 238 insertions(+), 6 deletions(-) create mode 100644 tests/data/forestdamage/Data_Set_Larch_Casebearer.zip create mode 100644 tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0004.xml create mode 100644 tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0005.xml create mode 100644 tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0004.JPG create mode 100644 tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0005.JPG create mode 100644 tests/data/forestdamage/data.py create mode 100644 tests/datasets/test_forestdamage.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index e18d187865e..f428f64608b 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -155,7 +155,12 @@ EuroSAT FAIR1M (Fine-grAined object recognItion in high-Resolution imagery) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: FAIR1M +.. autoclass:: FAIR1Ms + +Forest Damage +^^^^^^^^^^^^^ + +.. autoclass:: ForestDamage GID-15 (Gaofen Image Dataset) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/data/forestdamage/Data_Set_Larch_Casebearer.zip b/tests/data/forestdamage/Data_Set_Larch_Casebearer.zip new file mode 100644 index 0000000000000000000000000000000000000000..d472a94a0f14186e9ac72037bd51fa1401092465 GIT binary patch literal 3961 zcmeHKdpMJCA4iNLC6Q9(70V%WNR(5g`Gp}uXf!s7Ml`g9a`=@)NhRkw6ka(@h*&6@ zg&`C(ZAxs6ZADoTd#B!x@2{?_U++Kfb-mx`dY=1!{&-*YYjJ z<`};NA142PY~YRNFci$g0fF+ch6Q=~cdKm8M z4MTX*pN)VPxl&CkUPjmE(8Ii` z>Q}iwkMOHhTlejSWF7JxS##O`rY~bdhO2U+H%44#-yDH$jqzJs28ew|>8F%DZC87oQ3F^+sU$DB{fY1+N$2|w%|!(Y0% zMsO}6E;cqjcGD;bTo`moz*O34@17kBgx#|(W-z7eU~!3}I8i;h(E~fg=Qd{?KqXsQ z<45xt7MBL@87p>qLE!JiSXwbn{#q|Jbp#ZXTrSu%hgy2VP21&2BqKbIO2SLJ0+QZE%*zqn5diea8oEyznhrTU&B9ZE7IZ&?$cDgAh3 z4h@5;ff~N3&ns=aKHlbOIG~YXG-6`({J4$pn_*!&VeqcC_~ea}sKwxb>HoQCRaM9Yr^LJL?L8 z4RkUkoF-RyoKR{p5ob9mr@!SEo(RPkYj&P5dUFl7+uDXZg!ab?V1=+dnugtiG7gdB(-stV7{EWW+V`@auk~K4W<| zH(eN>c;Jaj*2bNdV6L%Al?$5D&8jhg_;xZ0Y*78&iNPlU-1P_QrL9Wu7%`76&d_g< zba*$ohPrVJ83eWN2%Oxj0hO$l5CT+@L!<1-maU%&PLq)&5hV{nh$|Jn{N`C`Mp@O) zxpz=GdJG_Lz2wC~A>$4QQmc=fAz{nYsT<;YwJk%T+G6q{MFlZ(;}b@gDqM4Km(X9W zKzpTdv}3iRw4piMY3LmNe+CaJfnxLM6E>6Mn}pK_>jw}9jZcGLjWM5+a7_UL)B(h3 z6Uye@wU@zIzZZqfM)wBtA{WBVL6Pt2T|*m%jE<&A;%ZsBNZ%`>Ylc^ge(5SRo0yi3LeQ7 z$f1N2(z2#DRN-4@x= zAAL&8;7w7b|IpRY%eVXpbpkfC#`2s;o1A65mEr1vY@6zcjPeGro*S1!mxE+EBJUiH z)W!O0PSztz=rMbk>>2txlV`&QN#U+1xufz5D%xoKgH5-zBnkZ}d zK#F|%{R!G8HgsYv>;%$raM5vpParOYzJitI_(`+zODSB&^IyVb;&&AG=6@{f~2~Y=Mi> zuS*MYla&tGF5fWyve7MA!SE(0sR<;kxrEQOj_#j!THnILgQfPz?T06Z`&HX&8x?s@(q z@ihL#%jYrf0GSts3+^6~_T#HPd)x5$?$qy%nr#N&Hc=~Rh0cnOJGb_QTb`~{aWNLl z1y_uCXZjW>D%W9+asjzJHgrNXO3E5Av0jzaTa6(1W7=AX52-@NQNFFXl8l5RomDS< zrMG&gO6?w6LDuLQj`sLg>j6po`v0O}iAeOz9w=36jb8h_>`evGr^CO*az&aXOxpY3o7 z>kD!eZ~EgM{-B01_0004.xml32323other882424 \ No newline at end of file diff --git a/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0005.xml b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0005.xml new file mode 100644 index 00000000000..f6e8957d533 --- /dev/null +++ b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Annotations/B01_0005.xml @@ -0,0 +1 @@ +B01_0005.xml32323other882424 \ No newline at end of file diff --git a/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0004.JPG b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0004.JPG new file mode 100644 index 0000000000000000000000000000000000000000..afc980b49ddbdb1fa90b5b2944bd2c2e52cc9418 GIT binary patch literal 1254 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L(=`2Q>k5-%AK=(R;AxKW^rG!sxRW{zJ&Lm zPB8L@yjkbAWU|iw)cX>>v5P0PPAIdipUTL})w|@$cAp2YqNC4Omt8k@&DEXlyTkZM z^qHiS`!gRI&)LXhHfOozlgs%k-pC*y{Sb6UZL zUIyXiFaG^4yI=ic{pJgSOEq@8`(Ms}eLs5Jx^1_E>^Hmg7u{NN@n+nr4PEKmU9|+m z3c0;>4v2=`yC1N#ONKV|d6o*1^D7Q7b^9MT9{ zuH(q?$%Cz1ZA#v^?Tc^Sx4X4&du^1b>+TnCtGBFKTf5!&#g|W2rJHk2PkPU_yZg*N za^a3ykJ1yB);(}YVvCu(ic?{l5l`za!+@Ba{8ZCSMxWj4_V#pBkDCFrn1q%?5Szrc7{TbRd!~Ipwzg`*ikYdZ zQ#d(tl#Kp;erB`s)SumLcbe3i+mGL0boJ-5tKa@JsB|rsSaEyW+5G3vYsHt%dbRGh z{!P`7SDyWLJ>suZIEn3!yGHKKyKOo*bm!|ZF($B?C%BrJ8%D;~_TBT;TW53h^SeCG z2@Mi6CLFU1e3qxMjnTGvj!C%LNgZx(iMQ`$7(VPyO9?&D(|aIcwLoiE&*K&H#wzb+ z3S$Iix;|;l^FDD<>$Xc?3PaD+oz~Gi4luKIugi9Px$EM~+qtoKf8XA`ugl1P>74Q{ S;p<-4#!cIOy?V9%|C<14Z~>_R literal 0 HcmV?d00001 diff --git a/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0005.JPG b/tests/data/forestdamage/Data_Set_Larch_Casebearer/Bebehojd_20190527/Images/B01_0005.JPG new file mode 100644 index 0000000000000000000000000000000000000000..8d742f2c51b455ab2839b083325bbb563d95aee5 GIT binary patch literal 1231 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L;a)+r*p1l+7wM;@>u&+%kO#6i{3w4UXwgz z*e8L|tf`TK~TKO@Z|N>AU*@ None: + root = ET.Element("annotation") + + ET.SubElement(root, "filename").text = os.path.basename(path) + + size = ET.SubElement(root, "size") + + ET.SubElement(size, "width").text = str(SIZE) + ET.SubElement(size, "height").text = str(SIZE) + ET.SubElement(size, "depth").text = str(3) + + annotation = ET.SubElement(root, "object") + + ET.SubElement(annotation, "damage").text = "other" + + bbox = ET.SubElement(annotation, "bndbox") + ET.SubElement(bbox, "xmin").text = str(0 + int(SIZE / 4)) + ET.SubElement(bbox, "ymin").text = str(0 + int(SIZE / 4)) + ET.SubElement(bbox, "xmax").text = str(SIZE - int(SIZE / 4)) + ET.SubElement(bbox, "ymax").text = str(SIZE - int(SIZE / 4)) + + tree = ET.ElementTree(root) + tree.write(path) + + +def create_file(path: str) -> None: + Z = np.random.rand(SIZE, SIZE, 3) * 255 + img = Image.fromarray(Z.astype("uint8")).convert("RGB") + img.save(path) + + +if __name__ == "__main__": + data_root = "Data_Set_Larch_Casebearer" + # remove old data + if os.path.isdir(data_root): + shutil.rmtree(data_root) + else: + os.makedirs(data_root) + + for path in PATHS["images"]: + os.makedirs(os.path.join(data_root, os.path.dirname(path)), exist_ok=True) + create_file(os.path.join(data_root, path)) + + for path in PATHS["annotations"]: + os.makedirs(os.path.join(data_root, os.path.dirname(path)), exist_ok=True) + create_annotation(os.path.join(data_root, path)) + + # compress data + shutil.make_archive(data_root, "zip", ".", data_root) + + # Compute checksums + with open(data_root + ".zip", "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{data_root}: {md5}") diff --git a/tests/datasets/test_forestdamage.py b/tests/datasets/test_forestdamage.py new file mode 100644 index 00000000000..afc9705e9f2 --- /dev/null +++ b/tests/datasets/test_forestdamage.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path +from typing import Generator + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.monkeypatch import MonkeyPatch + +import torchgeo.datasets.utils +from torchgeo.datasets import ForestDamage + + +def download_url(url: str, root: str, *args: str) -> None: + shutil.copy(url, root) + + +class TestForestDamage: + @pytest.fixture + def dataset( + self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path + ) -> ForestDamage: + monkeypatch.setattr( # type: ignore[attr-defined] + torchgeo.datasets.utils, "download_url", download_url + ) + data_dir = os.path.join("tests", "data", "forestdamage") + + url = os.path.join(data_dir, "Data_Set_Larch_Casebearer.zip") + + md5 = "a6adc19879c1021cc1ba8d424e19c9e0" + + monkeypatch.setattr(ForestDamage, "url", url) # type: ignore[attr-defined] + monkeypatch.setattr(ForestDamage, "md5", md5) # type: ignore[attr-defined] + root = str(tmp_path) + transforms = nn.Identity() # type: ignore[attr-defined] + return ForestDamage( + root=root, transforms=transforms, download=True, checksum=True + ) + + def test_already_downloaded(self, dataset: ForestDamage) -> None: + ForestDamage(root=dataset.root, download=True) + + def test_getitem(self, dataset: ForestDamage) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["label"], torch.Tensor) + assert isinstance(x["boxes"], torch.Tensor) + assert x["image"].shape[0] == 3 + assert x["image"].ndim == 3 + + def test_len(self, dataset: ForestDamage) -> None: + assert len(dataset) == 2 + + def test_not_extracted(self, tmp_path: Path) -> None: + url = os.path.join( + "tests", "data", "forestdamage", "Data_Set_Larch_Casebearer.zip" + ) + shutil.copy(url, tmp_path) + ForestDamage(root=str(tmp_path)) + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, "Data_Set_Larch_Casebearer.zip"), "w") as f: + f.write("bad") + with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + ForestDamage(root=str(tmp_path), checksum=True) + + def test_not_found(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found in."): + ForestDamage(str(tmp_path)) + + def test_plot(self, dataset: ForestDamage) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: ForestDamage) -> None: + x = dataset[0].copy() + x["prediction_boxes"] = x["boxes"].clone() + dataset.plot(x, suptitle="Prediction") + plt.close() diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index ef9b95c002e..35e6ed98158 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -16,6 +16,7 @@ from torch import Tensor from .geo import VisionDataset +from .utils import check_integrity, download_and_extract_archive, extract_archive def parse_pascal_voc(path: str) -> Dict[str, Any]: @@ -65,18 +66,33 @@ class ForestDamage(VisionDataset): * images are three-channel jpgs * annotations are in `Pascal VOC XML format - `_n + `_ + + Dataset Classes: + + * other (0) + * healthy (1) + * light damage (2) + * high damage (3) + + If the download fails or takes too long, it is recommended to try azcopy + as suggested `here `_. It is expected that the + downloaded data file with name `Data_Set_Larch_Casebearer` + can be found in `root`. If you use this dataset in your research, please use the following citation: - * Swedish Forest Agency (2021): Forest Damages – Larch Casebearer 1.0. + * Swedish Forest Agency (2021): Forest Damages - Larch Casebearer 1.0. National Forest Data Lab. Dataset. .. versionadded:: 0.3 """ - classes = ["H", "LD", "HD", "other"] - + classes = ["other", "H", "LD", "HD"] + url = ( + "https://lilablobssc.blob.core.windows.net/larch-casebearer/" + "Data_Set_Larch_Casebearer.zip" + ) data_dir = "Data_Set_Larch_Casebearer" md5 = "907815bcc739bff89496fac8f8ce63d7" @@ -105,6 +121,8 @@ def __init__( self.checksum = checksum self.download = download + self._verify() + self.files = self._load_files(self.root) self.class_to_idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)} @@ -146,7 +164,7 @@ def _load_files(self, root: str) -> List[Dict[str, str]]: root: root dir of dataset Returns: - list of dicts containing paths for each pair of image, audio, label + list of dicts containing paths for each pair of image, annotation """ images = sorted( glob.glob(os.path.join(root, self.data_dir, "**", "Images", "*.JPG")) @@ -195,6 +213,47 @@ def _load_target( labels = torch.tensor(labels_list) # type: ignore[attr-defined] return boxes, cast(Tensor, labels) + def _verify(self) -> None: + """Checks the integrity of the dataset structure. + + Returns: + True if the dataset directories are found, else False + """ + filepath = os.path.join(self.root, self.data_dir) + if os.path.isdir(filepath): + return + + filepath = os.path.join(self.root, self.data_dir + ".zip") + if os.path.isfile(filepath): + if self.checksum and not check_integrity(filepath, self.md5): + raise RuntimeError("Dataset found, but corrupted.") + extract_archive(filepath) + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + "Dataset not found in `root` directory, either specify a different" + + " `root` directory or manually download " + + "the dataset to this directory." + ) + + # else download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset and extract it. + + Raises: + AssertionError: if the checksum does not match + """ + download_and_extract_archive( + self.url, + self.root, + filename=self.data_dir + ".zip", + md5=self.md5 if self.checksum else None, + ) + def plot( self, sample: Dict[str, Tensor], From d3effe67f03c44701ff4c550f2f165b62d112976 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Sun, 13 Mar 2022 15:13:04 +0100 Subject: [PATCH 04/11] test --- torchgeo/datasets/forestdamage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index 35e6ed98158..55b7f9bb7db 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -75,7 +75,7 @@ class ForestDamage(VisionDataset): * light damage (2) * high damage (3) - If the download fails or takes too long, it is recommended to try azcopy + If the download fails or stalls, it is recommended to try azcopy as suggested `here `_. It is expected that the downloaded data file with name `Data_Set_Larch_Casebearer` can be found in `root`. From 9fc245a7595d2b4d13cb6ef58cb4995ababb2204 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Mon, 21 Mar 2022 10:05:52 +0100 Subject: [PATCH 05/11] remove type --- tests/datasets/test_forestdamage.py | 2 +- torchgeo/datasets/forestdamage.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/datasets/test_forestdamage.py b/tests/datasets/test_forestdamage.py index afc9705e9f2..14f4e44558f 100644 --- a/tests/datasets/test_forestdamage.py +++ b/tests/datasets/test_forestdamage.py @@ -37,7 +37,7 @@ def dataset( monkeypatch.setattr(ForestDamage, "url", url) # type: ignore[attr-defined] monkeypatch.setattr(ForestDamage, "md5", md5) # type: ignore[attr-defined] root = str(tmp_path) - transforms = nn.Identity() # type: ignore[attr-defined] + transforms = nn.Identity() # type: ignore[no-untyped-call] return ForestDamage( root=root, transforms=transforms, download=True, checksum=True ) diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index 55b7f9bb7db..51decbfea32 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Callable, Dict, List, Optional, Tuple from xml.etree import ElementTree import matplotlib.patches as patches @@ -191,27 +191,26 @@ def _load_image(self, path: str) -> Tensor: """ with Image.open(path) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) - tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] + tensor: Tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor def _load_target( - self, bboxes: List[List[int]], labels: List[str] + self, bboxes: List[List[int]], labels_list: List[str] ) -> Tuple[Tensor, Tensor]: """Load the target mask for a single image. Args: bboxes: list of bbox coordinats [xmin, ymin, xmax, ymax] - labels: list of class labels + labels_list: list of class labels Returns: the target bounding boxes and labels """ - labels_list = [self.class_to_idx[label] for label in labels] - boxes = torch.tensor(bboxes).to(torch.float) # type: ignore[attr-defined] - labels = torch.tensor(labels_list) # type: ignore[attr-defined] - return boxes, cast(Tensor, labels) + labels = torch.tensor([self.class_to_idx[label] for label in labels_list]) + boxes = torch.tensor(bboxes).to(torch.float) + return boxes, labels def _verify(self) -> None: """Checks the integrity of the dataset structure. From ff5c105a9e73fc18c57216f6d20c96e64840f69e Mon Sep 17 00:00:00 2001 From: nilsleh Date: Mon, 21 Mar 2022 10:52:15 +0100 Subject: [PATCH 06/11] fix docs --- docs/api/datasets.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index f428f64608b..31c5a059911 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -155,7 +155,7 @@ EuroSAT FAIR1M (Fine-grAined object recognItion in high-Resolution imagery) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: FAIR1Ms +.. autoclass:: FAIR1M Forest Damage ^^^^^^^^^^^^^ From 0009beafaf3b88a6b60ce2bf2c7721a46fbd9c02 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Mon, 21 Mar 2022 11:04:43 +0100 Subject: [PATCH 07/11] fix docs --- torchgeo/datasets/forestdamage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index 51decbfea32..6edae115c61 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -76,7 +76,7 @@ class ForestDamage(VisionDataset): * high damage (3) If the download fails or stalls, it is recommended to try azcopy - as suggested `here `_. It is expected that the + as suggested `here `__. It is expected that the downloaded data file with name `Data_Set_Larch_Casebearer` can be found in `root`. From cd7173b1e9b0411dc0253e61a91a2fd4fff905dc Mon Sep 17 00:00:00 2001 From: nilsleh Date: Wed, 30 Mar 2022 09:58:54 +0200 Subject: [PATCH 08/11] requested changes --- tests/datasets/test_forestdamage.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/datasets/test_forestdamage.py b/tests/datasets/test_forestdamage.py index 14f4e44558f..5db2e160d1f 100644 --- a/tests/datasets/test_forestdamage.py +++ b/tests/datasets/test_forestdamage.py @@ -4,7 +4,6 @@ import os import shutil from pathlib import Path -from typing import Generator import matplotlib.pyplot as plt import pytest @@ -22,20 +21,16 @@ def download_url(url: str, root: str, *args: str) -> None: class TestForestDamage: @pytest.fixture - def dataset( - self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path - ) -> ForestDamage: - monkeypatch.setattr( # type: ignore[attr-defined] - torchgeo.datasets.utils, "download_url", download_url - ) + def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ForestDamage: + monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) data_dir = os.path.join("tests", "data", "forestdamage") url = os.path.join(data_dir, "Data_Set_Larch_Casebearer.zip") md5 = "a6adc19879c1021cc1ba8d424e19c9e0" - monkeypatch.setattr(ForestDamage, "url", url) # type: ignore[attr-defined] - monkeypatch.setattr(ForestDamage, "md5", md5) # type: ignore[attr-defined] + monkeypatch.setattr(ForestDamage, "url", url) + monkeypatch.setattr(ForestDamage, "md5", md5) root = str(tmp_path) transforms = nn.Identity() # type: ignore[no-untyped-call] return ForestDamage( From 4f435b4fdb2b2c3f7b923e4bb654e09ee6141e3d Mon Sep 17 00:00:00 2001 From: nilsleh Date: Thu, 31 Mar 2022 10:26:08 +0200 Subject: [PATCH 09/11] fix documentation and pyupgrade --- tests/data/forestdamage/data.py | 2 +- torchgeo/datasets/forestdamage.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/data/forestdamage/data.py b/tests/data/forestdamage/data.py index 93d16170ea5..0926c52b4c3 100644 --- a/tests/data/forestdamage/data.py +++ b/tests/data/forestdamage/data.py @@ -2,7 +2,7 @@ import os import random import shutil -import xml.etree.cElementTree as ET +import xml.etree.ElementTree as ET import numpy as np from PIL import Image diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index 6edae115c61..35676a4d277 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -70,15 +70,15 @@ class ForestDamage(VisionDataset): Dataset Classes: - * other (0) - * healthy (1) - * light damage (2) - * high damage (3) + 0. other + 1. healthy + 2. light damage + 3. high damage If the download fails or stalls, it is recommended to try azcopy as suggested `here `__. It is expected that the - downloaded data file with name `Data_Set_Larch_Casebearer` - can be found in `root`. + downloaded data file with name ``Data_Set_Larch_Casebearer`` + can be found in ``root``. If you use this dataset in your research, please use the following citation: From 3d4a73db7aa0ccc525b760ab0bb4be7d611c0515 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Thu, 31 Mar 2022 17:11:34 +0200 Subject: [PATCH 10/11] remove random --- tests/data/forestdamage/data.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/data/forestdamage/data.py b/tests/data/forestdamage/data.py index 0926c52b4c3..aa1f23dae1c 100644 --- a/tests/data/forestdamage/data.py +++ b/tests/data/forestdamage/data.py @@ -1,6 +1,5 @@ import hashlib import os -import random import shutil import xml.etree.ElementTree as ET @@ -10,7 +9,6 @@ SIZE = 32 np.random.seed(0) -random.seed(0) PATHS = { "images": [ From f6abd479f4ef3a8d88330fd475c8777b2cf508f8 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Sat, 2 Apr 2022 17:00:20 +0200 Subject: [PATCH 11/11] missing license header --- tests/data/forestdamage/data.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/data/forestdamage/data.py b/tests/data/forestdamage/data.py index aa1f23dae1c..bd026a1c204 100644 --- a/tests/data/forestdamage/data.py +++ b/tests/data/forestdamage/data.py @@ -1,3 +1,8 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + import hashlib import os import shutil