Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSL4EO Landsat Downstream Dataset/module CDL, NLCD #1338

Merged
merged 28 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/conf/ssl4eo_l_benchmark_cdl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module:
_target_: torchgeo.trainers.SemanticSegmentationTask
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 7
num_classes: 17
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
num_filters: 1
ignore_index: 0

datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
root: "tests/data/ssl4eo_benchmark_landsat"
input_sensor: "tm_toa"
mask_product: "cdl"
batch_size: 2
num_workers: 0
20 changes: 20 additions & 0 deletions tests/conf/ssl4eo_l_benchmark_nlcd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module:
_target_: torchgeo.trainers.SemanticSegmentationTask
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 7
num_classes: 17
num_filters: 1
ignore_index: 0

datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
root: "tests/data/ssl4eo_benchmark_landsat"
input_sensor: "tm_toa"
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
mask_product: "nlcd"
batch_size: 2
num_workers: 0
183 changes: 183 additions & 0 deletions tests/data/ssl4eo_benchmark_landsat/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil
from typing import Union

import numpy as np
import rasterio
from rasterio import Affine
from rasterio.crs import CRS

SIZE = 36

np.random.seed(0)

FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]]

filenames: FILENAME_HIERARCHY = {
"tm_toa": {
"0000002": {"LT05_172034_20010526": ["all_bands.tif"]},
"0000005": {"LT05_223084_20010413": ["all_bands.tif"]},
"0000007": {"LT05_172034_20020902": ["all_bands.tif"]},
},
"etm_sr": {
"0000002": {"LE07_172034_20010526": ["all_bands.tif"]},
"0000005": {"LE07_223084_20010413": ["all_bands.tif"]},
"0000007": {"LE07_172034_20020902": ["all_bands.tif"]},
},
"etm_toa": {
"0000002": {"LE07_172034_20010526": ["all_bands.tif"]},
"0000005": {"LE07_223084_20010413": ["all_bands.tif"]},
"0000007": {"LE07_172034_20020902": ["all_bands.tif"]},
},
"oli_tirs_toa": {
"0000002": {"LC08_172034_20210306": ["all_bands.tif"]},
"0000005": {"LC08_223084_20210412": ["all_bands.tif"]},
"0000007": {"LC08_172034_20020902": ["all_bands.tif"]},
},
"oli_sr": {
"0000002": {"LC08_172034_20210306": ["all_bands.tif"]},
"0000005": {"LC08_223084_20210412": ["all_bands.tif"]},
"0000007": {"LC08_172034_20020902": ["all_bands.tif"]},
},
}

num_bands = {"tm_toa": 7, "etm_sr": 6, "etm_toa": 9, "oli_tirs_toa": 11, "oli_sr": 7}
years = {"tm": 2011, "etm": 2019, "oli": 2019}


def create_image(path: str) -> None:
profile = {
"driver": "GTiff",
"dtype": "uint8",
"nodata": None,
"width": SIZE,
"height": SIZE,
"count": num_bands["_".join(path.split(os.sep)[1].split("_")[2:][:-1])],
"crs": CRS.from_epsg(4326),
"transform": Affine(
0.00037672803497508636,
0.0,
-109.07063613660262,
0.0,
-0.0002554026278261721,
47.49838726154881,
),
"blockysize": 1,
"tiled": False,
"compress": "lzw",
"interleave": "pixel",
}

Z = np.random.randint(low=0, high=255, size=(SIZE, SIZE))

with rasterio.open(path, "w", **profile) as src:
for i in src.indexes:
src.write(Z, i)


def create_mask(path: str) -> None:
profile = {
"driver": "GTiff",
"dtype": "uint8",
"nodata": None,
"width": SIZE,
"height": SIZE,
"count": 1,
"crs": CRS.from_epsg(4326),
"transform": Affine(
0.00037672803497508636,
0.0,
-109.07063613660262,
0.0,
-0.0002554026278261721,
47.49838726154881,
),
"blockysize": 1,
"tiled": False,
"compress": "lzw",
"interleave": "band",
}

Z = np.random.randint(low=0, high=10, size=(1, SIZE, SIZE))

with rasterio.open(path, "w", **profile) as src:
src.write(Z)


def create_img_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None:
if isinstance(hierarchy, dict):
# Recursive case
for key, value in hierarchy.items():
if any([x in key for x in filenames.keys()]):
key = f"ssl4eo_l_{key}_benchmark"
path = os.path.join(directory, key)
os.makedirs(path, exist_ok=True)
create_img_directory(path, value)
else:
# Base case
for value in hierarchy:
path = os.path.join(directory, value)
create_image(path)


def create_mask_directory(
directory: str, hierarchy: FILENAME_HIERARCHY, mask_product: str
) -> None:
if isinstance(hierarchy, dict):
# Recursive case
for key, value in hierarchy.items():
path = os.path.join(directory, key)
os.makedirs(path, exist_ok=True)
create_mask_directory(path, value, mask_product)
else:
# Base case
for value in hierarchy:
path = os.path.join(directory, value)
year = years[path.split(os.sep)[1].split("_")[2]]
create_mask(path.replace("all_bands", f"{mask_product}_{year}"))


def create_tarballs(directories) -> None:
for directory in directories:
# Create tarballs
shutil.make_archive(directory, "gztar", ".", directory)

