Skip to content

Commit

Permalink
Add National Land Cover Database (NLCD) dataset (microsoft#1244)
Browse files Browse the repository at this point in the history
* working nlcd dataset version

* citation and correct ordinal color map

* add unit tests

* requested changes

* fix docs

* unnecessary space

* typos and for loop label conversion

* suggested plot changes

* use ListedColormap

* return fig statement

* docs about background class

* forgot print

* run pyupgrade

* found my bug
  • Loading branch information
nilsleh authored Apr 18, 2023
1 parent 92574e9 commit 84c6a0f
Show file tree
Hide file tree
Showing 10 changed files with 500 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ NAIP

.. autoclass:: NAIP

NLCD
^^^^

.. autoclass:: NLCD

Open Buildings
^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ Dataset,Type,Source,Size (px),Resolution (m)
`LandCover.ai Geo`_,"Imagery, Masks",Aerial,"4,200--9,500",0.25--0.5
`Landsat`_,Imagery,Landsat,"8,900x8,900",30
`NAIP`_,Imagery,Aerial,"6,100x7,600",1
`NLCD`_,Masks,Landsat,-,30
`Open Buildings`_,Geometries,"Maxar, CNES/Airbus",-,-
`Sentinel`_,Imagery,Sentinel,"10,000x10,000",10
87 changes: 87 additions & 0 deletions tests/data/nlcd/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil

import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine

SIZE = 32

np.random.seed(0)

dir = "nlcd_{}_land_cover_l48_20210604"

years = [2011, 2019]

wkt = """
PROJCS["Albers Conical Equal Area",
GEOGCS["WGS 84",
DATUM["WGS_1984",
SPHEROID["WGS 84",6378137,298.257223563,
AUTHORITY["EPSG","7030"]],
AUTHORITY["EPSG","6326"]],
PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],
UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],
AUTHORITY["EPSG","4326"]],
PROJECTION["Albers_Conic_Equal_Area"],
PARAMETER["latitude_of_center",23],
PARAMETER["longitude_of_center",-96],
PARAMETER["standard_parallel_1",29.5],
PARAMETER["standard_parallel_2",45.5],
PARAMETER["false_easting",0],
PARAMETER["false_northing",0],
UNIT["meters",1],
AXIS["Easting",EAST],
AXIS["Northing",NORTH]]
"""


def create_file(path: str, dtype: str):
"""Create the testing file."""
profile = {
"driver": "GTiff",
"dtype": dtype,
"count": 1,
"crs": CRS.from_wkt(wkt),
"transform": Affine(30.0, 0.0, -2493045.0, 0.0, -30.0, 3310005.0),
"height": SIZE,
"width": SIZE,
"compress": "lzw",
"predictor": 2,
}

allowed_values = [0, 11, 12, 21, 22, 23, 24, 31, 41, 42, 43, 52, 71, 81, 82, 90, 95]

Z = np.random.choice(allowed_values, size=(SIZE, SIZE))

with rasterio.open(path, "w", **profile) as src:
src.write(Z, 1)


if __name__ == "__main__":
for year in years:
year_dir = dir.format(year)
# Remove old data
if os.path.isdir(year_dir):
shutil.rmtree(year_dir)

os.makedirs(os.path.join(os.getcwd(), year_dir))

zip_filename = year_dir + ".zip"
filename = year_dir + ".img"
create_file(os.path.join(year_dir, filename), dtype="int8")

# Compress data
shutil.make_archive(year_dir, "zip", ".", year_dir)

# Compute checksums
with open(zip_filename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{zip_filename}: {md5}")
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
103 changes: 103 additions & 0 deletions tests/datasets/test_nlcd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS

import torchgeo.datasets.utils
from torchgeo.datasets import NLCD, BoundingBox, IntersectionDataset, UnionDataset


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


class TestNLCD:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NLCD:
monkeypatch.setattr(torchgeo.datasets.nlcd, "download_url", download_url)

md5s = {
2011: "99546a3b89a0dddbe4e28e661c79984e",
2019: "a4008746f15720b8908ddd357a75fded",
}
monkeypatch.setattr(NLCD, "md5s", md5s)

url = os.path.join(
"tests", "data", "nlcd", "nlcd_{}_land_cover_l48_20210604.zip"
)
monkeypatch.setattr(NLCD, "url", url)
monkeypatch.setattr(plt, "show", lambda *args: None)
root = str(tmp_path)
transforms = nn.Identity()
return NLCD(
root,
transforms=transforms,
download=True,
checksum=True,
years=[2011, 2019],
)

def test_getitem(self, dataset: NLCD) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)

def test_and(self, dataset: NLCD) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: NLCD) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_already_extracted(self, dataset: NLCD) -> None:
NLCD(root=dataset.root, download=True, years=[2019])

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join(
"tests", "data", "nlcd", "nlcd_2019_land_cover_l48_20210604.zip"
)
root = str(tmp_path)
shutil.copy(pathname, root)
NLCD(root, years=[2019])

def test_invalid_year(self, tmp_path: Path) -> None:
with pytest.raises(
AssertionError,
match="NLCD data product only exists for the following years:",
):
NLCD(str(tmp_path), years=[1996])

def test_plot(self, dataset: NLCD) -> None:
query = dataset.bounds
x = dataset[query]
dataset.plot(x, suptitle="Test")
plt.close()

def test_plot_prediction(self, dataset: NLCD) -> None:
query = dataset.bounds
x = dataset[query]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")
plt.close()

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
NLCD(str(tmp_path))

def test_invalid_query(self, dataset: NLCD) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from .millionaid import MillionAID
from .naip import NAIP
from .nasa_marine_debris import NASAMarineDebris
from .nlcd import NLCD
from .openbuildings import OpenBuildings
from .oscd import OSCD
from .patternnet import PatternNet
Expand Down Expand Up @@ -156,6 +157,7 @@
"Landsat8",
"Landsat9",
"NAIP",
"NLCD",
"OpenBuildings",
"Sentinel",
"Sentinel1",
Expand Down
Loading

0 comments on commit 84c6a0f

Please sign in to comment.