Skip to content

Commit

Permalink
PASTIS dataset (#315)
Browse files Browse the repository at this point in the history
* draft

* add dataset to __init__

* reorganize datasets and datamodules

* fix mypy errors

* draft

* add dataset to __init__

* reorganize datasets and datamodules

* fix mypy errors

* refactor

* Adding docs

* Adding plotting, cleaning up some stuff

* Black and isort

* Fix the datamodule import

* Pyupgrade

* Fixing some docstrings

* Flake8

* Isort

* Fix docstrings in datamodules

* Fixing fns and docstring

* Trying to fix the docs

* Trying to fix docs

* Adding tests

* Black

* newline

* Made the test dataset larger

* Remove the datamodules

* Update docs/api/non_geo_datasets.csv

Co-authored-by: Isaac Corley <22203655+isaaccorley@users.noreply.github.com>

* Update torchgeo/datasets/pastis.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Update torchgeo/datasets/pastis.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Update torchgeo/datasets/pastis.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Updating cmap

* Describe the different band combinations

* Merging datasets

* Handle the instance segmentation case in plotting

* Update torchgeo/datasets/pastis.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Made some code prettier

* Adding instance plotting

---------

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
3 people authored Aug 3, 2023
1 parent f0cacd5 commit 711a576
Show file tree
Hide file tree
Showing 7 changed files with 614 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ OSCD

.. autoclass:: OSCD

PASTIS
^^^^^^

.. autoclass:: PASTIS

PatternNet
^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`Million-AID`_,C,Google Earth,1M,51--73,,0.5--153,RGB
`NASA Marine Debris`_,OD,PlanetScope,707,1,256x256,3,RGB
`OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI
`PASTIS`_,I,Sentinel-1/2,"2,433",19,128x128xT,10,MSI
`PatternNet`_,C,Google Earth,"30,400",38,256x256,0.06--5,RGB
`Potsdam`_,S,Aerial,38,6,"6,000x6,000",0.05,MSI
`ReforesTree`_,"OD, R",Aerial,100,6,"4,000x4,000",0.02,RGB
Expand Down
Binary file added tests/data/pastis/PASTIS-R.zip
Binary file not shown.
91 changes: 91 additions & 0 deletions tests/data/pastis/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil
from typing import Union

import fiona
import numpy as np

SIZE = 32
NUM_SAMPLES = 5
MAX_NUM_TIME_STEPS = 10
np.random.seed(0)

FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]]

filenames: FILENAME_HIERARCHY = {
"DATA_S2": ["S2"],
"DATA_S1A": ["S1A"],
"DATA_S1D": ["S1D"],
"ANNOTATIONS": ["TARGET"],
"INSTANCE_ANNOTATIONS": ["INSTANCES"],
}


def create_file(path: str) -> None:
for i in range(NUM_SAMPLES):
new_path = f"{path}_{i}.npy"
fn = os.path.basename(new_path)
t = np.random.randint(1, MAX_NUM_TIME_STEPS)
if fn.startswith("S2"):
data = np.random.randint(0, 256, size=(t, 10, SIZE, SIZE)).astype(np.int16)
elif fn.startswith("S1A"):
data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16)
elif fn.startswith("S1D"):
data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16)
elif fn.startswith("TARGET"):
data = np.random.randint(0, 20, size=(3, SIZE, SIZE)).astype(np.uint8)
elif fn.startswith("INSTANCES"):
data = np.random.randint(0, 100, size=(SIZE, SIZE)).astype(np.int64)
np.save(new_path, data)


def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None:
if isinstance(hierarchy, dict):
# Recursive case
for key, value in hierarchy.items():
path = os.path.join(directory, key)
os.makedirs(path, exist_ok=True)
create_directory(path, value)
else:
# Base case
for value in hierarchy:
path = os.path.join(directory, value)
create_file(path)


if __name__ == "__main__":
create_directory("PASTIS-R", filenames)

schema = {"geometry": "Polygon", "properties": {"Fold": "int", "ID_PATCH": "int"}}
with fiona.open(
os.path.join("PASTIS-R", "metadata.geojson"),
"w",
"GeoJSON",
crs="EPSG:4326",
schema=schema,
) as f:
for i in range(NUM_SAMPLES):
f.write(
{
"geometry": {
"type": "Polygon",
"coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]],
},
"id": str(i),
"properties": {"Fold": i % 5, "ID_PATCH": i},
}
)

filename = "PASTIS-R.zip"
shutil.make_archive(filename.replace(".zip", ""), "zip", ".", "PASTIS-R")

# Compute checksums
with open(filename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{filename}: {md5}")
110 changes: 110 additions & 0 deletions tests/datasets/test_pastis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset

import torchgeo.datasets.utils
from torchgeo.datasets import PASTIS


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


class TestPASTIS:
@pytest.fixture(
params=[
{"folds": (0, 1), "bands": "s2", "mode": "semantic"},
{"folds": (0, 1), "bands": "s1a", "mode": "semantic"},
{"folds": (0, 1), "bands": "s1d", "mode": "instance"},
]
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> PASTIS:
monkeypatch.setattr(torchgeo.datasets.pastis, "download_url", download_url)

md5 = "9b11ae132623a0d13f7f0775d2003703"
monkeypatch.setattr(PASTIS, "md5", md5)
url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip")
monkeypatch.setattr(PASTIS, "url", url)
root = str(tmp_path)
folds = request.param["folds"]
bands = request.param["bands"]
mode = request.param["mode"]
transforms = nn.Identity()
return PASTIS(
root, folds, bands, mode, transforms, download=True, checksum=True
)

def test_getitem_semantic(self, dataset: PASTIS) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)

def test_getitem_instance(self, dataset: PASTIS) -> None:
dataset.mode = "instance"
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
assert isinstance(x["boxes"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)

def test_len(self, dataset: PASTIS) -> None:
assert len(dataset) == 2

def test_add(self, dataset: PASTIS) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
assert len(ds) == 4

def test_already_extracted(self, dataset: PASTIS) -> None:
PASTIS(root=dataset.root, download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip")
root = str(tmp_path)
shutil.copy(url, root)
PASTIS(root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
PASTIS(str(tmp_path))

def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "PASTIS-R.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
PASTIS(root=str(tmp_path), checksum=True)

def test_invalid_fold(self) -> None:
with pytest.raises(AssertionError):
PASTIS(folds=(6,))

def test_invalid_mode(self) -> None:
with pytest.raises(AssertionError):
PASTIS(mode="invalid")

def test_plot(self, dataset: PASTIS) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["mask"].clone()
if dataset.mode == "instance":
x["prediction_labels"] = x["label"].clone()
dataset.plot(x)
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from .nlcd import NLCD
from .openbuildings import OpenBuildings
from .oscd import OSCD
from .pastis import PASTIS
from .patternnet import PatternNet
from .potsdam import Potsdam2D
from .reforestree import ReforesTree
Expand Down Expand Up @@ -194,6 +195,7 @@
"MillionAID",
"NASAMarineDebris",
"OSCD",
"PASTIS",
"PatternNet",
"Potsdam2D",
"RESISC45",
Expand Down
Loading

0 comments on commit 711a576

Please sign in to comment.