-
Notifications
You must be signed in to change notification settings - Fork 379
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* draft * add dataset to __init__ * reorganize datasets and datamodules * fix mypy errors * draft * add dataset to __init__ * reorganize datasets and datamodules * fix mypy errors * refactor * Adding docs * Adding plotting, cleaning up some stuff * Black and isort * Fix the datamodule import * Pyupgrade * Fixing some docstrings * Flake8 * Isort * Fix docstrings in datamodules * Fixing fns and docstring * Trying to fix the docs * Trying to fix docs * Adding tests * Black * newline * Made the test dataset larger * Remove the datamodules * Update docs/api/non_geo_datasets.csv Co-authored-by: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> * Update torchgeo/datasets/pastis.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/pastis.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/pastis.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Updating cmap * Describe the different band combinations * Merging datasets * Handle the instance segmentation case in plotting * Update torchgeo/datasets/pastis.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Made some code prettier * Adding instance plotting --------- Co-authored-by: Caleb Robinson <calebrob6@gmail.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
- Loading branch information
1 parent
f0cacd5
commit 711a576
Showing
7 changed files
with
614 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -272,6 +272,11 @@ OSCD | |
|
||
.. autoclass:: OSCD | ||
|
||
PASTIS | ||
^^^^^^ | ||
|
||
.. autoclass:: PASTIS | ||
|
||
PatternNet | ||
^^^^^^^^^^ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import hashlib | ||
import os | ||
import shutil | ||
from typing import Union | ||
|
||
import fiona | ||
import numpy as np | ||
|
||
SIZE = 32 | ||
NUM_SAMPLES = 5 | ||
MAX_NUM_TIME_STEPS = 10 | ||
np.random.seed(0) | ||
|
||
FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]] | ||
|
||
filenames: FILENAME_HIERARCHY = { | ||
"DATA_S2": ["S2"], | ||
"DATA_S1A": ["S1A"], | ||
"DATA_S1D": ["S1D"], | ||
"ANNOTATIONS": ["TARGET"], | ||
"INSTANCE_ANNOTATIONS": ["INSTANCES"], | ||
} | ||
|
||
|
||
def create_file(path: str) -> None: | ||
for i in range(NUM_SAMPLES): | ||
new_path = f"{path}_{i}.npy" | ||
fn = os.path.basename(new_path) | ||
t = np.random.randint(1, MAX_NUM_TIME_STEPS) | ||
if fn.startswith("S2"): | ||
data = np.random.randint(0, 256, size=(t, 10, SIZE, SIZE)).astype(np.int16) | ||
elif fn.startswith("S1A"): | ||
data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16) | ||
elif fn.startswith("S1D"): | ||
data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16) | ||
elif fn.startswith("TARGET"): | ||
data = np.random.randint(0, 20, size=(3, SIZE, SIZE)).astype(np.uint8) | ||
elif fn.startswith("INSTANCES"): | ||
data = np.random.randint(0, 100, size=(SIZE, SIZE)).astype(np.int64) | ||
np.save(new_path, data) | ||
|
||
|
||
def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: | ||
if isinstance(hierarchy, dict): | ||
# Recursive case | ||
for key, value in hierarchy.items(): | ||
path = os.path.join(directory, key) | ||
os.makedirs(path, exist_ok=True) | ||
create_directory(path, value) | ||
else: | ||
# Base case | ||
for value in hierarchy: | ||
path = os.path.join(directory, value) | ||
create_file(path) | ||
|
||
|
||
if __name__ == "__main__": | ||
create_directory("PASTIS-R", filenames) | ||
|
||
schema = {"geometry": "Polygon", "properties": {"Fold": "int", "ID_PATCH": "int"}} | ||
with fiona.open( | ||
os.path.join("PASTIS-R", "metadata.geojson"), | ||
"w", | ||
"GeoJSON", | ||
crs="EPSG:4326", | ||
schema=schema, | ||
) as f: | ||
for i in range(NUM_SAMPLES): | ||
f.write( | ||
{ | ||
"geometry": { | ||
"type": "Polygon", | ||
"coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]], | ||
}, | ||
"id": str(i), | ||
"properties": {"Fold": i % 5, "ID_PATCH": i}, | ||
} | ||
) | ||
|
||
filename = "PASTIS-R.zip" | ||
shutil.make_archive(filename.replace(".zip", ""), "zip", ".", "PASTIS-R") | ||
|
||
# Compute checksums | ||
with open(filename, "rb") as f: | ||
md5 = hashlib.md5(f.read()).hexdigest() | ||
print(f"{filename}: {md5}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
import shutil | ||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from _pytest.fixtures import SubRequest | ||
from pytest import MonkeyPatch | ||
from torch.utils.data import ConcatDataset | ||
|
||
import torchgeo.datasets.utils | ||
from torchgeo.datasets import PASTIS | ||
|
||
|
||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: | ||
shutil.copy(url, root) | ||
|
||
|
||
class TestPASTIS: | ||
@pytest.fixture( | ||
params=[ | ||
{"folds": (0, 1), "bands": "s2", "mode": "semantic"}, | ||
{"folds": (0, 1), "bands": "s1a", "mode": "semantic"}, | ||
{"folds": (0, 1), "bands": "s1d", "mode": "instance"}, | ||
] | ||
) | ||
def dataset( | ||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest | ||
) -> PASTIS: | ||
monkeypatch.setattr(torchgeo.datasets.pastis, "download_url", download_url) | ||
|
||
md5 = "9b11ae132623a0d13f7f0775d2003703" | ||
monkeypatch.setattr(PASTIS, "md5", md5) | ||
url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip") | ||
monkeypatch.setattr(PASTIS, "url", url) | ||
root = str(tmp_path) | ||
folds = request.param["folds"] | ||
bands = request.param["bands"] | ||
mode = request.param["mode"] | ||
transforms = nn.Identity() | ||
return PASTIS( | ||
root, folds, bands, mode, transforms, download=True, checksum=True | ||
) | ||
|
||
def test_getitem_semantic(self, dataset: PASTIS) -> None: | ||
x = dataset[0] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["image"], torch.Tensor) | ||
assert isinstance(x["mask"], torch.Tensor) | ||
|
||
def test_getitem_instance(self, dataset: PASTIS) -> None: | ||
dataset.mode = "instance" | ||
x = dataset[0] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["image"], torch.Tensor) | ||
assert isinstance(x["mask"], torch.Tensor) | ||
assert isinstance(x["boxes"], torch.Tensor) | ||
assert isinstance(x["label"], torch.Tensor) | ||
|
||
def test_len(self, dataset: PASTIS) -> None: | ||
assert len(dataset) == 2 | ||
|
||
def test_add(self, dataset: PASTIS) -> None: | ||
ds = dataset + dataset | ||
assert isinstance(ds, ConcatDataset) | ||
assert len(ds) == 4 | ||
|
||
def test_already_extracted(self, dataset: PASTIS) -> None: | ||
PASTIS(root=dataset.root, download=True) | ||
|
||
def test_already_downloaded(self, tmp_path: Path) -> None: | ||
url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip") | ||
root = str(tmp_path) | ||
shutil.copy(url, root) | ||
PASTIS(root) | ||
|
||
def test_not_downloaded(self, tmp_path: Path) -> None: | ||
with pytest.raises(RuntimeError, match="Dataset not found"): | ||
PASTIS(str(tmp_path)) | ||
|
||
def test_corrupted(self, tmp_path: Path) -> None: | ||
with open(os.path.join(tmp_path, "PASTIS-R.zip"), "w") as f: | ||
f.write("bad") | ||
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): | ||
PASTIS(root=str(tmp_path), checksum=True) | ||
|
||
def test_invalid_fold(self) -> None: | ||
with pytest.raises(AssertionError): | ||
PASTIS(folds=(6,)) | ||
|
||
def test_invalid_mode(self) -> None: | ||
with pytest.raises(AssertionError): | ||
PASTIS(mode="invalid") | ||
|
||
def test_plot(self, dataset: PASTIS) -> None: | ||
x = dataset[0].copy() | ||
dataset.plot(x, suptitle="Test") | ||
plt.close() | ||
dataset.plot(x, show_titles=False) | ||
plt.close() | ||
x["prediction"] = x["mask"].clone() | ||
if dataset.mode == "instance": | ||
x["prediction_labels"] = x["label"].clone() | ||
dataset.plot(x) | ||
plt.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.