Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add xView2 Dataset #236

Merged
merged 13 commits into from
Nov 15, 2021
6 changes: 6 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ UC Merced
.. autoclass:: UCMerced
.. autoclass:: UCMercedDataModule

xView2
^^^^^^

.. autoclass:: XView2
.. autoclass:: XView2DataModule

ZueriCrop
^^^^^^^^^

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
117 changes: 117 additions & 0 deletions tests/datasets/test_xview2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path
from typing import Generator

import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch

from torchgeo.datasets import XView2, XView2DataModule


class TestXView2:
@pytest.fixture(params=["train", "test"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.fixture(params=["train", "test"])
@pytest.fixture(scope="class", params=["train", "test"])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks the tests

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScopeMismatch: You tried to access the 'function' scoped fixture 'monkeypatch' with a 'class' scoped request object, involved factories

def dataset(
self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest
) -> XView2:
monkeypatch.setattr( # type: ignore[attr-defined]
XView2,
"metadata",
{
"train": {
"filename": "train_images_labels_targets.tar.gz",
"md5": "373e61d55c1b294aa76b94dbbd81332b",
"directory": "train",
},
"test": {
"filename": "test_images_labels_targets.tar.gz",
"md5": "bc6de81c956a3bada38b5b4e246266a1",
"directory": "test",
},
},
)
root = os.path.join("tests", "data", "xview2")
split = request.param
transforms = nn.Identity() # type: ignore[attr-defined]
return XView2(root, split, transforms, checksum=True)

def test_getitem(self, dataset: XView2) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)

def test_len(self, dataset: XView2) -> None:
assert len(dataset) == 2

def test_extract(self, tmp_path: Path) -> None:
shutil.copyfile(
os.path.join(
"tests", "data", "xview2", "train_images_labels_targets.tar.gz"
),
os.path.join(tmp_path, "train_images_labels_targets.tar.gz"),
)
shutil.copyfile(
os.path.join(
"tests", "data", "xview2", "test_images_labels_targets.tar.gz"
),
os.path.join(tmp_path, "test_images_labels_targets.tar.gz"),
)
XView2(root=str(tmp_path))

def test_corrupted(self, tmp_path: Path) -> None:
with open(
os.path.join(tmp_path, "train_images_labels_targets.tar.gz"), "w"
) as f:
f.write("bad")
with open(
os.path.join(tmp_path, "test_images_labels_targets.tar.gz"), "w"
) as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
XView2(root=str(tmp_path), checksum=True)

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
XView2(split="foo")

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
XView2(str(tmp_path))

def test_plot(self, dataset: XView2) -> None:
x = dataset[0].copy()
XView2.plot(x, suptitle="Test")
XView2.plot(x, show_titles=False)
x["prediction"] = x["mask"][0].clone()
XView2.plot(x)


class TestXView2DataModule:
@pytest.fixture(scope="class", params=[0.0, 0.5])
def datamodule(self, request: SubRequest) -> XView2DataModule:
root = os.path.join("tests", "data", "xview2")
batch_size = 1
num_workers = 0
val_split_size = request.param
dm = XView2DataModule(
root, batch_size, num_workers, val_split_pct=val_split_size
)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(self, datamodule: XView2DataModule) -> None:
next(iter(datamodule.train_dataloader()))

def test_val_dataloader(self, datamodule: XView2DataModule) -> None:
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: XView2DataModule) -> None:
next(iter(datamodule.test_dataloader()))
3 changes: 3 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet7
from .ucmerced import UCMerced, UCMercedDataModule
from .utils import BoundingBox, collate_dict
from .xview import XView2, XView2DataModule
from .zuericrop import ZueriCrop

__all__ = (
Expand Down Expand Up @@ -130,6 +131,8 @@
"UCMerced",
"UCMercedDataModule",
"VHR10",
"XView2",
"XView2DataModule",
"ZueriCrop",
# Base classes
"GeoDataset",
Expand Down
28 changes: 28 additions & 0 deletions torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch import Tensor
from torch.utils.data import Dataset, Subset, random_split
from torchvision.datasets.utils import check_integrity, download_url
from torchvision.utils import draw_segmentation_masks

__all__ = (
"check_integrity",
Expand All @@ -33,9 +34,13 @@
"rasterio_loader",
"dataset_split",
"sort_sentinel2_bands",
"draw_semantic_segmentation_masks",
)


ColorMap = Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



class _rarfile:
class RarFile:
def __init__(self, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -431,3 +436,26 @@ def sort_sentinel2_bands(x: str) -> str:
if x == "B8A":
x = "B08A"
return x


def draw_semantic_segmentation_masks(
image: Tensor, mask: Tensor, alpha: float = 0.5, colors: Optional[ColorMap] = None
) -> np.ndarray: # type: ignore[type-arg]
"""Overlay a semantic segmentation mask onto an image.

Args:
image: tensor of shape (3, h, w)
mask: tensor of shape (h, w) with pixel values representing the classes
alpha: alpha blend factor
colors: (Optional) list of RGB int tuples, or color strings e.g. red, #FF00FF
calebrob6 marked this conversation as resolved.
Show resolved Hide resolved
Returns:
a list of the subset datasets. Either [train, val] or [train, val, test]
"""
classes = torch.unique(mask) # type: ignore[attr-defined]
classes = classes[1:]
class_masks = mask == classes[:, None, None]
img = draw_segmentation_masks(
image=image, masks=class_masks, alpha=alpha, colors=colors
)
img = img.permute((1, 2, 0)).numpy()
return img # type: ignore[no-any-return]
Loading