# Compute checksums
with open(f"{directory}.tar.gz", "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(directory, md5)


if __name__ == "__main__":
# image directories
create_img_directory(".", filenames)
directories = filenames.keys()
directories = [f"ssl4eo_l_{key}_benchmark" for key in directories]
create_tarballs(directories)

# mask directory cdl
mask_keep = ["tm_toa", "etm_sr", "oli_sr"]
mask_filenames = {
f"ssl4eo_l_{key.split('_')[0]}_cdl": val
for key, val in filenames.items()
if key in mask_keep
}
create_mask_directory(".", mask_filenames, "cdl")
directories = mask_filenames.keys()
create_tarballs(directories)

# mask directory nlcd
mask_filenames = {
f"ssl4eo_l_{key.split('_')[0]}_nlcd": val
for key, val in filenames.items()
if key in mask_keep
}
create_mask_directory(".", mask_filenames, "nlcd")
directories = mask_filenames.keys()
create_tarballs(directories)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
135 changes: 135 additions & 0 deletions tests/datasets/test_ssl4eo_benchmark_landsat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import glob
import os
import shutil
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset

import torchgeo.datasets.utils
from torchgeo.datasets import SSL4EOLBenchmark


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


class TestSSL4EOLBenchmark:
@pytest.fixture(
params=product(
["tm_toa", "etm_toa", "etm_sr", "oli_tirs_toa", "oli_sr"],
["cdl", "nlcd"],
["train", "val", "test"],
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
)
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> SSL4EOLBenchmark:
monkeypatch.setattr(
torchgeo.datasets.ssl4eo_benchmark_landsat, "download_url", download_url
)
root = str(tmp_path)

url = os.path.join("tests", "data", "ssl4eo_benchmark_landsat", "{}.tar.gz")
monkeypatch.setattr(SSL4EOLBenchmark, "url", url)

input_sensor, mask_product, split = request.param
monkeypatch.setattr(
SSL4EOLBenchmark, "split_percentages", [1 / 3, 1 / 3, 1 / 3]
)

img_md5s = {
"tm_toa": "27f0562206baec86c5fdd1d7f069ef91",
"etm_toa": "0350f83c8462a64ffd192d8ebe070842",
"etm_sr": "277e1657b89e141fa3085fd01053162d",
"oli_tirs_toa": "53350e7ee0616df47859d28a29e170da",
"oli_sr": "8235bcce500657b9e0cfcb3af6bb1480",
}
monkeypatch.setattr(SSL4EOLBenchmark, "img_md5s", img_md5s)

mask_md5s = {
"tm": {
"cdl": "762104b3fc41afe1ef63f5ea80940d4b",
"nlcd": "57391b79a33ccd482471b377ae2de7f1",
},
"etm": {
"cdl": "8285e0d051081a9379cd150c7669971e",
"nlcd": "916f4a433df6c8abca15b45b60d005d3",
},
"oli": {
"cdl": "729a7b75b8749c8a7f26e5ece164e73f",
"nlcd": "e237adcee8b43d4eca86a6d169ae2761",
},
}
monkeypatch.setattr(SSL4EOLBenchmark, "mask_md5s", mask_md5s)

transforms = nn.Identity()
return SSL4EOLBenchmark(
root=root,
input_sensor=input_sensor,
mask_product=mask_product,
split=split,
transforms=transforms,
download=True,
checksum=True,
)

def test_getitem(self, dataset: SSL4EOLBenchmark) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
SSL4EOLBenchmark(split="foo")

def test_invalid_input_sensor(self) -> None:
with pytest.raises(AssertionError):
SSL4EOLBenchmark(split="foo")

def test_invalid_mask_product(self) -> None:
with pytest.raises(AssertionError):
SSL4EOLBenchmark(split="foo")
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

def test_add(self, dataset: SSL4EOLBenchmark) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)

def test_already_extracted(self, dataset: SSL4EOLBenchmark) -> None:
SSL4EOLBenchmark(
root=dataset.root,
input_sensor=dataset.input_sensor,
mask_product=dataset.mask_product,
download=True,
)

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "ssl4eo_benchmark_landsat", "*.tar.gz")
root = str(tmp_path)
for tarfile in glob.iglob(pathname):
shutil.copy(tarfile, root)
SSL4EOLBenchmark(root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
SSL4EOLBenchmark(str(tmp_path))

def test_plot(self, dataset: SSL4EOLBenchmark) -> None:
sample = dataset[0]
dataset.plot(sample, suptitle="Test")
plt.close()
dataset.plot(sample, show_titles=False)
plt.close()
sample["prediction"] = sample["mask"].clone()
dataset.plot(sample)
plt.close()
2 changes: 2 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class TestSemanticSegmentationTask:
"sen12ms_s2_all",
"sen12ms_s2_reduced",
"spacenet1",
"ssl4eo_l_benchmark_cdl",
"ssl4eo_l_benchmark_nlcd",
"vaihingen2d",
],
)
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .so2sat import So2SatDataModule
from .spacenet import SpaceNet1DataModule
from .ssl4eo import SSL4EOLDataModule, SSL4EOS12DataModule
from .ssl4eo_benchmark_landsat import SSL4EOLBenchmarkDataModule
from .sustainbench_crop_yield import SustainBenchCropYieldDataModule
from .ucmerced import UCMercedDataModule
from .usavars import USAVarsDataModule
Expand Down Expand Up @@ -65,6 +66,7 @@
"SKIPPDDataModule",
"So2SatDataModule",
"SpaceNet1DataModule",
"SSL4EOLBenchmarkDataModule",
"SSL4EOLDataModule",
"SSL4EOS12DataModule",
"SustainBenchCropYieldDataModule",
Expand Down
Loading