forked from microsoft/torchgeo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
L7 Irish: add new dataset/datamodule (microsoft#1197)
* Landsat 7 Irisih: add new dataset/datamodule adding [Landsat 7 Cloud Cover Assessment Validation Data](https://landsat.usgs.gov/landsat-7-cloud-cover-assessment-validation-data) from USGS. The files are .tif format. The naming pattern is slightly different from the Landsat class. Thus, the customized RasterDataset class is used. * # Changes to be committed: # deleted: .DS_Store # deleted: torchgeo/.DS_Store # # Changes not staged for commit: # modified: .gitignore # modified: torchgeo/datasets/l7irish.py * # Changes to be committed: # modified: .gitignore # modified: torchgeo/datasets/l7irish.py * add: test_l7irish.py and data.py modify: .gitignore and l7irish.py * # Changes to be committed: # modified: .gitignore # new file: tests/data/l7irish/data.py # modified: torchgeo/datasets/l7irish.py * modified: tests/datasets/test_l7irish.py, torchgeo/datasets/__init__.py, torchgeo/datasets/l7irish.py * added data.py, austral.tar.gz, test_l7irish.py modified l7irish.py * remove comments in test_l7irish.py * resolve black and flake8 issues * Fixed _getitem Added l7irish to __init__ Added L7Irish details to geo_datasets.csv Will work on datamodules/l7irish.py * Added L7 Irish datamodule * fix flake8 space error * fix black test error * chmod +x for data.py * Update docs/api/datamodules.rst Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update docs/api/datasets.rst Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update docs/api/geo_datasets.csv Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/l7irish.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/l7irish.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Resolved minor issues in l7irish.py Added mask assert in test_l7irish.py * Improved _getitem and plot functions * Added new artificial data with 5 scenes Updated test_l7irish.py * remove comments in l7irish.py * resolve black, flake8, and isort errors * add l7irish.yaml and refine test_segmentation.py * modified l7irish.yaml * revert a change in .gitignore * add function test_rgb_bands_absent_plot() * resolve black test issue * Update torchgeo/datasets/l7irish.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/l7irish.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Updaye l7irish.py and create new test data * update l7irish.py for style tests * remove old test data * Update tests/data/l7irish/data.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/l7irish.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * update data.py and l7irish.py * update md5s, citations, masks, and thermal bands * update mask mapping * update formatting * update mask path * Update torchgeo/datasets/l7irish.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/l7irish.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update tests/data/l7irish/data.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update docs/api/geo_datasets.csv Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update tests/conf/l7irish.yaml Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * resolve issues from comments * Update L7 Irish link Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * update mask data generation and review changes * Fix checksums --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
- Loading branch information
1 parent
5d750c0
commit e6da03e
Showing
63 changed files
with
620 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,6 +93,11 @@ iNaturalist | |
|
||
.. autoclass:: INaturalist | ||
|
||
L7 Irish | ||
^^^^^^^^ | ||
|
||
.. autoclass:: L7Irish | ||
|
||
L8 Biome | ||
^^^^^^^^ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
experiment: | ||
task: "l7irish" | ||
module: | ||
loss: "ce" | ||
model: "unet" | ||
backbone: "resnet18" | ||
weights: null | ||
learning_rate: 1e-3 | ||
learning_rate_schedule_patience: 6 | ||
verbose: false | ||
in_channels: 9 | ||
num_classes: 5 | ||
num_filters: 1 | ||
ignore_index: 0 | ||
datamodule: | ||
root: "tests/data/l7irish" | ||
download: true | ||
batch_size: 1 | ||
patch_size: 32 | ||
length: 5 | ||
num_workers: 0 |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import hashlib | ||
import os | ||
import shutil | ||
from typing import Dict, List, Union | ||
|
||
import numpy as np | ||
import rasterio | ||
from rasterio import Affine | ||
from rasterio.crs import CRS | ||
|
||
SIZE = 36 | ||
|
||
np.random.seed(0) | ||
|
||
FILENAME_HIERARCHY = Union[Dict[str, "FILENAME_HIERARCHY"], List[str]] | ||
|
||
bands = [ | ||
"B10.TIF", | ||
"B20.TIF", | ||
"B30.TIF", | ||
"B40.TIF", | ||
"B50.TIF", | ||
"B61.TIF", | ||
"B62.TIF", | ||
"B70.TIF", | ||
"B80.TIF", | ||
] | ||
|
||
filenames: FILENAME_HIERARCHY = { | ||
"austral": {"p226_r98": [], "p227_r98": [], "p231_r93_2": []}, | ||
"boreal": {"p2_r27": [], "p143_r21_3": []}, | ||
} | ||
prefixes = [ | ||
"L71226098_09820011112", | ||
"L71227098_09820011103", | ||
"L71231093_09320010507", | ||
"L71002027_02720010604", | ||
"L71143021_02120010803", | ||
] | ||
|
||
for land_type, patches in filenames.items(): | ||
for patch in patches: | ||
path, row = patch.split("_")[:2] | ||
key = path[1:].zfill(3) + row[1:].zfill(3) | ||
for prefix in prefixes: | ||
if key in prefix: | ||
for band in bands: | ||
if band in ["B62.TIF", "B70.TIF", "B80.TIF"]: | ||
prefix = prefix.replace("L71", "L72") | ||
filenames[land_type][patch].append(f"{prefix}_{band}") | ||
|
||
filenames[land_type][patch].append(f"L7_{path}_{row}_newmask2015.TIF") | ||
|
||
|
||
def create_file(path: str) -> None: | ||
dtype = "uint8" | ||
profile = { | ||
"driver": "GTiff", | ||
"dtype": dtype, | ||
"width": SIZE, | ||
"height": SIZE, | ||
"count": 1, | ||
"crs": CRS.from_epsg(32719), | ||
"transform": Affine(30.0, 0.0, 462884.99999999994, 0.0, -30.0, 4071915.0), | ||
} | ||
|
||
if path.endswith("B80.TIF"): | ||
profile["transform"] = Affine( | ||
15.0, 0.0, 462892.49999999994, 0.0, -15.0, 4071907.5 | ||
) | ||
profile["width"] = profile["height"] = SIZE * 2 | ||
|
||
if path.endswith("_newmask2015.TIF"): | ||
Z = np.random.choice( | ||
np.array([0, 64, 128, 191, 255], dtype=dtype), size=(SIZE, SIZE) | ||
) | ||
|
||
else: | ||
Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"]) | ||
|
||
with rasterio.open(path, "w", **profile) as src: | ||
src.write(Z, 1) | ||
|
||
|
||
def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: | ||
if isinstance(hierarchy, dict): | ||
# Recursive case | ||
for key, value in hierarchy.items(): | ||
path = os.path.join(directory, key) | ||
os.makedirs(path, exist_ok=True) | ||
create_directory(path, value) | ||
else: | ||
# Base case | ||
for value in hierarchy: | ||
path = os.path.join(directory, value) | ||
create_file(path) | ||
|
||
|
||
if __name__ == "__main__": | ||
create_directory(".", filenames) | ||
|
||
directories = ["austral", "boreal"] | ||
for directory in directories: | ||
filename = str(directory) | ||
|
||
# Create tarballs | ||
shutil.make_archive(filename, "gztar", ".", directory) | ||
|
||
# # Compute checksums | ||
with open(f"{filename}.tar.gz", "rb") as f: | ||
md5 = hashlib.md5(f.read()).hexdigest() | ||
print(filename, md5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import glob | ||
import os | ||
import shutil | ||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from _pytest.monkeypatch import MonkeyPatch | ||
from rasterio.crs import CRS | ||
|
||
import torchgeo.datasets.utils | ||
from torchgeo.datasets import BoundingBox, IntersectionDataset, L7Irish, UnionDataset | ||
|
||
|
||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: | ||
shutil.copy(url, root) | ||
|
||
|
||
class TestL7Irish: | ||
@pytest.fixture | ||
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> L7Irish: | ||
monkeypatch.setattr(torchgeo.datasets.l7irish, "download_url", download_url) | ||
md5s = { | ||
"austral": "c06147330141517f7eee55ea931c4787", | ||
"boreal": "4b598e55f0d6d33da3672190ebf96268", | ||
} | ||
|
||
url = os.path.join("tests", "data", "l7irish", "{}.tar.gz") | ||
monkeypatch.setattr(L7Irish, "url", url) | ||
monkeypatch.setattr(L7Irish, "md5s", md5s) | ||
root = str(tmp_path) | ||
transforms = nn.Identity() | ||
return L7Irish(root, transforms=transforms, download=True, checksum=True) | ||
|
||
def test_getitem(self, dataset: L7Irish) -> None: | ||
x = dataset[dataset.bounds] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["crs"], CRS) | ||
assert isinstance(x["image"], torch.Tensor) | ||
assert isinstance(x["mask"], torch.Tensor) | ||
|
||
def test_and(self, dataset: L7Irish) -> None: | ||
ds = dataset & dataset | ||
assert isinstance(ds, IntersectionDataset) | ||
|
||
def test_or(self, dataset: L7Irish) -> None: | ||
ds = dataset | dataset | ||
assert isinstance(ds, UnionDataset) | ||
|
||
def test_plot(self, dataset: L7Irish) -> None: | ||
x = dataset[dataset.bounds] | ||
dataset.plot(x, suptitle="Test") | ||
plt.close() | ||
|
||
def test_already_extracted(self, dataset: L7Irish) -> None: | ||
L7Irish(root=dataset.root, download=True) | ||
|
||
def test_already_downloaded(self, tmp_path: Path) -> None: | ||
pathname = os.path.join("tests", "data", "l7irish", "*.tar.gz") | ||
root = str(tmp_path) | ||
for tarfile in glob.iglob(pathname): | ||
shutil.copy(tarfile, root) | ||
L7Irish(root) | ||
|
||
def test_not_downloaded(self, tmp_path: Path) -> None: | ||
with pytest.raises(RuntimeError, match="Dataset not found"): | ||
L7Irish(str(tmp_path)) | ||
|
||
def test_plot_prediction(self, dataset: L7Irish) -> None: | ||
x = dataset[dataset.bounds] | ||
x["prediction"] = x["mask"].clone() | ||
dataset.plot(x, suptitle="Prediction") | ||
plt.close() | ||
|
||
def test_invalid_query(self, dataset: L7Irish) -> None: | ||
query = BoundingBox(0, 0, 0, 0, 0, 0) | ||
with pytest.raises( | ||
IndexError, match="query: .* not found in index with bounds:" | ||
): | ||
dataset[query] | ||
|
||
def test_rgb_bands_absent_plot(self, dataset: L7Irish) -> None: | ||
with pytest.raises( | ||
ValueError, match="Dataset doesn't contain some of the RGB bands" | ||
): | ||
ds = L7Irish(root=dataset.root, bands=["B1", "B2", "B5"]) | ||
x = ds[ds.bounds] | ||
ds.plot(x, suptitle="Test") | ||
plt.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
"""L7 Irish datamodule.""" | ||
|
||
from typing import Any, Tuple, Union | ||
|
||
import torch | ||
|
||
from ..datasets import L7Irish, random_bbox_assignment | ||
from ..samplers import GridGeoSampler, RandomBatchGeoSampler | ||
from .geo import GeoDataModule | ||
|
||
|
||
class L7IrishDataModule(GeoDataModule): | ||
"""LightningDataModule implementation for the L7 Irish dataset. | ||
.. versionadded:: 0.5 | ||
""" | ||
|
||
mean = torch.tensor(0) | ||
std = torch.tensor(10000) | ||
|
||
def __init__( | ||
self, | ||
batch_size: int = 1, | ||
patch_size: Union[int, Tuple[int, int]] = 32, | ||
length: int = 5, | ||
num_workers: int = 0, | ||
**kwargs: Any, | ||
) -> None: | ||
"""Initialize a new L7IrishDataModule instance. | ||
Args: | ||
batch_size: Size of each mini-batch. | ||
patch_size: Size of each patch, either ``size`` or ``(height, width)``. | ||
length: Length of each training epoch. | ||
num_workers: Number of workers for parallel data loading. | ||
**kwargs: Additional keyword arguments passed to | ||
:class:`~torchgeo.datasets.L7Irish`. | ||
""" | ||
super().__init__( | ||
L7Irish, | ||
batch_size=batch_size, | ||
patch_size=patch_size, | ||
length=length, | ||
num_workers=num_workers, | ||
**kwargs, | ||
) | ||
|
||
def setup(self, stage: str) -> None: | ||
"""Set up datasets. | ||
Args: | ||
stage: Either 'fit', 'validate', 'test', or 'predict'. | ||
""" | ||
dataset = L7Irish(**self.kwargs) | ||
generator = torch.Generator().manual_seed(0) | ||
( | ||
self.train_dataset, | ||
self.val_dataset, | ||
self.test_dataset, | ||
) = random_bbox_assignment(dataset, [0.6, 0.2, 0.2], generator) | ||
|
||
if stage in ["fit"]: | ||
self.train_batch_sampler = RandomBatchGeoSampler( | ||
self.train_dataset, self.patch_size, self.batch_size, self.length | ||
) | ||
if stage in ["fit", "validate"]: | ||
self.val_sampler = GridGeoSampler( | ||
self.val_dataset, self.patch_size, self.patch_size | ||
) | ||
if stage in ["test"]: | ||
self.test_sampler = GridGeoSampler( | ||
self.test_dataset, self.patch_size, self.patch_size | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.