From 931ad77c4f98a7cc8b3648300fd130e1e920f08e Mon Sep 17 00:00:00 2001 From: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 10 Sep 2021 14:12:53 -0500 Subject: [PATCH 1/5] updated docs --- docs/api/datasets.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 034f3c7d463..ff85d597ab7 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -87,6 +87,11 @@ CV4A Kenya Crop Type Competition .. autoclass:: CV4AKenyaCropType +GID-15 (Gaofen Image Dataset) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: GID15 + LandCover.ai (Land Cover from Aerial Imagery) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From f25311d515bd02c83b296e5d5537107ff17aae89 Mon Sep 17 00:00:00 2001 From: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 10 Sep 2021 14:13:06 -0500 Subject: [PATCH 2/5] added dummy sample test data --- tests/data/gid15/gid-15.zip | Bin 0 -> 4240 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/data/gid15/gid-15.zip diff --git a/tests/data/gid15/gid-15.zip b/tests/data/gid15/gid-15.zip new file mode 100644 index 0000000000000000000000000000000000000000..c6d2d2711fc184e853a5bc03ed49dd300b2b21d5 GIT binary patch literal 4240 zcmcgvT}YEr7(V;4T#YadS`Z?nkWu!1bDLX(a%4%UInW^yw4eN8u!cV~f^G`RP@;xi z)NLhoV>c-wMB;^@hIL^PNI@D%Q3OGD69k=iXXpFQ&NhdGt4! z%-ayWZgn4MN&S(2Go!tz{#4!JK&u6ryrbSQto}>ek-_}Mgiz|uuRXoYJ#9@t9gfF? z?a^L>@(v@w%8K%bLn-1=E{0YYR{pj7W|#+jD;f|v8wKPSE>V9YX4^a(f4^}_){blTKcuxZ80v~ zESoo!+eW{3-Y`~HHSeBD4teG*wXPD=!WKJ}l$>@bHeI_5_DS9lQB!Y5+L!mQCQM8#H#wd5>NW^VG zT3#7{UWe!LHeO2FZ5E05nw`n2M6@$Ci@*enoBp4xZzD?~JcxK0+ruIA(<$5FtLfSK zDeK7lEBE%y4h)X@78`3176v|=rXTowlBH^S-}L=ha>6?7DVs73)FNgmeH*Ek8A`3% zAzUe}JXK3vI?K7Vlk07wqc1^*@&ztsQY=S5joLWjstk(vNvToX)vlV?w*23DTRB}P?pM$zMw4gqs~p2M{S#d>t70E?Aw=bcI& z#e!8%{XEjF0xnp2(oU?Z*kHZ-aj(@Iz_r54OT7kTsfoBHh9Vn2xZqWUzWd-eiAV(` zB5~FD8+0TwAhAeC(@1(;!ibW4fkhO<5{r0^LnOyEMwP}UsFQdk1}BP&(8$%-5>$SF zu}5}hj9N=mF6@^Lq~Eklph;k7j^NzciQ8hdq1wUi{l-joI;mI{+~QL2i7GMXP+aI~ zLW;}l${u!*g$*R1+UU&SCo`Jt~1X?}VF23oNq zh$T`*1W^nnETYX!6j4NN?@K}hnJD6u7*HrG_!N<#@?B#0J~^wfmUURgpuw^h6^L2{ zR%Xv@tdhvVo*rUgpse7@B*&VL4_J)kf@E7Lh6f4@EwpUO6qADG)re67(@J6lODl!| zmUbc+S}_u=rIu_%@iM2h(8R8-oZ9q%Mi^Ie)v|KLi=E;^7ssTyLK}r^etll?f@XQo e7s~XAy!>^Ic_ml1zo7uzf>1I3flifJ;L~4uEK5%S literal 0 HcmV?d00001 From b3dc2a26d34c1326783e5afa7ab5906f9a34f46a Mon Sep 17 00:00:00 2001 From: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 10 Sep 2021 14:13:14 -0500 Subject: [PATCH 3/5] added tests --- tests/datasets/test_gid15.py | 67 ++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/datasets/test_gid15.py diff --git a/tests/datasets/test_gid15.py b/tests/datasets/test_gid15.py new file mode 100644 index 00000000000..7cc90edd924 --- /dev/null +++ b/tests/datasets/test_gid15.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path +from typing import Generator + +import pytest +import torch +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch + +import torchgeo.datasets.utils +from torchgeo.datasets import GID15 +from torchgeo.transforms import Identity + + +def download_url(url: str, root: str, *args: str) -> None: + shutil.copy(url, root) + + +class TestGID15: + @pytest.fixture(params=["train", "val", "test"]) + def dataset( + self, + monkeypatch: Generator[MonkeyPatch, None, None], + tmp_path: Path, + request: SubRequest, + ) -> GID15: + monkeypatch.setattr( # type: ignore[attr-defined] + torchgeo.datasets.utils, "download_url", download_url + ) + md5 = "3d5b1373ef9a3084ec493b9b2056fe07" + monkeypatch.setattr(GID15, "md5", md5) # type: ignore[attr-defined] + url = os.path.join("tests", "data", "gid15", "gid-15.zip") + monkeypatch.setattr(GID15, "url", url) # type: ignore[attr-defined] + root = str(tmp_path) + split = request.param + transforms = Identity() + return GID15(root, split, transforms, download=True, checksum=True) + + def test_getitem(self, dataset: GID15) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert x["image"].shape[0] == 3 + + if dataset.split != "test": + assert isinstance(x["mask"], torch.Tensor) + assert x["image"].shape[-2:] == x["mask"].shape[-2:] + else: + assert "mask" not in x + + def test_len(self, dataset: GID15) -> None: + assert len(dataset) == 2 + + def test_already_downloaded(self, dataset: GID15) -> None: + GID15(root=dataset.root, download=True) + + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + GID15(split="foo") + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + GID15(str(tmp_path)) From 7eed867dafb4b03f492fd8a86df96c3e5f8575f6 Mon Sep 17 00:00:00 2001 From: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 10 Sep 2021 14:13:29 -0500 Subject: [PATCH 4/5] added dataset --- torchgeo/datasets/__init__.py | 2 + torchgeo/datasets/gid15.py | 237 ++++++++++++++++++++++++++++++++++ 2 files changed, 239 insertions(+) create mode 100644 torchgeo/datasets/gid15.py diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index d249b6072b4..680d597f102 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -23,6 +23,7 @@ from .cv4a_kenya_crop_type import CV4AKenyaCropType from .cyclone import TropicalCycloneWindEstimation from .geo import GeoDataset, RasterDataset, VectorDataset, VisionDataset, ZipDataset +from .gid15 import GID15 from .landcoverai import LandCoverAI from .landsat import ( Landsat, @@ -81,6 +82,7 @@ "COWCCounting", "COWCDetection", "CV4AKenyaCropType", + "GID15", "LandCoverAI", "LEVIRCDPlus", "PatternNet", diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py new file mode 100644 index 00000000000..f26aa643fdd --- /dev/null +++ b/torchgeo/datasets/gid15.py @@ -0,0 +1,237 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""GID-15 dataset.""" + +import glob +import os +from typing import Callable, Dict, List, Optional + +import numpy as np +import torch +from PIL import Image +from torch import Tensor + +from .geo import VisionDataset +from .utils import download_and_extract_archive + + +class GID15(VisionDataset): + """GID-15 dataset. + + The `GID-15 `_ + dataset is a dataset for semantic segmentation. + + Dataset features: + * images taken by the Gaofen-2 (GF-2) satellite over 60 cities in China + * masks representing 15 semantic categories + * three spectral bands - RGB + * 150 with 3 m per pixel resolution (6800x7200 px) + + Dataset format: + * images are three-channel pngs + * masks are single-channel pngs + * colormapped masks are 3 channel tifs + + Dataset classes: + 1. background + 2. industrial_land + 3. urban_residential + 4. rural_residential + 5. traffic_land + 6. paddy_field + 7. irrigated_land + 8. dry_cropland + 9. garden_plot + 10. arbor_woodland + 11. shrub_land + 12. natural_grassland + 13. artificial_grassland + 14. river + 15. lake + 16. pond + + If you use this dataset in your research, please cite the following paper: + * https://arxiv.org/abs/1807.05713 + """ + + url = "https://drive.google.com/file/d/1zbkCEXPEKEV6gq19OKmIbaT8bXXfWW6u" + md5 = "615682bf659c3ed981826c6122c10c83" + filename = "gid-15.zip" + directory = "GID" + splits = ["train", "val", "test"] + classes = [ + "background", + "industrial_land", + "urban_residential", + "rural_residential", + "traffic_land", + "paddy_field", + "irrigated_land", + "dry_cropland", + "garden_plot", + "arbor_woodland", + "shrub_land", + "natural_grassland", + "artificial_grassland", + "river", + "lake", + "pond", + ] + + def __init__( + self, + root: str = "data", + split: str = "train", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new GID-15 dataset instance. + + Args: + root: root directory where dataset can be found + split: one of "train", "val", or "test" + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: if ``split`` argument is invalid + RuntimeError: if ``download=False`` and data is not found, or checksums + don't match + """ + assert split in self.splits + + self.root = root + self.split = split + self.transforms = transforms + self.checksum = checksum + + if download: + self._download() + + if not self._check_integrity(): + raise RuntimeError( + "Dataset not found or corrupted. " + + "You can use download=True to download it" + ) + + self.files = self._load_files(self.root, self.split) + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + files = self.files[index] + image = self._load_image(files["image"]) + + if self.split != "test": + mask = self._load_target(files["mask"]) + sample = {"image": image, "mask": mask} + else: + sample = {"image": image} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.files) + + def _load_files(self, root: str, split: str) -> List[Dict[str, str]]: + """Return the paths of the files in the dataset. + + Args: + root: root dir of dataset + split: subset of dataset, one of [train, val, test] + + Returns: + list of dicts containing paths for each pair of image1, image2, mask + """ + image_root = os.path.join(root, "GID", "img_dir") + images = glob.glob(os.path.join(image_root, split, "*.tif")) + images = sorted(images) + if split != "test": + masks = [ + image.replace("img_dir", "ann_dir").replace(".tif", "_15label.png") + for image in images + ] + else: + masks = [""] * len(images) + + files = [dict(image=image, mask=mask) for image, mask in zip(images, masks)] + return files + + def _load_image(self, path: str) -> Tensor: + """Load a single image. + + Args: + path: path to the image + + Returns: + the image + """ + filename = os.path.join(path) + with Image.open(filename) as img: + array = np.array(img.convert("RGB")) + tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _load_target(self, path: str) -> Tensor: + """Load the target mask for a single image. + + Args: + path: path to the image + + Returns: + the target mask + """ + filename = os.path.join(path) + with Image.open(filename) as img: + array = np.array(img.convert("L")) + tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] + tensor = tensor.to(torch.long) # type: ignore[attr-defined] + return tensor + + def _check_integrity(self) -> bool: + """Checks the integrity of the dataset structure. + + Returns: + True if the dataset directories and split files are found, else False + """ + filepath = os.path.join(self.root, self.directory) + if not os.path.exists(filepath): + return False + return True + + def _download(self) -> None: + """Download the dataset and extract it. + + Raises: + AssertionError: if the checksum of split.py does not match + """ + if self._check_integrity(): + print("Files already downloaded and verified") + return + + download_and_extract_archive( + self.url, + self.root, + filename=self.filename, + md5=self.md5 if self.checksum else None, + ) From 71300a21c67cc887f61f8c73283a0eee11820108 Mon Sep 17 00:00:00 2001 From: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> Date: Fri, 10 Sep 2021 14:37:28 -0500 Subject: [PATCH 5/5] format and removed empty masks from file list --- torchgeo/datasets/gid15.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index f26aa643fdd..5e6ec381f53 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -34,22 +34,22 @@ class GID15(VisionDataset): * colormapped masks are 3 channel tifs Dataset classes: - 1. background - 2. industrial_land - 3. urban_residential - 4. rural_residential - 5. traffic_land - 6. paddy_field - 7. irrigated_land - 8. dry_cropland - 9. garden_plot - 10. arbor_woodland - 11. shrub_land - 12. natural_grassland - 13. artificial_grassland - 14. river - 15. lake - 16. pond + 1. background + 2. industrial_land + 3. urban_residential + 4. rural_residential + 5. traffic_land + 6. paddy_field + 7. irrigated_land + 8. dry_cropland + 9. garden_plot + 10. arbor_woodland + 11. shrub_land + 12. natural_grassland + 13. artificial_grassland + 14. river + 15. lake + 16. pond If you use this dataset in your research, please cite the following paper: * https://arxiv.org/abs/1807.05713 @@ -159,7 +159,7 @@ def _load_files(self, root: str, split: str) -> List[Dict[str, str]]: split: subset of dataset, one of [train, val, test] Returns: - list of dicts containing paths for each pair of image1, image2, mask + list of dicts containing paths for each pair of image, mask """ image_root = os.path.join(root, "GID", "img_dir") images = glob.glob(os.path.join(image_root, split, "*.tif")) @@ -169,10 +169,10 @@ def _load_files(self, root: str, split: str) -> List[Dict[str, str]]: image.replace("img_dir", "ann_dir").replace(".tif", "_15label.png") for image in images ] + files = [dict(image=image, mask=mask) for image, mask in zip(images, masks)] else: - masks = [""] * len(images) + files = [dict(image=image) for image in images] - files = [dict(image=image, mask=mask) for image, mask in zip(images, masks)] return files def _load_image(self, path: str) -> Tensor: