From 84c6a0f5e6ee0515767f00b3be033ed31185a471 Mon Sep 17 00:00:00 2001 From: Nils Lehmann <35272119+nilsleh@users.noreply.github.com> Date: Tue, 18 Apr 2023 20:35:19 +0200 Subject: [PATCH] Add National Land Cover Database (NLCD) dataset (#1244) * working nlcd dataset version * citation and correct ordinal color map * add unit tests * requested changes * fix docs * unnecessary space * typos and for loop label conversion * suggested plot changes * use ListedColormap * return fig statement * docs about background class * forgot print * run pyupgrade * found my bug --- docs/api/datasets.rst | 5 + docs/api/geo_datasets.csv | 1 + tests/data/nlcd/data.py | 87 +++++ .../nlcd_2011_land_cover_l48_20210604.zip | Bin 0 -> 1813 bytes .../nlcd_2011_land_cover_l48_20210604.img | Bin 0 -> 1665 bytes .../nlcd_2019_land_cover_l48_20210604.zip | Bin 0 -> 1811 bytes .../nlcd_2019_land_cover_l48_20210604.img | Bin 0 -> 1675 bytes tests/datasets/test_nlcd.py | 103 ++++++ torchgeo/datasets/__init__.py | 2 + torchgeo/datasets/nlcd.py | 302 ++++++++++++++++++ 10 files changed, 500 insertions(+) create mode 100644 tests/data/nlcd/data.py create mode 100644 tests/data/nlcd/nlcd_2011_land_cover_l48_20210604.zip create mode 100644 tests/data/nlcd/nlcd_2011_land_cover_l48_20210604/nlcd_2011_land_cover_l48_20210604.img create mode 100644 tests/data/nlcd/nlcd_2019_land_cover_l48_20210604.zip create mode 100644 tests/data/nlcd/nlcd_2019_land_cover_l48_20210604/nlcd_2019_land_cover_l48_20210604.img create mode 100644 tests/datasets/test_nlcd.py create mode 100644 torchgeo/datasets/nlcd.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 0dddd413564..775bbc2cd4e 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -129,6 +129,11 @@ NAIP .. autoclass:: NAIP +NLCD +^^^^ + +.. autoclass:: NLCD + Open Buildings ^^^^^^^^^^^^^^ diff --git a/docs/api/geo_datasets.csv b/docs/api/geo_datasets.csv index 628933eec70..54cf53b9f27 100644 --- a/docs/api/geo_datasets.csv +++ b/docs/api/geo_datasets.csv @@ -17,5 +17,6 @@ Dataset,Type,Source,Size (px),Resolution (m) `LandCover.ai Geo`_,"Imagery, Masks",Aerial,"4,200--9,500",0.25--0.5 `Landsat`_,Imagery,Landsat,"8,900x8,900",30 `NAIP`_,Imagery,Aerial,"6,100x7,600",1 +`NLCD`_,Masks,Landsat,-,30 `Open Buildings`_,Geometries,"Maxar, CNES/Airbus",-,- `Sentinel`_,Imagery,Sentinel,"10,000x10,000",10 diff --git a/tests/data/nlcd/data.py b/tests/data/nlcd/data.py new file mode 100644 index 00000000000..fa1c592ea29 --- /dev/null +++ b/tests/data/nlcd/data.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil + +import numpy as np +import rasterio +from rasterio.crs import CRS +from rasterio.transform import Affine + +SIZE = 32 + +np.random.seed(0) + +dir = "nlcd_{}_land_cover_l48_20210604" + +years = [2011, 2019] + +wkt = """ +PROJCS["Albers Conical Equal Area", + GEOGCS["WGS 84", + DATUM["WGS_1984", + SPHEROID["WGS 84",6378137,298.257223563, + AUTHORITY["EPSG","7030"]], + AUTHORITY["EPSG","6326"]], + PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]], + UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]], + AUTHORITY["EPSG","4326"]], + PROJECTION["Albers_Conic_Equal_Area"], + PARAMETER["latitude_of_center",23], + PARAMETER["longitude_of_center",-96], + PARAMETER["standard_parallel_1",29.5], + PARAMETER["standard_parallel_2",45.5], + PARAMETER["false_easting",0], + PARAMETER["false_northing",0], + UNIT["meters",1], + AXIS["Easting",EAST], + AXIS["Northing",NORTH]] +""" + + +def create_file(path: str, dtype: str): + """Create the testing file.""" + profile = { + "driver": "GTiff", + "dtype": dtype, + "count": 1, + "crs": CRS.from_wkt(wkt), + "transform": Affine(30.0, 0.0, -2493045.0, 0.0, -30.0, 3310005.0), + "height": SIZE, + "width": SIZE, + "compress": "lzw", + "predictor": 2, + } + + allowed_values = [0, 11, 12, 21, 22, 23, 24, 31, 41, 42, 43, 52, 71, 81, 82, 90, 95] + + Z = np.random.choice(allowed_values, size=(SIZE, SIZE)) + + with rasterio.open(path, "w", **profile) as src: + src.write(Z, 1) + + +if __name__ == "__main__": + for year in years: + year_dir = dir.format(year) + # Remove old data + if os.path.isdir(year_dir): + shutil.rmtree(year_dir) + + os.makedirs(os.path.join(os.getcwd(), year_dir)) + + zip_filename = year_dir + ".zip" + filename = year_dir + ".img" + create_file(os.path.join(year_dir, filename), dtype="int8") + + # Compress data + shutil.make_archive(year_dir, "zip", ".", year_dir) + + # Compute checksums + with open(zip_filename, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{zip_filename}: {md5}") diff --git a/tests/data/nlcd/nlcd_2011_land_cover_l48_20210604.zip b/tests/data/nlcd/nlcd_2011_land_cover_l48_20210604.zip new file mode 100644 index 0000000000000000000000000000000000000000..6e5ad22538fe83687804ad0924b0d7cd045bf69e GIT binary patch literal 1813 zcma)-dpHw%7{|96&80(b$5I?cNMqJq!Z_wSD|a%BVK~$%mzZlA$)Q{#rwJ{J=`;#Q z%e`!4v{r86pn-dHpqkc*&Njd2ClNx6#E_uS!UxrON^Ej3UrrSwQ-a^E{&H-~F9=Kk%iZVS{ zhJOAB=vg#A=wR(-cMkjh@Mw}j;elt3*I_>#OssQNi$F4@?y{icZ2E4dok+FCW?6^7 zihB96a#4`R*CfQgRXia9_zr*uX2qN9hh{hhq+2`0B<@8)4-;N|xBCNXlCTs;GJ_I8#J zthF5L^fByzu(7i~VP)NMojbUi#+~IM>YDY8y5co8x)S;MQFk54XfE)rK8(j_kac5} z4AU7PNdC#%r;sBxio^0cxu+?>bR^`CpC4q8kjpuAsvlAFv04&*30~6EmZO*lyieVd z(|O~~g-~9uClaq{+T&3l6e7V)e(-XerFvUa@6rf2N4IKGluzv&0WU(zBn!a;YYO)M z(`K<*nv;X0w5Ql>okY$*I*$s873%^*@)Y>5&XctiR??D?^OhTNl?`~F?Gu&@KFrSd zXnEqsJxz*{AMEf^GvXRZK0|{rMO(c+L$CL%e6IF0Y{6spfC2k2cLg0_vDK<+U{?A7 z;>2Jcv!4hJx8v&xtmMhMr*207UUiYgb~^%O8uc9Y>*+T^*yYMa*nl zE4}5=E$l?o!c3Mb3Noso!UocQw}9zUxtM4aL~kE;YozPGQId`jGX*en-t(}&QmwFC z@W|iv;Z_Y_fE@+7cRpf#c||B~*Qs@W`5J{{{a!SMQYmDKiTU^!?<^LFmkdW(O_AF( zEtGE#s}I6Cho&!`X?^(4QHL!zKqBBSAlhL46&`wFys)*DRcqPg z7GKlr@w?@czL$63B?I>%m8<7ex!PsALMy=)4>LHC=po`&ut%~RxqddTcW4sfSv`lV zui`l6x2m_6ULLfiwyD2o!@=Y9(#KaMV%ZUy={%^|=lyq1Xn9p1&11JW{w$cr`3!n2 zfEU%}RxQ&R>TR65+vbvT^HYS<7SL>Pj1=0Zy67 z6dBWZS?Xo(&(A9ie%+zd{O*n#jQ!{FQAXQ2?Zj22xkmAmJ3XS#)Y8bnXSUcp)fus2 z7@@%^M@dIZga6vsL0o;K8LMzSC1EUgUD6#31Sw1WH#@#yL;>*8_EY!GJ-{sEr^QmqH&f zrqfL*<6&WHalQ&8IaYptq>Im+(R>L3rCFl3~-gCYu_x$cn zerah-zyko-00;>XLSaN83pDE1P8Z3oDTJoFwX=tfkX^k&D1hIByuVvxglDnaQ%D8~ zTHr}=YoZl4Ezqo5CXk5SEh3b4OrR3{lZcBN5~LFhC0=aQFkRDMCUlPmg-{~@+&R$e zwN~^}zT)e`{|P*vl4K?d_$^VI_>}>1oy-@2nA@p|W4MTtdo+Pg6Y*gXc;JJ0^k`@h z2vpbeVFBI4s6l`R3>S&%nNHt}YSc}Viwq#U7$u-GoqhnCiwJnQNCLWxD0sR^28N5y zKR_%zLFA(8{g;GQZc0q9{DMupdw&z}WVy30I=xvU7MBzhmhZ!o_U_rSEx!ayE;~T% z;_|}$st-~&VZyhn0B2iv%DREWP=M;8NFG!)#ES#L7Bdqe^ZTTj@n1>yOkf-z2O}9} zb1$O!Vt`1-0eqllvMf!#q8ycqB4*i&6%6tarke13#FCQl>@j|KGWto>`40PR%xau* z1kD9Vg-SpEB|(1X`kVgaQQab>=C{JR$UNoNd-&AqF93vQ0hZGr{emwX`1WY-Pl4yW z-?o@St_c;^>c-rRMBl-Kzl2_}n6GFzlX*p6d_A_lQT=P#QC+Dl&pwt=*{9UAR|?@F zKUM@rS=w!^1pSd|U}u$Z6vwhnhQ4vqc`k5I{$brRLJLe#d8v(Obz{r>czYk~Pf4Q= zF^}I*S=5~!QSh!=)R`1;BZ`Z^k@4njc}Bj{;5(Iv=yh`NzxFhq!nFbGd4L6Z!fD63 zEJZ`{=li)jJb$Vgry8Rt(P2_f7*3n$S|p6XHgh8!A(m=Ygva*c4nz0aLnmxL=*Jzl z=|M_+Vtub-;8FfVL5sD=!l*U$%v7PG966`m)?%sjOfHS7(BN6UieN{71G%{-Yn+EO zv`bjZ!q2!o#4ypZOC?Kgr|G%XIHP2&z%zVcCX8RSbPcE+SjMjJ!(-I)c2Uo0`Ux9-4p}wa4N(F5?ll9N zPo}tnpp5HML0hrR_Ygfrtmh%aN818)G9#3(3F*wa+a|vAWiw#*!7>C-jH93$BliGR zl(qE;UOW}Mk))PI^I9Vo0loJMj#kvsc0LA@dT6k^-86lHQ#ZoC^&qIYY%Z)NrBuBY zr=|V?oDhjnTI(1R!R2r*3*&^2%7HFIniTa^sqG_L@`!&}lP$HVXn=P6Z*MuIasn8W z=6RmZG(EFR>3B9nrBg;i6?#J^nSK-4CV{RisYyXH!LpEM3%io8)+1kgbGrf_mEqL> zNQ{e2q#JI~WGDn*i&+LK_oDO)zG*$m(&>HbJ|z8xL8yPm#Jr&9^qC9baYfj*O*?6< z%tK@%MHN?;301WkMmZ9`TT3QW6{#FQ9Zd#)!8S?y&tzLtw4!}qg`&DaReTI+YHcV` zsMpWtqW7#weQV0oozfFsS&5X6PJQ|`1+wRhR5W=wy;U@Y7Vx`H1!mU|&03ihTT54I zt)XHnA%NY>?}S;w=+!uAkuV%Kw9$N1JS_3_K))G<5G%zz{(Dfq5N*`2IitEt+hI@c zBgZx*m~HONU}aVY(f76$&LK> literal 0 HcmV?d00001 diff --git a/tests/data/nlcd/nlcd_2019_land_cover_l48_20210604.zip b/tests/data/nlcd/nlcd_2019_land_cover_l48_20210604.zip new file mode 100644 index 0000000000000000000000000000000000000000..0021cc869619a4534378105e652fe77210afce80 GIT binary patch literal 1811 zcmWIWW@Zs#00G^pePLh*lu!cFc{#}`@kR!Qmhm}>c`5P9`DLj^@i`_IKrtgj12Y2? z{Q#(b4zPZeu!XxPvobJrvoSEZ6Q!S^4SJcm>9rH8eI|!X9J~KH_EwOD*peu&qeT;M z3wSRQ47kD0Gog3GBi@p}Ev^f+{8j|8n!RAnaCD6kP+fR~k0oi2*t0#U#_~K|7tc2C z;wgwc;NtQuBGfG3B|K~UhQAvp+uVO|`M&bp$=`LFnhOswF_<&7FmJM4=$p}aT&B*h z)!HUg`0&H$Pun=G_Zl+1Yx*^b^}SAM$~U2Hcba^94&PpP=ost0(AP{mdXvpEwPG6* zA1BW~_p-oa>&5e8Ca+`O)W&Z7sVB=;XZAnoUt@ik&e!~Z?}Ju2kp_4snz)~}zs z?cVqJtT{Q4S?}TkJ70Urt!gL2ukXvhb3^jqt{J|qmyao3Y<<+0d|xeA%;)~ye(Tux z_YB^-354&o6*^~r_O)G9%hU7|!R0&h9{yT#O2|i7Q*rh0aGBWC!C!tKt>Sz=|H_PX zXQdK(o4=E|817ZR(*1b4I(U6}-QAMrGC3i~on2lA$G#ukRP3px@bUW22UmM9epKsn z?Nim(PY`jxT9&uzxB4`lPtyYSa5ZP1h|SEHzC|}nEN$BqQ*ULy z4bdS&-|u*cY-IlMTfQr1+S7=G`#q(u{+%@UN%ib#%LkciSAB}#hLx@RV%B*jaO*nV z$1g%RmIWzkU3|Ic%pHYUQ{7ga=*(mbyQ`_S-qXZ**00YIOB>F0D~Z}!Jr+t&%RBEJ zWRS3gquIskY2&1}{QoOgzuP$N+A~@1;HKRrEUZVSYD{Fbj_MU?^K;&z60Es9vupOW zyO)mS++G*fX(%4HXrHU!nYg0+UPim)KC*p&64^Stx4&m;_~iH%Uc6EpYd!{Xp5Dtd zzhj1?RBf=%+#5eaO5_%8i#>U)bh>}F#4+_M$IGv8GwI4c3y|9~Sv$BfbIs{vO%@R+ z41@3ZY+P5hy!2N696pT|TtdG#&&uda*&U_!>E2S=ioCES*=o6OOmC{Uzj=OgQhb3g zyMBQ=Uts^~&__+b-rj0sPZm8%?$pyzQEEJ8a>NeQ1sLCQQ&5&!&m&AmX2P1ZV%aojd`-b+ph{g7%eh)%>n$303 zzLGc7$ewn3@kuYKvj_TR&v$$DUX2$qvb5egiq~q=zk$t zsIll=$JJ{AO=A5F7I%+aXovyh2bjcCswXI+f@7Pboia@GGwzxLQi%vK{B;De2$u4|dIV-3_L>A@ h9s|RcMhBwIBWNS2UJ3AK1y)oH3`{_H6i8cuc>sXC7cc+- literal 0 HcmV?d00001 diff --git a/tests/data/nlcd/nlcd_2019_land_cover_l48_20210604/nlcd_2019_land_cover_l48_20210604.img b/tests/data/nlcd/nlcd_2019_land_cover_l48_20210604/nlcd_2019_land_cover_l48_20210604.img new file mode 100644 index 0000000000000000000000000000000000000000..231f010d997cafa851d4088d5c3731bf2e65ede9 GIT binary patch literal 1675 zcmZvZdrVVT9LIm>-rlwph6SmXLa{vLHN=u=1Qd3G0kRQRe1Xo9wg^Dm}QwUsWwr1^xN=k++G()6QWY|*#H-^TxmI+{r6p*PfHWJ=t-0CL>S9f0oJ?({Vr#28(AfjvZj7zC4C#O~6wKmbg~ zbDh5v+oDIn0%r&5?p^G;Hx<}CaR;4&;b4eA)5V?#U^$3@?H~?L4r0J@kO9sP+W!DO z)G?%kRMK0*nt5hYj{1Vdu&1a|*rv4alkJ-_^+4{Frt?+AQ>yKQ>()PJ>r}}Gy z^I~e%1SLs5Pw3*)hIRFdOFkcrOsvVC10JJ6bB204YD-Y8uKn1wu`~QlF^3GNOhw|0 z#SO=QNV+1NzPBHDy3xI1+>ODw#_N%Czi(dm*&-}ILy@|}bqmiYRi0pONeNlqzv9?h zzOvsh)=O0Xw8kNnYKbdb9Gorkj-WkQmC3o!I#n51>s zhFrm0CR#ic>XU&{&~?bL?xu4L^FSbf(+?p@Lc_AuZsp|4?hr^5RqNU^ zj$Rk*CZ*yMz23{h3!>O*3^g$^&BUrZ#pN}L7G_g_9=jCK1#zU(M{io<%*yR~L=|kDrCftRU?FNlP@ZL?<(<%*gD2L;XVj{9=~=5-XKJ~1S)p1@o+A2-$1(Z z&v}mc0F$b_eIn=yM~*s7hby|1J*|9!pPLCz-5tdIPsF0hpvKUl6DVyVGcx$aq+fY;b7$yj4JJ9agP}M9~L=%-1nt@+BUK9~plO+(U zTu#q&nAdv+@X)+cEDteeP6JdgVijNko zbFvP-Cb&aRb!}GcOJ9NVelzEYL=mHjOwP(S1X+gE;q2?<7 None: + shutil.copy(url, root) + + +class TestNLCD: + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NLCD: + monkeypatch.setattr(torchgeo.datasets.nlcd, "download_url", download_url) + + md5s = { + 2011: "99546a3b89a0dddbe4e28e661c79984e", + 2019: "a4008746f15720b8908ddd357a75fded", + } + monkeypatch.setattr(NLCD, "md5s", md5s) + + url = os.path.join( + "tests", "data", "nlcd", "nlcd_{}_land_cover_l48_20210604.zip" + ) + monkeypatch.setattr(NLCD, "url", url) + monkeypatch.setattr(plt, "show", lambda *args: None) + root = str(tmp_path) + transforms = nn.Identity() + return NLCD( + root, + transforms=transforms, + download=True, + checksum=True, + years=[2011, 2019], + ) + + def test_getitem(self, dataset: NLCD) -> None: + x = dataset[dataset.bounds] + assert isinstance(x, dict) + assert isinstance(x["crs"], CRS) + assert isinstance(x["mask"], torch.Tensor) + + def test_and(self, dataset: NLCD) -> None: + ds = dataset & dataset + assert isinstance(ds, IntersectionDataset) + + def test_or(self, dataset: NLCD) -> None: + ds = dataset | dataset + assert isinstance(ds, UnionDataset) + + def test_already_extracted(self, dataset: NLCD) -> None: + NLCD(root=dataset.root, download=True, years=[2019]) + + def test_already_downloaded(self, tmp_path: Path) -> None: + pathname = os.path.join( + "tests", "data", "nlcd", "nlcd_2019_land_cover_l48_20210604.zip" + ) + root = str(tmp_path) + shutil.copy(pathname, root) + NLCD(root, years=[2019]) + + def test_invalid_year(self, tmp_path: Path) -> None: + with pytest.raises( + AssertionError, + match="NLCD data product only exists for the following years:", + ): + NLCD(str(tmp_path), years=[1996]) + + def test_plot(self, dataset: NLCD) -> None: + query = dataset.bounds + x = dataset[query] + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: NLCD) -> None: + query = dataset.bounds + x = dataset[query] + x["prediction"] = x["mask"].clone() + dataset.plot(x, suptitle="Prediction") + plt.close() + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + NLCD(str(tmp_path)) + + def test_invalid_query(self, dataset: NLCD) -> None: + query = BoundingBox(0, 0, 0, 0, 0, 0) + with pytest.raises( + IndexError, match="query: .* not found in index with bounds:" + ): + dataset[query] diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 27ae5f01d2e..334601853d5 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -74,6 +74,7 @@ from .millionaid import MillionAID from .naip import NAIP from .nasa_marine_debris import NASAMarineDebris +from .nlcd import NLCD from .openbuildings import OpenBuildings from .oscd import OSCD from .patternnet import PatternNet @@ -156,6 +157,7 @@ "Landsat8", "Landsat9", "NAIP", + "NLCD", "OpenBuildings", "Sentinel", "Sentinel1", diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py new file mode 100644 index 00000000000..8926cd9f885 --- /dev/null +++ b/torchgeo/datasets/nlcd.py @@ -0,0 +1,302 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""NLCD dataset.""" + +import os +from typing import Any, Callable, Optional + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import ListedColormap +from rasterio.crs import CRS + +from .geo import RasterDataset +from .utils import BoundingBox, download_url, extract_archive + + +class NLCD(RasterDataset): + """National Land Cover Database (NLCD) dataset. + + The `NLCD dataset + `_ + is a land cover product that covers the United States and Puerto Rico. The current + implementation supports maps for the continental United States only. The product is + a joint effort between the United States Geological Survey + (`USGS `_) and the Multi-Resolution Land Characteristics + Consortium (`MRLC `_) which released the first product + in 2001 with new updates every five years since then. + + The dataset contains the following 17 classes: + + 0. Background + #. Open Water + #. Perennial Ice/Snow + #. Developed, Open Space + #. Developed, Low Intensity + #. Developed, Medium Intensity + #. Developed, High Intensity + #. Barren Land (Rock/Sand/Clay) + #. Deciduous Forest + #. Evergreen Forest + #. Mixed Forest + #. Shrub/Scrub + #. Grassland/Herbaceous + #. Pasture/Hay + #. Cultivated Crops + #. Woody Wetlands + #. Emergent Herbaceous Wetlands + + Detailed descriptions of the classes can be found + `here `__. + + Dataset format: + + * single channel .img file with integer class labels + + If you use this dataset in your research, please use the corresponding citation: + + * 2001: https://doi.org/10.5066/P9MZGHLF + * 2006: https://doi.org/10.5066/P9HBR9V3 + * 2011: https://doi.org/10.5066/P97S2IID + * 2016: https://doi.org/10.5066/P96HHBIE + * 2019: https://doi.org/10.5066/P9KZCM54 + + .. versionadded:: 0.5 + """ # noqa: E501 + + filename_glob = "nlcd_*_land_cover_l48_20210604.img" + filename_regex = ( + r"nlcd_(?P\d{4})_land_cover_l48_(?P\d{8})\.img" + ) + zipfile_glob = "nlcd_*_land_cover_l48_20210604.zip" + date_format = "%Y" + is_image = False + + url = "https://s3-us-west-2.amazonaws.com/mrlc/nlcd_{}_land_cover_l48_20210604.zip" + + md5s = { + 2001: "538166a4d783204764e3df3b221fc4cd", + 2006: "67454e7874a00294adb9442374d0c309", + 2011: "ea524c835d173658eeb6fa3c8e6b917b", + 2016: "452726f6e3bd3f70d8ca2476723d238a", + 2019: "82851c3f8105763b01c83b4a9e6f3961", + } + + ordinal_label_map = { + 0: 0, + 11: 1, + 12: 2, + 21: 3, + 22: 4, + 23: 5, + 24: 6, + 31: 7, + 41: 8, + 42: 9, + 43: 10, + 52: 11, + 71: 12, + 81: 13, + 82: 14, + 90: 15, + 95: 16, + } + + cmap = { + 0: (0, 0, 0, 255), + 1: (70, 107, 159, 255), + 2: (209, 222, 248, 255), + 3: (222, 197, 197, 255), + 4: (217, 146, 130, 255), + 5: (235, 0, 0, 255), + 6: (171, 0, 0, 255), + 7: (179, 172, 159, 255), + 8: (104, 171, 95, 255), + 9: (28, 95, 44, 255), + 10: (181, 197, 143, 255), + 11: (204, 184, 121, 255), + 12: (223, 223, 194, 255), + 13: (220, 217, 57, 255), + 14: (171, 108, 40, 255), + 15: (184, 217, 235, 255), + 16: (108, 159, 184, 255), + } + + def __init__( + self, + root: str = "data", + crs: Optional[CRS] = None, + res: Optional[float] = None, + years: list[int] = [2019], + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + cache: bool = True, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new Dataset instance. + + Args: + root: root directory where dataset can be found + crs: :term:`coordinate reference system (CRS)` to warp to + (defaults to the CRS of the first file found) + res: resolution of the dataset in units of CRS + (defaults to the resolution of the first file found) + years: list of years for which to use nlcd layer + transforms: a function/transform that takes an input sample + and returns a transformed version + cache: if True, cache file handle to speed up repeated sampling + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 after downloading files (may be slow) + + Raises: + FileNotFoundError: if no files are found in ``root`` + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + AssertionError: if ``year`` is invalid + """ + assert set(years).issubset(self.md5s.keys()), ( + "NLCD data product only exists for the following years: " + f"{list(self.md5s.keys())}." + ) + self.years = years + self.root = root + self.download = download + self.checksum = checksum + + self._verify() + + super().__init__(root, crs, res, transforms=transforms, cache=cache) + + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + """Retrieve mask and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of mask and metadata at that index + + Raises: + IndexError: if query is not found in the index + """ + sample = super().__getitem__(query) + + mask = sample["mask"] + for k, v in self.ordinal_label_map.items(): + mask[mask == k] = v + + sample["mask"] = mask + + return sample + + def _verify(self) -> None: + """Verify the integrity of the dataset. + + Raises: + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + # Check if the extracted files already exist + exists = [] + for year in self.years: + filename_year = self.filename_glob.replace("*", str(year)) + dirname_year = filename_year.split(".")[0] + pathname = os.path.join(self.root, dirname_year, filename_year) + if os.path.exists(pathname): + exists.append(True) + else: + exists.append(False) + + if all(exists): + return + + # Check if the zip files have already been downloaded + exists = [] + for year in self.years: + pathname = os.path.join( + self.root, self.zipfile_glob.replace("*", str(year)) + ) + if os.path.exists(pathname): + exists.append(True) + self._extract() + else: + exists.append(False) + + if all(exists): + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.root}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automatically download the dataset." + ) + + # Download the dataset + self._download() + self._extract() + + def _download(self) -> None: + """Download the dataset.""" + for year in self.years: + download_url( + self.url.format(year), + self.root, + md5=self.md5s[year] if self.checksum else None, + ) + + def _extract(self) -> None: + """Extract the dataset.""" + for year in self.years: + zipfile_name = self.zipfile_glob.replace("*", str(year)) + pathname = os.path.join(self.root, zipfile_name) + extract_archive(pathname, self.root) + + def plot( + self, + sample: dict[str, Any], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`RasterDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + mask = sample["mask"].squeeze().numpy() + ncols = 1 + + plt_cmap = ListedColormap( + np.stack([np.array(val) / 255 for val in self.cmap.values()], axis=0) + ) + + showing_predictions = "prediction" in sample + if showing_predictions: + pred = sample["prediction"].squeeze().numpy() + ncols = 2 + + fig, axs = plt.subplots( + nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False + ) + + axs[0, 0].imshow(mask, cmap=plt_cmap) + axs[0, 0].axis("off") + + if show_titles: + axs[0, 0].set_title("Mask") + + if showing_predictions: + axs[0, 1].imshow(pred, cmap=plt_cmap) + axs[0, 1].axis("off") + if show_titles: + axs[0, 1].set_title("Prediction") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig