From 0620f6baeb42cf806f5eab91fd575a5a13a45fd8 Mon Sep 17 00:00:00 2001 From: iejMac Date: Thu, 20 Jan 2022 06:49:52 +0000 Subject: [PATCH 01/65] Adding dataset from MOSAIKS paper --- torchgeo/datasets/__init__.py | 2 ++ torchgeo/datasets/mosaiks.py | 41 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 torchgeo/datasets/mosaiks.py diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 4d9ddbcf97f..458cd6d5d47 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -56,6 +56,7 @@ ) from .levircd import LEVIRCDPlus from .loveda import LoveDA +from .mosaiks import MOSAIKS from .naip import NAIP from .nasa_marine_debris import NASAMarineDebris from .nwpu import VHR10 @@ -127,6 +128,7 @@ "LandCoverAI", "LEVIRCDPlus", "LoveDA", + "MOSAIKS", "NASAMarineDebris", "OSCD", "PatternNet", diff --git a/torchgeo/datasets/mosaiks.py b/torchgeo/datasets/mosaiks.py new file mode 100644 index 00000000000..81732c7be5d --- /dev/null +++ b/torchgeo/datasets/mosaiks.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from .utils import download_url + + +class MOSAIKS: + url_prefix = "https://files.codeocean.com/files/verified/fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" + + label_urls = { + "housing": url_prefix + "housing/outcomes_sampled_housing_CONTUS_16_640_POP_100000_0.csv?download", + "income": url_prefix + "income/outcomes_sampled_income_CONTUS_16_640_POP_100000_0.csv?download", + "roads": url_prefix + "roads/outcomes_sampled_roads_CONTUS_16_640_POP_100000_0.csv?download", + "nightligths": url_prefix + "nightlights/outcomes_sampled_nightlights_CONTUS_16_640_POP_100000_0.csv?download", + "population": url_prefix + "population/outcomes_sampled_population_CONTUS_16_640_UAR_100000_0.csv?download", + "elevation": url_prefix + "elevation/outcomes_sampled_elevation_CONTUS_16_640_UAR_100000_0.csv?download", + "treecover": url_prefix + "treecover/outcomes_sampled_treecover_CONTUS_16_640_UAR_100000_0.csv?download", + } + + def __init__( + self, + root: str = "data", + ) -> None: + """Initialize a new MOSAIKS dataset instance. + """ + + self.root = root + + self._verify() + + + def _verify(self) -> None: + self._download() + + def _download(self) -> None: + for f_name in self.label_urls: + download_url( + self.label_urls[f_name], + self.root, + filename=f_name + ".csv" + ) From 09b4faf797ae1927b1d12fbb720ee9cf451af370 Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 21 Jan 2022 08:05:42 +0000 Subject: [PATCH 02/65] Name change --- torchgeo/datasets/__init__.py | 4 ++-- torchgeo/datasets/{mosaiks.py => usavars.py} | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) rename torchgeo/datasets/{mosaiks.py => usavars.py} (99%) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 458cd6d5d47..8e3b69a688f 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -56,7 +56,6 @@ ) from .levircd import LEVIRCDPlus from .loveda import LoveDA -from .mosaiks import MOSAIKS from .naip import NAIP from .nasa_marine_debris import NASAMarineDebris from .nwpu import VHR10 @@ -70,6 +69,7 @@ from .so2sat import So2Sat from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7 from .ucmerced import UCMerced +from .usavars import USAVars from .utils import ( BoundingBox, concat_samples, @@ -128,7 +128,6 @@ "LandCoverAI", "LEVIRCDPlus", "LoveDA", - "MOSAIKS", "NASAMarineDebris", "OSCD", "PatternNet", @@ -145,6 +144,7 @@ "SpaceNet7", "TropicalCycloneWindEstimation", "UCMerced", + "USAVars", "Vaihingen2D", "VHR10", "XView2", diff --git a/torchgeo/datasets/mosaiks.py b/torchgeo/datasets/usavars.py similarity index 99% rename from torchgeo/datasets/mosaiks.py rename to torchgeo/datasets/usavars.py index 81732c7be5d..69930927d7a 100644 --- a/torchgeo/datasets/mosaiks.py +++ b/torchgeo/datasets/usavars.py @@ -4,7 +4,7 @@ from .utils import download_url -class MOSAIKS: +class USAVars: url_prefix = "https://files.codeocean.com/files/verified/fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" label_urls = { From fa869cc76cb4f3a82601d62c1259a5b3e495598b Mon Sep 17 00:00:00 2001 From: iejMac Date: Mon, 24 Jan 2022 02:59:37 +0000 Subject: [PATCH 03/65] implementing NAIPTileIndex in USAVars --- torchgeo/datasets/usavars.py | 75 +++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 69930927d7a..8557e592e7a 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -1,6 +1,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import os +import pickle +import rtree +import shapely + from .utils import download_url @@ -17,17 +22,24 @@ class USAVars: "treecover": url_prefix + "treecover/outcomes_sampled_treecover_CONTUS_16_640_UAR_100000_0.csv?download", } + NAIP_BLOB_ROOT = 'https://naipblobs.blob.core.windows.net/naip/' + NAIP_INDEX_BLOB_ROOT = "https://naipblobs.blob.core.windows.net/naip-index/rtree/" + INDEX_FNS = ["tile_index.dat", "tile_index.idx", "tiles.p"] + + def __init__( self, root: str = "data", ) -> None: - """Initialize a new MOSAIKS dataset instance. + """Initialize a new USAVars dataset instance. """ self.root = root self._verify() + self.tile_rtree = rtree.index.Index(self.root + "/tile_index") + self.tile_index = pickle.load(open(self.root + "/tiles.p", "rb")) def _verify(self) -> None: self._download() @@ -37,5 +49,64 @@ def _download(self) -> None: download_url( self.label_urls[f_name], self.root, - filename=f_name + ".csv" + filename=f_name + ".csv", ) + + for fn in self.INDEX_FNS: + download_url( + self.NAIP_INDEX_BLOB_ROOT + fn, + self.root, + filename=fn, + ) + + def lookup_point(self, lat, lon): + '''Given a lat/lon coordinate pair, return the list of NAIP tiles that *contain* that point. + + Args: + lat (float): Latitude in EPSG:4326 + lon (float): Longitude in EPSG:4326 + Returns: + intersected_files (list): A list of URLs of NAIP tiles that *contain* the given (`lat`, `lon`) point + + Raises: + IndexError: Raised if no tile within the index contains the given (`lat`, `lon`) point + ''' + + point = shapely.geometry.Point(float(lon), float(lat)) + geom = shapely.geometry.mapping(point) + + return self.lookup_geom(geom) + + def lookup_geom(self, geom): + '''Given a GeoJSON geometry, return the list of NAIP tiles that *contain* that feature. + + Args: + geom (dict): A GeoJSON geometry in EPSG:4326 + Returns: + intersected_files (list): A list of URLs of NAIP tiles that *contain* the given `geom` + + Raises: + IndexError: Raised if no tile within the index fully contains the given `geom` + ''' + shape = shapely.geometry.shape(geom) + intersected_indices = list(self.tile_rtree.intersection(shape.bounds)) + + intersected_files = [] + tile_intersection = False + + for idx in intersected_indices: + intersected_file = self.tile_index[idx][0] + intersected_geom = self.tile_index[idx][1] + if intersected_geom.contains(shape): + tile_intersection = True + intersected_files.append(self.NAIP_BLOB_ROOT + intersected_file) + + if not tile_intersection and len(intersected_indices) > 0: + raise IndexError("There are overlaps with tile index, but no tile contains the shape") + elif len(intersected_files) <= 0: + raise IndexError("No tile intersections") + else: + return intersected_files + + + From 501216b7da1df1cd389332eeed277ae694084338 Mon Sep 17 00:00:00 2001 From: iejMac Date: Mon, 24 Jan 2022 03:53:28 +0000 Subject: [PATCH 04/65] lookup_point works --- torchgeo/datasets/usavars.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 8557e592e7a..1d37704d706 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -8,7 +8,6 @@ from .utils import download_url - class USAVars: url_prefix = "https://files.codeocean.com/files/verified/fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" @@ -56,7 +55,6 @@ def _download(self) -> None: download_url( self.NAIP_INDEX_BLOB_ROOT + fn, self.root, - filename=fn, ) def lookup_point(self, lat, lon): @@ -107,6 +105,3 @@ def lookup_geom(self, geom): raise IndexError("No tile intersections") else: return intersected_files - - - From 48f6c9b29de633ec9bee122f5b06ab89d2ea61bb Mon Sep 17 00:00:00 2001 From: iejMac Date: Tue, 22 Feb 2022 05:29:52 +0000 Subject: [PATCH 05/65] usavars: adding extract + verify --- torchgeo/datasets/usavars.py | 126 +++++++++++++++-------------------- 1 file changed, 55 insertions(+), 71 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 1d37704d706..5e80682ed03 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -1,47 +1,76 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import glob import os import pickle -import rtree import shapely -from .utils import download_url +from .utils import ( + download_url, + extract_archive, +) class USAVars: - url_prefix = "https://files.codeocean.com/files/verified/fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" + csv_prefix = "https://files.codeocean.com/files/verified/fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" - label_urls = { - "housing": url_prefix + "housing/outcomes_sampled_housing_CONTUS_16_640_POP_100000_0.csv?download", - "income": url_prefix + "income/outcomes_sampled_income_CONTUS_16_640_POP_100000_0.csv?download", - "roads": url_prefix + "roads/outcomes_sampled_roads_CONTUS_16_640_POP_100000_0.csv?download", - "nightligths": url_prefix + "nightlights/outcomes_sampled_nightlights_CONTUS_16_640_POP_100000_0.csv?download", - "population": url_prefix + "population/outcomes_sampled_population_CONTUS_16_640_UAR_100000_0.csv?download", - "elevation": url_prefix + "elevation/outcomes_sampled_elevation_CONTUS_16_640_UAR_100000_0.csv?download", - "treecover": url_prefix + "treecover/outcomes_sampled_treecover_CONTUS_16_640_UAR_100000_0.csv?download", - } + data_url = "https://mosaiks.blob.core.windows.net/datasets/uar.zip" + dirname = "usavars" + zipfile = dirname + ".zip" - NAIP_BLOB_ROOT = 'https://naipblobs.blob.core.windows.net/naip/' - NAIP_INDEX_BLOB_ROOT = "https://naipblobs.blob.core.windows.net/naip-index/rtree/" - INDEX_FNS = ["tile_index.dat", "tile_index.idx", "tiles.p"] + label_urls = { + "housing": csv_prefix + "housing/outcomes_sampled_housing_CONTUS_16_640_POP_100000_0.csv?download", + "income": csv_prefix + "income/outcomes_sampled_income_CONTUS_16_640_POP_100000_0.csv?download", + "roads": csv_prefix + "roads/outcomes_sampled_roads_CONTUS_16_640_POP_100000_0.csv?download", + "nightligths": csv_prefix + "nightlights/outcomes_sampled_nightlights_CONTUS_16_640_POP_100000_0.csv?download", + "population": csv_prefix + "population/outcomes_sampled_population_CONTUS_16_640_UAR_100000_0.csv?download", + "elevation": csv_prefix + "elevation/outcomes_sampled_elevation_CONTUS_16_640_UAR_100000_0.csv?download", + "treecover": csv_prefix + "treecover/outcomes_sampled_treecover_CONTUS_16_640_UAR_100000_0.csv?download", + } def __init__( self, root: str = "data", + download: bool = False, + checksum: bool = False, ) -> None: """Initialize a new USAVars dataset instance. """ self.root = root + self.download = download + self.checksum = checksum self._verify() - self.tile_rtree = rtree.index.Index(self.root + "/tile_index") - self.tile_index = pickle.load(open(self.root + "/tiles.p", "rb")) - def _verify(self) -> None: + """Verify the integrity of the dataset. + Raises: + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + + # Check if the extracted files already exist + pathname = os.path.join(self.root, self.dirname) + if glob.glob(pathname): + return + + # Check if the zip files have already been downloaded + pathname = os.path.join(self.root, self.zipfile) + if glob.glob(pathname): + self._extract() + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.root}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automaticaly download the dataset." + ) + self._download() + self._extract() def _download(self) -> None: for f_name in self.label_urls: @@ -50,58 +79,13 @@ def _download(self) -> None: self.root, filename=f_name + ".csv", ) - - for fn in self.INDEX_FNS: - download_url( - self.NAIP_INDEX_BLOB_ROOT + fn, + download_url( + self.data_url, self.root, - ) + filename=self.zipfile + ) - def lookup_point(self, lat, lon): - '''Given a lat/lon coordinate pair, return the list of NAIP tiles that *contain* that point. - - Args: - lat (float): Latitude in EPSG:4326 - lon (float): Longitude in EPSG:4326 - Returns: - intersected_files (list): A list of URLs of NAIP tiles that *contain* the given (`lat`, `lon`) point - - Raises: - IndexError: Raised if no tile within the index contains the given (`lat`, `lon`) point - ''' - - point = shapely.geometry.Point(float(lon), float(lat)) - geom = shapely.geometry.mapping(point) - - return self.lookup_geom(geom) - - def lookup_geom(self, geom): - '''Given a GeoJSON geometry, return the list of NAIP tiles that *contain* that feature. - - Args: - geom (dict): A GeoJSON geometry in EPSG:4326 - Returns: - intersected_files (list): A list of URLs of NAIP tiles that *contain* the given `geom` - - Raises: - IndexError: Raised if no tile within the index fully contains the given `geom` - ''' - shape = shapely.geometry.shape(geom) - intersected_indices = list(self.tile_rtree.intersection(shape.bounds)) - - intersected_files = [] - tile_intersection = False - - for idx in intersected_indices: - intersected_file = self.tile_index[idx][0] - intersected_geom = self.tile_index[idx][1] - if intersected_geom.contains(shape): - tile_intersection = True - intersected_files.append(self.NAIP_BLOB_ROOT + intersected_file) - - if not tile_intersection and len(intersected_indices) > 0: - raise IndexError("There are overlaps with tile index, but no tile contains the shape") - elif len(intersected_files) <= 0: - raise IndexError("No tile intersections") - else: - return intersected_files + def _extract(self) -> None: + src = os.path.join(self.root, self.zipfile) + dst = os.path.join(self.root, self.dirname) + extract_archive(src, dst) From be354d7f3b2482819643ba874c08221336553742 Mon Sep 17 00:00:00 2001 From: iejMac Date: Tue, 22 Feb 2022 05:37:43 +0000 Subject: [PATCH 06/65] USAVars: add md5 --- torchgeo/datasets/usavars.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 5e80682ed03..b9846862690 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -18,6 +18,7 @@ class USAVars: dirname = "usavars" zipfile = dirname + ".zip" + md5 = "677e89fd20e5dd0fe4d29b61827c2456" label_urls = { "housing": csv_prefix + "housing/outcomes_sampled_housing_CONTUS_16_640_POP_100000_0.csv?download", @@ -83,6 +84,7 @@ def _download(self) -> None: self.data_url, self.root, filename=self.zipfile + md5=self.md5 if self.checksum else None, ) def _extract(self) -> None: From 8b8ca7914fb2be6e755ad0a9ecede32155317c59 Mon Sep 17 00:00:00 2001 From: iejMac Date: Tue, 22 Feb 2022 06:39:29 +0000 Subject: [PATCH 07/65] initial _load_files function --- torchgeo/datasets/usavars.py | 38 +++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index b9846862690..f3f609cda13 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -3,8 +3,9 @@ import glob import os -import pickle -import shapely +import pandas as pd + +from typing import Any, Dict, List from .utils import ( download_url, @@ -45,6 +46,37 @@ def __init__( self._verify() + self.files = self._load_files() + + def __len__(self) -> int: + """Return the number of data points in the dataset. + Returns: + length of the dataset + """ + return len(self.files) + + def _load_files(self) -> List[Dict[str, Any]]: + file_path = os.path.join(self.root, self.dirname, "uar") + files = os.listdir(file_path) + + files = files[:10] # TODO: remove this, keeping temporarily because this func is very slow + + # csvs = self.label_urls.keys() # only uar for now + csvs = ["treecover", "elevation", "population"] + labels_ds = [(lab, pd.read_csv(os.path.join(self.root, lab + ".csv"))) for lab in csvs] + samples = [] + for f in files: + img_path = os.path.join(file_path, f) + samp = {"image": img_path} + + id_ = f[5:-4] + + for lab, ds in labels_ds: + samp[lab] = ds[ds["ID"] == id_][lab].values[0] + + samples.append(samp) + return samples + def _verify(self) -> None: """Verify the integrity of the dataset. Raises: @@ -83,7 +115,7 @@ def _download(self) -> None: download_url( self.data_url, self.root, - filename=self.zipfile + filename=self.zipfile, md5=self.md5 if self.checksum else None, ) From b57b7ee54ac55374aad792526cbeed7ddbab96ee Mon Sep 17 00:00:00 2001 From: iejMac Date: Tue, 22 Feb 2022 07:23:54 +0000 Subject: [PATCH 08/65] adding plotting --- torchgeo/datasets/usavars.py | 74 ++++++++++++++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 3 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index f3f609cda13..b356abf7d5a 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -3,16 +3,26 @@ import glob import os -import pandas as pd +from typing import Any, Dict, List, Optional, Union -from typing import Any, Dict, List +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import rasterio +import torch +from matplotlib.figure import Figure +from torch import Tensor +from .geo import VisionDataset from .utils import ( download_url, extract_archive, ) -class USAVars: +class USAVars(VisionDataset): + # TODO: complete this + """ + """ csv_prefix = "https://files.codeocean.com/files/verified/fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" data_url = "https://mosaiks.blob.core.windows.net/datasets/uar.zip" @@ -48,8 +58,22 @@ def __init__( self.files = self._load_files() + def __getitem__(self, index: int) -> Dict[str, Union[Tensor, float]]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + sample = self.files[index] + sample["image"] = self._load_image(sample["image"]) + return sample + def __len__(self) -> int: """Return the number of data points in the dataset. + Returns: length of the dataset """ @@ -77,6 +101,12 @@ def _load_files(self) -> List[Dict[str, Any]]: samples.append(samp) return samples + def _load_image(self, path: str) -> Tensor: + with rasterio.open(path) as f: + array: "np.typing.NDArray[np.int_]" = f.read() + tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] + return tensor + def _verify(self) -> None: """Verify the integrity of the dataset. Raises: @@ -123,3 +153,41 @@ def _extract(self) -> None: src = os.path.join(self.root, self.zipfile) dst = os.path.join(self.root, self.dirname) extract_archive(src, dst) + + + def plot( + self, + sample: Dict[str, Tensor], + show_labels: bool = True, + suptitle: Optional[str] = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_labels: flag indicating whether to show labels above panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + + image = sample["image"][:3].numpy() # get RGB inds + image = np.moveaxis(image, 0, 2) + + fig, axs = plt.subplots(figsize=(10, 10)) + axs.imshow(image) + axs.axis("off") + + if show_labels: + labels = [(lab, val) for lab, val in sample.items() if lab != "image"] + label_string = "" + for l, v in labels: + label_string += f"{l}={v} " + + axs.set_title(label_string) + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig From ec84d89b5f3f00308b0ec44ffe313be1edb03658 Mon Sep 17 00:00:00 2001 From: iejMac Date: Tue, 22 Feb 2022 07:25:29 +0000 Subject: [PATCH 09/65] formatting --- torchgeo/datasets/usavars.py | 65 ++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index b356abf7d5a..b3953765631 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -14,15 +14,12 @@ from torch import Tensor from .geo import VisionDataset -from .utils import ( - download_url, - extract_archive, -) +from .utils import download_url, extract_archive + class USAVars(VisionDataset): # TODO: complete this - """ - """ + """ """ csv_prefix = "https://files.codeocean.com/files/verified/fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" data_url = "https://mosaiks.blob.core.windows.net/datasets/uar.zip" @@ -32,23 +29,26 @@ class USAVars(VisionDataset): md5 = "677e89fd20e5dd0fe4d29b61827c2456" label_urls = { - "housing": csv_prefix + "housing/outcomes_sampled_housing_CONTUS_16_640_POP_100000_0.csv?download", - "income": csv_prefix + "income/outcomes_sampled_income_CONTUS_16_640_POP_100000_0.csv?download", - "roads": csv_prefix + "roads/outcomes_sampled_roads_CONTUS_16_640_POP_100000_0.csv?download", - "nightligths": csv_prefix + "nightlights/outcomes_sampled_nightlights_CONTUS_16_640_POP_100000_0.csv?download", - "population": csv_prefix + "population/outcomes_sampled_population_CONTUS_16_640_UAR_100000_0.csv?download", - "elevation": csv_prefix + "elevation/outcomes_sampled_elevation_CONTUS_16_640_UAR_100000_0.csv?download", - "treecover": csv_prefix + "treecover/outcomes_sampled_treecover_CONTUS_16_640_UAR_100000_0.csv?download", + "housing": csv_prefix + + "housing/outcomes_sampled_housing_CONTUS_16_640_POP_100000_0.csv?download", + "income": csv_prefix + + "income/outcomes_sampled_income_CONTUS_16_640_POP_100000_0.csv?download", + "roads": csv_prefix + + "roads/outcomes_sampled_roads_CONTUS_16_640_POP_100000_0.csv?download", + "nightligths": csv_prefix + + "nightlights/outcomes_sampled_nightlights_CONTUS_16_640_POP_100000_0.csv?download", + "population": csv_prefix + + "population/outcomes_sampled_population_CONTUS_16_640_UAR_100000_0.csv?download", + "elevation": csv_prefix + + "elevation/outcomes_sampled_elevation_CONTUS_16_640_UAR_100000_0.csv?download", + "treecover": csv_prefix + + "treecover/outcomes_sampled_treecover_CONTUS_16_640_UAR_100000_0.csv?download", } def __init__( - self, - root: str = "data", - download: bool = False, - checksum: bool = False, + self, root: str = "data", download: bool = False, checksum: bool = False ) -> None: - """Initialize a new USAVars dataset instance. - """ + """Initialize a new USAVars dataset instance.""" self.root = root self.download = download @@ -83,11 +83,15 @@ def _load_files(self) -> List[Dict[str, Any]]: file_path = os.path.join(self.root, self.dirname, "uar") files = os.listdir(file_path) - files = files[:10] # TODO: remove this, keeping temporarily because this func is very slow + files = files[ + :10 + ] # TODO: remove this, keeping temporarily because this func is very slow # csvs = self.label_urls.keys() # only uar for now csvs = ["treecover", "elevation", "population"] - labels_ds = [(lab, pd.read_csv(os.path.join(self.root, lab + ".csv"))) for lab in csvs] + labels_ds = [ + (lab, pd.read_csv(os.path.join(self.root, lab + ".csv"))) for lab in csvs + ] samples = [] for f in files: img_path = os.path.join(file_path, f) @@ -137,16 +141,12 @@ def _verify(self) -> None: def _download(self) -> None: for f_name in self.label_urls: - download_url( - self.label_urls[f_name], - self.root, - filename=f_name + ".csv", - ) + download_url(self.label_urls[f_name], self.root, filename=f_name + ".csv") download_url( - self.data_url, - self.root, - filename=self.zipfile, - md5=self.md5 if self.checksum else None, + self.data_url, + self.root, + filename=self.zipfile, + md5=self.md5 if self.checksum else None, ) def _extract(self) -> None: @@ -154,7 +154,6 @@ def _extract(self) -> None: dst = os.path.join(self.root, self.dirname) extract_archive(src, dst) - def plot( self, sample: Dict[str, Tensor], @@ -172,9 +171,9 @@ def plot( a matplotlib Figure with the rendered sample """ - image = sample["image"][:3].numpy() # get RGB inds + image = sample["image"][:3].numpy() # get RGB inds image = np.moveaxis(image, 0, 2) - + fig, axs = plt.subplots(figsize=(10, 10)) axs.imshow(image) axs.axis("off") From b78ee807668665a068cc6a8763526c167852e4fa Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 16:38:35 +0000 Subject: [PATCH 10/65] add description --- torchgeo/datasets/usavars.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index b3953765631..7b8f42ed532 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -18,9 +18,25 @@ class USAVars(VisionDataset): - # TODO: complete this - """ """ - csv_prefix = "https://files.codeocean.com/files/verified/fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" + """USAVars dataset. + + Dataset format: + * images are 4-channel tifs + * labels are singular float values + + Dataset labels: + - tree cover + - elevation + - population density + - nighttime lights + - income per houshold + - road length + - housing price + + .. versionadded:: 0.3 + """ + csv_prefix = "https://files.codeocean.com/files/verified/" + + "fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" data_url = "https://mosaiks.blob.core.windows.net/datasets/uar.zip" dirname = "usavars" From 0311f4e329f3912de9b0d2fafa197883eb20144a Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 16:40:27 +0000 Subject: [PATCH 11/65] black fix --- torchgeo/datasets/usavars.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 7b8f42ed532..732fecf89f0 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -19,7 +19,7 @@ class USAVars(VisionDataset): """USAVars dataset. - + Dataset format: * images are 4-channel tifs * labels are singular float values @@ -32,11 +32,12 @@ class USAVars(VisionDataset): - income per houshold - road length - housing price - + .. versionadded:: 0.3 """ + csv_prefix = "https://files.codeocean.com/files/verified/" - + "fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" + +"fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" data_url = "https://mosaiks.blob.core.windows.net/datasets/uar.zip" dirname = "usavars" From ebbf9d09448cc2021edc3cf56d0ee40b63e47c3d Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 16:47:22 +0000 Subject: [PATCH 12/65] flake8 fix --- torchgeo/datasets/usavars.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 732fecf89f0..edc504da3e4 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -38,6 +38,7 @@ class USAVars(VisionDataset): csv_prefix = "https://files.codeocean.com/files/verified/" +"fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" + csv_postfix = "_CONTUS_16_640_POP_100000_0.csv?download" data_url = "https://mosaiks.blob.core.windows.net/datasets/uar.zip" dirname = "usavars" @@ -46,20 +47,17 @@ class USAVars(VisionDataset): md5 = "677e89fd20e5dd0fe4d29b61827c2456" label_urls = { - "housing": csv_prefix - + "housing/outcomes_sampled_housing_CONTUS_16_640_POP_100000_0.csv?download", - "income": csv_prefix - + "income/outcomes_sampled_income_CONTUS_16_640_POP_100000_0.csv?download", - "roads": csv_prefix - + "roads/outcomes_sampled_roads_CONTUS_16_640_POP_100000_0.csv?download", + "housing": csv_prefix + "housing/outcomes_sampled_housing" + csv_postfix, + "income": csv_prefix + "income/outcomes_sampled_income" + csv_postfix, + "roads": csv_prefix + "roads/outcomes_sampled_roads" + csv_postfix, "nightligths": csv_prefix - + "nightlights/outcomes_sampled_nightlights_CONTUS_16_640_POP_100000_0.csv?download", + + "nightlights/outcomes_sampled_nightlights" + + csv_postfix, "population": csv_prefix - + "population/outcomes_sampled_population_CONTUS_16_640_UAR_100000_0.csv?download", - "elevation": csv_prefix - + "elevation/outcomes_sampled_elevation_CONTUS_16_640_UAR_100000_0.csv?download", - "treecover": csv_prefix - + "treecover/outcomes_sampled_treecover_CONTUS_16_640_UAR_100000_0.csv?download", + + "population/outcomes_sampled_population" + + csv_postfix, + "elevation": csv_prefix + "elevation/outcomes_sampled_elevation" + csv_postfix, + "treecover": csv_prefix + "treecover/outcomes_sampled_treecover" + csv_postfix, } def __init__( @@ -198,8 +196,8 @@ def plot( if show_labels: labels = [(lab, val) for lab, val in sample.items() if lab != "image"] label_string = "" - for l, v in labels: - label_string += f"{l}={v} " + for lab, val in labels: + label_string += f"{lab}={val} " axs.set_title(label_string) From b0d34c95cf98e7f43a84d0fc656db9bb1dba252d Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 16:54:24 +0000 Subject: [PATCH 13/65] pydocstyle fix --- torchgeo/datasets/usavars.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index edc504da3e4..4cd4c2c7b13 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +"""USAVars dataset.""" + import glob import os from typing import Any, Dict, List, Optional, Union @@ -64,7 +66,6 @@ def __init__( self, root: str = "data", download: bool = False, checksum: bool = False ) -> None: """Initialize a new USAVars dataset instance.""" - self.root = root self.download = download self.checksum = checksum @@ -128,10 +129,10 @@ def _load_image(self, path: str) -> Tensor: def _verify(self) -> None: """Verify the integrity of the dataset. + Raises: RuntimeError: if ``download=False`` but dataset is missing or checksum fails """ - # Check if the extracted files already exist pathname = os.path.join(self.root, self.dirname) if glob.glob(pathname): @@ -185,7 +186,6 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - image = sample["image"][:3].numpy() # get RGB inds image = np.moveaxis(image, 0, 2) From 062c9f8978605ce6ae383894f74644e09e7e6ee8 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 16:56:45 +0000 Subject: [PATCH 14/65] mypy fix --- torchgeo/datasets/usavars.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 4cd4c2c7b13..82bbbdd172b 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -38,8 +38,8 @@ class USAVars(VisionDataset): .. versionadded:: 0.3 """ - csv_prefix = "https://files.codeocean.com/files/verified/" - +"fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" + csv_prefix = ("https://files.codeocean.com/files/verified/" + + "fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/") csv_postfix = "_CONTUS_16_640_POP_100000_0.csv?download" data_url = "https://mosaiks.blob.core.windows.net/datasets/uar.zip" From 7bffb9c4ac2f40d8073e89e9061115c1f0f6a054 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 16:59:40 +0000 Subject: [PATCH 15/65] add DS to docs --- docs/api/datasets.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 9d5da998add..abc8465c625 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -222,6 +222,11 @@ UC Merced .. autoclass:: UCMerced +USAVars +^^^^^^^ + +.. autoclass:: USAVars + Vaihingen ^^^^^^^^^ From 168f9890376717187fc1e1e62821860ecb0e0633 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 17:04:58 +0000 Subject: [PATCH 16/65] black fix --- torchgeo/datasets/usavars.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 82bbbdd172b..1696dc3c154 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -38,8 +38,10 @@ class USAVars(VisionDataset): .. versionadded:: 0.3 """ - csv_prefix = ("https://files.codeocean.com/files/verified/" + - "fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/") + csv_prefix = ( + "https://files.codeocean.com/files/verified/" + + "fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" + ) csv_postfix = "_CONTUS_16_640_POP_100000_0.csv?download" data_url = "https://mosaiks.blob.core.windows.net/datasets/uar.zip" From 7e940e1ba2b438389930e4eccc6e9ba43feb156b Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 18:28:33 +0000 Subject: [PATCH 17/65] fake dataset --- tests/data/usavars/usavars.zip | Bin 0 -> 774 bytes tests/data/usavars/usavars/uar/tile_0,0.tif | Bin 0 -> 422 bytes tests/data/usavars/usavars/uar/tile_0,1.tif | Bin 0 -> 422 bytes 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/data/usavars/usavars.zip create mode 100644 tests/data/usavars/usavars/uar/tile_0,0.tif create mode 100644 tests/data/usavars/usavars/uar/tile_0,1.tif diff --git a/tests/data/usavars/usavars.zip b/tests/data/usavars/usavars.zip new file mode 100644 index 0000000000000000000000000000000000000000..457700ccf1998363e3f44cc9c76d8e74fa8893a8 GIT binary patch literal 774 zcmWIWW@Zs#0D-A};UQoKlwbkUrHMuQ0Z$S(cQ|3@?in>_TK(ZZ61CQxHo@Vca~s zp+zG^W{Eq`NtwbGQlbje9F5dDPR>bK5qrI1#`>A;Yu?M9xFaUVW~T6$x8wSDV`B}5 zOw$F+df)o+7&(+lonTOS(BR~z(ZayU9wKFKz{d1*yH}sXsvooSCMf&JB#Bwy2u_&1 z&_ww3v9iDkHph>ZnXv`3nek1im~X_)-w}2?h4J#NQ+0iMV#m0CaV=Eb#})dsawS)r zSeY7@L*&TVaA+s;atK>XEAa-l zRu;ZlSBT)Z4DT^JGdpY$)W{*CG?6e#z?ldU^9HH{^o&F;&P+s;atK>XEAa-l zRu;ZlSBT)Z4DT^JGdpY$)W{*CG?6e#z?ldU^9HH{^o&F;&P Date: Wed, 23 Feb 2022 19:03:32 +0000 Subject: [PATCH 18/65] add transforms arg --- torchgeo/datasets/usavars.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 1696dc3c154..b9859d8aeec 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -65,10 +65,14 @@ class USAVars(VisionDataset): } def __init__( - self, root: str = "data", download: bool = False, checksum: bool = False + self, root: str = "data", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + download: bool = False, + checksum: bool = False ) -> None: """Initialize a new USAVars dataset instance.""" self.root = root + self.transforms = transforms self.download = download self.checksum = checksum @@ -87,6 +91,10 @@ def __getitem__(self, index: int) -> Dict[str, Union[Tensor, float]]: """ sample = self.files[index] sample["image"] = self._load_image(sample["image"]) + + if self.transforms is not None: + sample = self.transforms(sample) + return sample def __len__(self) -> int: From 0a880b50583b2957ea85f8208530389d16440fd4 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 19:18:21 +0000 Subject: [PATCH 19/65] initial tests --- tests/datasets/test_usavars.py | 100 +++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 tests/datasets/test_usavars.py diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py new file mode 100644 index 00000000000..52fad97e270 --- /dev/null +++ b/tests/datasets/test_usavars.py @@ -0,0 +1,100 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path +from typing import Generator + +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch +from matplotlib import pyplot as plt +from torch.utils.data import ConcatDataset + +from torchgeo.datasets import USAVars + +def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: + shutil.copy(url, root) + + +class TestUSAVars: + def dataset( + self, + monkeypatch: Generator[MonkeyPatch, None, None], + tmp_path: Path, + request: SubRequest, + ) -> USAVars: + + monkeypatch.setattr( # type: ignore[attr-defined] + torchgeo.datasets.usavars, "download_url", download_url + ) + + md5 = "b504580a00bdc27097d5421dec50481b" + monkeypatch.setattr(USAVars, "md5", md5) # type: ignore[attr-defined] + + data_url = os.path.join("tests", "data", "usavars", "usavars.zip") + monkeypatch.setattr(USAVars, "data_url", data_url) # type: ignore[attr-defined] + + label_urls = { + "elevation": os.path.join( + "tests", + "data", + "usavars", + "elevation.csv", + ), + "population": os.path.join( + "tests", + "data", + "usavars", + "population.csv", + ), + "treecover": os.path.join( + "tests", + "data", + "usavars", + "treecover.csv", + ), + } + monkeypatch.setattr(USAVars, "label_urls", label_urls) # type: ignore[attr-defined] + + root = str(tmp_path) + transforms = nn.Identity() # type: ignore[attr-defined] + + return USAVars( + root, transforms=transformsdownload=True, checksum=True + ) + + def test_getitem(self, dataset: USAVars) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert x["image"].ndim == 4 + assert len(x.keys()) == 4 # image, elevation, population, treecover + assert x["image"].shape[0] == 4 # R, G, B, Inf + + def test_len(self, dataset: USAVars) -> None: + assert len(dataset) == 2 + + def test_add(self, dataset: USAVars) -> None: + ds = dataset + dataset + assert isinstance(ds, ConcatDataset) + + def test_already_extracted(self, dataset: USAVars) -> None: + USAVars(root=dataset.root, download=True) + + def test_already_downloaded(self, tmp_path: Path) -> None: + pathname = os.path.join("tests", "data", "usavars", "usavars.zip") + root = str(tmp_path) + shutil.copy(pathname, root) + USAVars(root) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + USAVars(str(tmp_path)) + + def test_plot(self, dataset: USAVars) -> None: + dataset.plot(dataset[0], suptitle="Test") + plt.close() From 61739b7a321efce70ae302920cc7fbb168be9f34 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 19:21:52 +0000 Subject: [PATCH 20/65] fix black flake8 isort --- tests/datasets/test_usavars.py | 39 ++++++++++++---------------------- 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py index 52fad97e270..d1104e2026f 100644 --- a/tests/datasets/test_usavars.py +++ b/tests/datasets/test_usavars.py @@ -14,8 +14,10 @@ from matplotlib import pyplot as plt from torch.utils.data import ConcatDataset +import torchgeo.datasets.utils from torchgeo.datasets import USAVars + def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -39,41 +41,28 @@ def dataset( monkeypatch.setattr(USAVars, "data_url", data_url) # type: ignore[attr-defined] label_urls = { - "elevation": os.path.join( - "tests", - "data", - "usavars", - "elevation.csv", - ), - "population": os.path.join( - "tests", - "data", - "usavars", - "population.csv", - ), - "treecover": os.path.join( - "tests", - "data", - "usavars", - "treecover.csv", - ), + "elevation": os.path.join("tests", "data", "usavars", "elevation.csv"), + "population": os.path.join("tests", "data", "usavars", "population.csv"), + "treecover": os.path.join("tests", "data", "usavars", "treecover.csv"), } - monkeypatch.setattr(USAVars, "label_urls", label_urls) # type: ignore[attr-defined] + monkeypatch.setattr( # type: ignore[attr-defined] + USAVars, + "label_urls", + label_urls, + ) root = str(tmp_path) transforms = nn.Identity() # type: ignore[attr-defined] - return USAVars( - root, transforms=transformsdownload=True, checksum=True - ) + return USAVars(root, transforms=transforms, download=True, checksum=True) def test_getitem(self, dataset: USAVars) -> None: x = dataset[0] assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) assert x["image"].ndim == 4 - assert len(x.keys()) == 4 # image, elevation, population, treecover - assert x["image"].shape[0] == 4 # R, G, B, Inf + assert len(x.keys()) == 4 # image, elevation, population, treecover + assert x["image"].shape[0] == 4 # R, G, B, Inf def test_len(self, dataset: USAVars) -> None: assert len(dataset) == 2 @@ -81,7 +70,7 @@ def test_len(self, dataset: USAVars) -> None: def test_add(self, dataset: USAVars) -> None: ds = dataset + dataset assert isinstance(ds, ConcatDataset) - + def test_already_extracted(self, dataset: USAVars) -> None: USAVars(root=dataset.root, download=True) From b5e2512891a7c09dab5f8de77dd6ef7bc8a8acf4 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 19:25:23 +0000 Subject: [PATCH 21/65] fix black flake8 --- torchgeo/datasets/usavars.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index b9859d8aeec..939ae591d2d 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import matplotlib.pyplot as plt import numpy as np @@ -65,10 +65,11 @@ class USAVars(VisionDataset): } def __init__( - self, root: str = "data", + self, + root: str = "data", transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, download: bool = False, - checksum: bool = False + checksum: bool = False, ) -> None: """Initialize a new USAVars dataset instance.""" self.root = root From 74aaab48d94d4fd17c0cdca1e72cabeca99623fc Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 19:27:29 +0000 Subject: [PATCH 22/65] fix black --- tests/datasets/test_usavars.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py index d1104e2026f..74d2f7afc80 100644 --- a/tests/datasets/test_usavars.py +++ b/tests/datasets/test_usavars.py @@ -46,9 +46,7 @@ def dataset( "treecover": os.path.join("tests", "data", "usavars", "treecover.csv"), } monkeypatch.setattr( # type: ignore[attr-defined] - USAVars, - "label_urls", - label_urls, + USAVars, "label_urls", label_urls ) root = str(tmp_path) From a47e2fb7965262d8655f96adb2a60c2656482944 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 19:51:30 +0000 Subject: [PATCH 23/65] fix mypy --- torchgeo/datasets/usavars.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 939ae591d2d..e2ec872d47c 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -183,7 +183,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: Dict[str, Union[Tensor, float]], show_labels: bool = True, suptitle: Optional[str] = None, ) -> Figure: @@ -197,7 +197,8 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - image = sample["image"][:3].numpy() # get RGB inds + image: Tensor = torch.Tensor(sample["image"]) + image = image[:3].numpy() # get RGB inds image = np.moveaxis(image, 0, 2) fig, axs = plt.subplots(figsize=(10, 10)) From e081de864c136bca062ae4107b5457f6edfdf38b Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 20:15:41 +0000 Subject: [PATCH 24/65] test fixes --- tests/datasets/test_usavars.py | 6 +++++- torchgeo/datasets/usavars.py | 20 ++++++++++++-------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py index 74d2f7afc80..3b53cb72124 100644 --- a/tests/datasets/test_usavars.py +++ b/tests/datasets/test_usavars.py @@ -23,6 +23,7 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestUSAVars: + @pytest.fixture() def dataset( self, monkeypatch: Generator[MonkeyPatch, None, None], @@ -58,7 +59,7 @@ def test_getitem(self, dataset: USAVars) -> None: x = dataset[0] assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) - assert x["image"].ndim == 4 + assert x["image"].ndim == 3 assert len(x.keys()) == 4 # image, elevation, population, treecover assert x["image"].shape[0] == 4 # R, G, B, Inf @@ -76,6 +77,9 @@ def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join("tests", "data", "usavars", "usavars.zip") root = str(tmp_path) shutil.copy(pathname, root) + for csv in ["elevation.csv", "population.csv", "treecover.csv"]: + shutil.copy(os.path.join("tests", "data", "usavars", csv), root) + USAVars(root) def test_not_downloaded(self, tmp_path: Path) -> None: diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index e2ec872d47c..9c6bcf5e711 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional import matplotlib.pyplot as plt import numpy as np @@ -81,7 +81,7 @@ def __init__( self.files = self._load_files() - def __getitem__(self, index: int) -> Dict[str, Union[Tensor, float]]: + def __getitem__(self, index: int) -> Dict[str, Tensor]: """Return an index within the dataset. Args: @@ -91,12 +91,17 @@ def __getitem__(self, index: int) -> Dict[str, Union[Tensor, float]]: data and label at that index """ sample = self.files[index] - sample["image"] = self._load_image(sample["image"]) + tensor_sample = {} + tensor_sample["image"] = self._load_image(sample["image"]) + + keys = [key for key in sample.keys() if key != "image"] + for key in keys: + tensor_sample[key] = Tensor([sample[key]]) if self.transforms is not None: - sample = self.transforms(sample) + tensor_sample = self.transforms(tensor_sample) - return sample + return tensor_sample def __len__(self) -> int: """Return the number of data points in the dataset. @@ -183,7 +188,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Union[Tensor, float]], + sample: Dict[str, Tensor], show_labels: bool = True, suptitle: Optional[str] = None, ) -> Figure: @@ -197,8 +202,7 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - image: Tensor = torch.Tensor(sample["image"]) - image = image[:3].numpy() # get RGB inds + image = sample["image"][:3].numpy() # get RGB inds image = np.moveaxis(image, 0, 2) fig, axs = plt.subplots(figsize=(10, 10)) From 73a87f8cfcae563093ae0333449a1e8998861413 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 20:42:03 +0000 Subject: [PATCH 25/65] testing something --- tests/datasets/test_usavars.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py index 3b53cb72124..26dea4a7298 100644 --- a/tests/datasets/test_usavars.py +++ b/tests/datasets/test_usavars.py @@ -74,7 +74,7 @@ def test_already_extracted(self, dataset: USAVars) -> None: USAVars(root=dataset.root, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "usavars", "usavars.zip") + pathname = os.path.join("tests", "data", "usavars", "savars.zip") root = str(tmp_path) shutil.copy(pathname, root) for csv in ["elevation.csv", "population.csv", "treecover.csv"]: From efb8360e7dd35fd02e84324dd8ac6353390be9e9 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 20:48:22 +0000 Subject: [PATCH 26/65] it finds zip but not csv --- tests/datasets/test_usavars.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py index 26dea4a7298..3b53cb72124 100644 --- a/tests/datasets/test_usavars.py +++ b/tests/datasets/test_usavars.py @@ -74,7 +74,7 @@ def test_already_extracted(self, dataset: USAVars) -> None: USAVars(root=dataset.root, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "usavars", "savars.zip") + pathname = os.path.join("tests", "data", "usavars", "usavars.zip") root = str(tmp_path) shutil.copy(pathname, root) for csv in ["elevation.csv", "population.csv", "treecover.csv"]: From dea370395d03ab0ebf2382ef5cb158243b5ccac1 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 20:49:54 +0000 Subject: [PATCH 27/65] fake csv files didn't get added --- tests/data/usavars/elevation.csv | 3 +++ tests/data/usavars/population.csv | 3 +++ tests/data/usavars/treecover.csv | 3 +++ 3 files changed, 9 insertions(+) create mode 100644 tests/data/usavars/elevation.csv create mode 100644 tests/data/usavars/population.csv create mode 100644 tests/data/usavars/treecover.csv diff --git a/tests/data/usavars/elevation.csv b/tests/data/usavars/elevation.csv new file mode 100644 index 00000000000..4c0a6a31a19 --- /dev/null +++ b/tests/data/usavars/elevation.csv @@ -0,0 +1,3 @@ +,Unnamed: 0,ID,lon,lat,elevation +0,1,"0,0",0.0,0.0,0.0 +1,2,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/population.csv b/tests/data/usavars/population.csv new file mode 100644 index 00000000000..2443630f570 --- /dev/null +++ b/tests/data/usavars/population.csv @@ -0,0 +1,3 @@ +,Unnamed: 0,V1,V1.1,ID,lon,lat,population +0,1,1,1,"0,0",0.0,0.0,0.0 +1,2,2,2,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/treecover.csv b/tests/data/usavars/treecover.csv new file mode 100644 index 00000000000..9ffc19a091e --- /dev/null +++ b/tests/data/usavars/treecover.csv @@ -0,0 +1,3 @@ +,Unnamed: 0,ID,lon,lat,treecover +0,1,"0,0",0.0,0.0,0.0 +1,2,"0,1",0.1,0.1,1.0 From 81afd9aef0a037d73c611487406a705811ee88e4 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 21:03:03 +0000 Subject: [PATCH 28/65] pandas docs fix --- torchgeo/datasets/usavars.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 9c6bcf5e711..53ec78f7f50 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -79,6 +79,13 @@ def __init__( self._verify() + try: + import pandas as pd # noqa: F401 + except ImportError: + raise ImportError( + "pandas is not installed and is required to use this dataset" + ) + self.files = self._load_files() def __getitem__(self, index: int) -> Dict[str, Tensor]: @@ -213,6 +220,8 @@ def plot( labels = [(lab, val) for lab, val in sample.items() if lab != "image"] label_string = "" for lab, val in labels: + print(val[0].item()) + label_string += f"{lab}={val} " axs.set_title(label_string) From bd75a008041eb4c89f740840ff723fee275ab40f Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 21:04:37 +0000 Subject: [PATCH 29/65] forgot to take out here --- torchgeo/datasets/usavars.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 53ec78f7f50..281b3006669 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -9,7 +9,6 @@ import matplotlib.pyplot as plt import numpy as np -import pandas as pd import rasterio import torch from matplotlib.figure import Figure From bd90b5e0b3cadfd80e2984685195886ed8c68e90 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 21:07:05 +0000 Subject: [PATCH 30/65] need to add in functions --- torchgeo/datasets/usavars.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 281b3006669..b0a0b33bc90 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -118,6 +118,8 @@ def __len__(self) -> int: return len(self.files) def _load_files(self) -> List[Dict[str, Any]]: + import pandas as pd + file_path = os.path.join(self.root, self.dirname, "uar") files = os.listdir(file_path) @@ -219,7 +221,7 @@ def plot( labels = [(lab, val) for lab, val in sample.items() if lab != "image"] label_string = "" for lab, val in labels: - print(val[0].item()) + print(round(val[0].item(), 2)) label_string += f"{lab}={val} " From d1deebfafd718697a7d6559896ca2038546e97e3 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 21:13:09 +0000 Subject: [PATCH 31/65] round plot labels --- torchgeo/datasets/usavars.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index b0a0b33bc90..39754c3dc9e 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -221,10 +221,7 @@ def plot( labels = [(lab, val) for lab, val in sample.items() if lab != "image"] label_string = "" for lab, val in labels: - print(round(val[0].item(), 2)) - - label_string += f"{lab}={val} " - + label_string += f"{lab}={round(val[0].item(), 2)} " axs.set_title(label_string) if suptitle is not None: From 0e6202a99f45746f575d18fc818bb7fc0a89d369 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Wed, 23 Feb 2022 21:34:08 +0000 Subject: [PATCH 32/65] Small edits --- torchgeo/datasets/usavars.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 39754c3dc9e..7b47f4513b2 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -37,30 +37,26 @@ class USAVars(VisionDataset): .. versionadded:: 0.3 """ - csv_prefix = ( + url_prefix = ( "https://files.codeocean.com/files/verified/" - + "fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/" + + "fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications" ) - csv_postfix = "_CONTUS_16_640_POP_100000_0.csv?download" + pop_csv_suffix = "CONTUS_16_640_POP_100000_0.csv?download" data_url = "https://mosaiks.blob.core.windows.net/datasets/uar.zip" dirname = "usavars" - zipfile = dirname + ".zip" + zipfile = "uar.zip" md5 = "677e89fd20e5dd0fe4d29b61827c2456" label_urls = { - "housing": csv_prefix + "housing/outcomes_sampled_housing" + csv_postfix, - "income": csv_prefix + "income/outcomes_sampled_income" + csv_postfix, - "roads": csv_prefix + "roads/outcomes_sampled_roads" + csv_postfix, - "nightligths": csv_prefix - + "nightlights/outcomes_sampled_nightlights" - + csv_postfix, - "population": csv_prefix - + "population/outcomes_sampled_population" - + csv_postfix, - "elevation": csv_prefix + "elevation/outcomes_sampled_elevation" + csv_postfix, - "treecover": csv_prefix + "treecover/outcomes_sampled_treecover" + csv_postfix, + "housing": f"{url_prefix}/housing/outcomes_sampled_housing_{pop_csv_suffix}", + "income": f"{url_prefix}/income/outcomes_sampled_income_{pop_csv_suffix}", + "roads": f"{url_prefix}/roads/outcomes_sampled_roads_{pop_csv_suffix}", + "nightligths": f"{url_prefix}/nightlights/outcomes_sampled_nightlights_{pop_csv_suffix}", + "population": f"{url_prefix}/population/outcomes_sampled_population_{pop_csv_suffix}", + "elevation": f"{url_prefix}/elevation/outcomes_sampled_elevation_{pop_csv_suffix}", + "treecover": f"{url_prefix}/treecover/outcomes_sampled_treecover_{pop_csv_suffix}", } def __init__( @@ -182,6 +178,7 @@ def _verify(self) -> None: def _download(self) -> None: for f_name in self.label_urls: download_url(self.label_urls[f_name], self.root, filename=f_name + ".csv") + download_url( self.data_url, self.root, From d6bc60f22237710ca43c5f1c2d961eb13acf9ede Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 22:45:50 +0000 Subject: [PATCH 33/65] remove Unnamed column --- tests/data/usavars/elevation.csv | 6 +++--- tests/data/usavars/population.csv | 6 +++--- tests/data/usavars/treecover.csv | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/data/usavars/elevation.csv b/tests/data/usavars/elevation.csv index 4c0a6a31a19..9e10a3141e9 100644 --- a/tests/data/usavars/elevation.csv +++ b/tests/data/usavars/elevation.csv @@ -1,3 +1,3 @@ -,Unnamed: 0,ID,lon,lat,elevation -0,1,"0,0",0.0,0.0,0.0 -1,2,"0,1",0.1,0.1,1.0 +,ID,lon,lat,elevation +0,"0,0",0.0,0.0,0.0 +1,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/population.csv b/tests/data/usavars/population.csv index 2443630f570..99a9d9425c6 100644 --- a/tests/data/usavars/population.csv +++ b/tests/data/usavars/population.csv @@ -1,3 +1,3 @@ -,Unnamed: 0,V1,V1.1,ID,lon,lat,population -0,1,1,1,"0,0",0.0,0.0,0.0 -1,2,2,2,"0,1",0.1,0.1,1.0 +,V1,V1.1,ID,lon,lat,population +0,1,1,"0,0",0.0,0.0,0.0 +1,2,2,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/treecover.csv b/tests/data/usavars/treecover.csv index 9ffc19a091e..eb3d5030c81 100644 --- a/tests/data/usavars/treecover.csv +++ b/tests/data/usavars/treecover.csv @@ -1,3 +1,3 @@ -,Unnamed: 0,ID,lon,lat,treecover -0,1,"0,0",0.0,0.0,0.0 -1,2,"0,1",0.1,0.1,1.0 +,ID,lon,lat,treecover +0,"0,0",0.0,0.0,0.0 +1,"0,1",0.1,0.1,1.0 From bf77ea32d8f62f7fd306c37464713d4b9f6f0a15 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 22:49:18 +0000 Subject: [PATCH 34/65] zipfile change --- torchgeo/datasets/usavars.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 7b47f4513b2..8be1db73b1f 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -45,7 +45,7 @@ class USAVars(VisionDataset): data_url = "https://mosaiks.blob.core.windows.net/datasets/uar.zip" dirname = "usavars" - zipfile = "uar.zip" + zipfile = "usavars.zip" md5 = "677e89fd20e5dd0fe4d29b61827c2456" From 9e63d67ff28c308f56bb384bbbb2a8e22c71c7dc Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 22:57:03 +0000 Subject: [PATCH 35/65] i think this solves codecov? --- tests/datasets/test_usavars.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py index 3b53cb72124..95f0005efcb 100644 --- a/tests/datasets/test_usavars.py +++ b/tests/datasets/test_usavars.py @@ -17,6 +17,8 @@ import torchgeo.datasets.utils from torchgeo.datasets import USAVars +pytest.importorskip("pandas", minversion="0.19.1") + def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: shutil.copy(url, root) From 18a1904683744b8ac4fa522d6d9b4b1235f770c0 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 23:03:15 +0000 Subject: [PATCH 36/65] there needs to be a test --- tests/datasets/test_usavars.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py index 95f0005efcb..3f5df3dc2f6 100644 --- a/tests/datasets/test_usavars.py +++ b/tests/datasets/test_usavars.py @@ -88,6 +88,18 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found"): USAVars(str(tmp_path)) + @pytest.fixture(params=["pandas"]) + def test_mock_missing_module( + self, dataset: USAVars, mock_missing_module: str + ) -> None: + package = mock_missing_module + if package == "pandas": + with pytest.raises( + ImportError, + match=f"{package} is not installed and is required to use this dataset", + ): + USAVars(dataset.root) + def test_plot(self, dataset: USAVars) -> None: dataset.plot(dataset[0], suptitle="Test") plt.close() From 2820635d0b89cffc02a4a5dba2bfc8bbc96d93a1 Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 23 Feb 2022 23:14:52 +0000 Subject: [PATCH 37/65] codecov --- tests/datasets/test_usavars.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py index 3f5df3dc2f6..de9587ff25b 100644 --- a/tests/datasets/test_usavars.py +++ b/tests/datasets/test_usavars.py @@ -1,10 +1,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import builtins import os import shutil from pathlib import Path -from typing import Generator +from typing import Any, Generator import pytest import torch @@ -89,6 +90,22 @@ def test_not_downloaded(self, tmp_path: Path) -> None: USAVars(str(tmp_path)) @pytest.fixture(params=["pandas"]) + def mock_missing_module( + self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest + ) -> str: + import_orig = builtins.__import__ + package = str(request.param) + + def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == package: + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr( # type: ignore[attr-defined] + builtins, "__import__", mocked_import + ) + return package + def test_mock_missing_module( self, dataset: USAVars, mock_missing_module: str ) -> None: From dd9b3b716939240972d2eebbafbed8961672ffae Mon Sep 17 00:00:00 2001 From: iejMac Date: Thu, 24 Feb 2022 17:22:53 +0000 Subject: [PATCH 38/65] bring back UAR! --- torchgeo/datasets/usavars.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 8be1db73b1f..4aa0298b640 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -42,6 +42,7 @@ class USAVars(VisionDataset): + "fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications" ) pop_csv_suffix = "CONTUS_16_640_POP_100000_0.csv?download" + uar_csv_suffix = "CONTUS_16_640_UAR_100000_0.csv?download" data_url = "https://mosaiks.blob.core.windows.net/datasets/uar.zip" dirname = "usavars" @@ -54,9 +55,9 @@ class USAVars(VisionDataset): "income": f"{url_prefix}/income/outcomes_sampled_income_{pop_csv_suffix}", "roads": f"{url_prefix}/roads/outcomes_sampled_roads_{pop_csv_suffix}", "nightligths": f"{url_prefix}/nightlights/outcomes_sampled_nightlights_{pop_csv_suffix}", - "population": f"{url_prefix}/population/outcomes_sampled_population_{pop_csv_suffix}", - "elevation": f"{url_prefix}/elevation/outcomes_sampled_elevation_{pop_csv_suffix}", - "treecover": f"{url_prefix}/treecover/outcomes_sampled_treecover_{pop_csv_suffix}", + "population": f"{url_prefix}/population/outcomes_sampled_population_{uar_csv_suffix}", + "elevation": f"{url_prefix}/elevation/outcomes_sampled_elevation_{uar_csv_suffix}", + "treecover": f"{url_prefix}/treecover/outcomes_sampled_treecover_{uar_csv_suffix}", } def __init__( @@ -128,6 +129,7 @@ def _load_files(self) -> List[Dict[str, Any]]: labels_ds = [ (lab, pd.read_csv(os.path.join(self.root, lab + ".csv"))) for lab in csvs ] + samples = [] for f in files: img_path = os.path.join(file_path, f) From b45a4c3de333475cd6e2b0bd1ea144b9bf8b8b41 Mon Sep 17 00:00:00 2001 From: iejMac Date: Thu, 24 Feb 2022 17:30:48 +0000 Subject: [PATCH 39/65] remove intermediate directory --- tests/data/usavars/{usavars => }/uar/tile_0,0.tif | Bin tests/data/usavars/{usavars => }/uar/tile_0,1.tif | Bin torchgeo/datasets/usavars.py | 9 +++------ 3 files changed, 3 insertions(+), 6 deletions(-) rename tests/data/usavars/{usavars => }/uar/tile_0,0.tif (100%) rename tests/data/usavars/{usavars => }/uar/tile_0,1.tif (100%) diff --git a/tests/data/usavars/usavars/uar/tile_0,0.tif b/tests/data/usavars/uar/tile_0,0.tif similarity index 100% rename from tests/data/usavars/usavars/uar/tile_0,0.tif rename to tests/data/usavars/uar/tile_0,0.tif diff --git a/tests/data/usavars/usavars/uar/tile_0,1.tif b/tests/data/usavars/uar/tile_0,1.tif similarity index 100% rename from tests/data/usavars/usavars/uar/tile_0,1.tif rename to tests/data/usavars/uar/tile_0,1.tif diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 4aa0298b640..da764dc1a76 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -45,7 +45,6 @@ class USAVars(VisionDataset): uar_csv_suffix = "CONTUS_16_640_UAR_100000_0.csv?download" data_url = "https://mosaiks.blob.core.windows.net/datasets/uar.zip" - dirname = "usavars" zipfile = "usavars.zip" md5 = "677e89fd20e5dd0fe4d29b61827c2456" @@ -117,7 +116,7 @@ def __len__(self) -> int: def _load_files(self) -> List[Dict[str, Any]]: import pandas as pd - file_path = os.path.join(self.root, self.dirname, "uar") + file_path = os.path.join(self.root, "uar") files = os.listdir(file_path) files = files[ @@ -156,7 +155,7 @@ def _verify(self) -> None: RuntimeError: if ``download=False`` but dataset is missing or checksum fails """ # Check if the extracted files already exist - pathname = os.path.join(self.root, self.dirname) + pathname = os.path.join(self.root, "uar") if glob.glob(pathname): return @@ -189,9 +188,7 @@ def _download(self) -> None: ) def _extract(self) -> None: - src = os.path.join(self.root, self.zipfile) - dst = os.path.join(self.root, self.dirname) - extract_archive(src, dst) + extract_archive(os.path.join(self.root, self.zipfile)) def plot( self, From c1f934530d182568fdf83a28e3f70bfcff0274ae Mon Sep 17 00:00:00 2001 From: iejMac Date: Thu, 24 Feb 2022 17:36:05 +0000 Subject: [PATCH 40/65] fix flake8 --- torchgeo/datasets/usavars.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index da764dc1a76..aed725ab53e 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -53,10 +53,14 @@ class USAVars(VisionDataset): "housing": f"{url_prefix}/housing/outcomes_sampled_housing_{pop_csv_suffix}", "income": f"{url_prefix}/income/outcomes_sampled_income_{pop_csv_suffix}", "roads": f"{url_prefix}/roads/outcomes_sampled_roads_{pop_csv_suffix}", - "nightligths": f"{url_prefix}/nightlights/outcomes_sampled_nightlights_{pop_csv_suffix}", - "population": f"{url_prefix}/population/outcomes_sampled_population_{uar_csv_suffix}", - "elevation": f"{url_prefix}/elevation/outcomes_sampled_elevation_{uar_csv_suffix}", - "treecover": f"{url_prefix}/treecover/outcomes_sampled_treecover_{uar_csv_suffix}", + "nightligths": f"{url_prefix}/nightlights/" + + f"outcomes_sampled_nightlights_{pop_csv_suffix}", + "population": f"{url_prefix}/population/" + + f"outcomes_sampled_population_{uar_csv_suffix}", + "elevation": f"{url_prefix}/elevation/" + + f"outcomes_sampled_elevation_{uar_csv_suffix}", + "treecover": f"{url_prefix}/treecover/" + + f"outcomes_sampled_treecover_{uar_csv_suffix}", } def __init__( From 705ffcadedf77b6e09f8268f90a14da9d584083f Mon Sep 17 00:00:00 2001 From: iejMac Date: Thu, 24 Feb 2022 17:51:23 +0000 Subject: [PATCH 41/65] No more iteration in load_files --- torchgeo/datasets/usavars.py | 50 ++++++++++++------------------------ 1 file changed, 16 insertions(+), 34 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index aed725ab53e..01b2f77ec07 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -87,6 +87,12 @@ def __init__( self.files = self._load_files() + # csvs = self.label_urls.keys() # only uar for now + csvs = ["treecover", "elevation", "population"] + self.label_dfs = [ + (lab, pd.read_csv(os.path.join(self.root, lab + ".csv"))) for lab in csvs + ] + def __getitem__(self, index: int) -> Dict[str, Tensor]: """Return an index within the dataset. @@ -96,18 +102,18 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: Returns: data and label at that index """ - sample = self.files[index] - tensor_sample = {} - tensor_sample["image"] = self._load_image(sample["image"]) + tif_file = self.files[index] + id_ = tif_file[5:-4] - keys = [key for key in sample.keys() if key != "image"] - for key in keys: - tensor_sample[key] = Tensor([sample[key]]) + sample = {} + for lab, ds in self.label_dfs: + sample[lab] = Tensor([ds[ds["ID"] == id_][lab].values[0]]) + sample["image"] = self._load_image(os.path.join(self.root, "uar", tif_file)) if self.transforms is not None: - tensor_sample = self.transforms(tensor_sample) + sample = self.transforms(sample) - return tensor_sample + return sample def __len__(self) -> int: """Return the number of data points in the dataset. @@ -117,34 +123,10 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self) -> List[Dict[str, Any]]: - import pandas as pd - + def _load_files(self) -> List[str]: file_path = os.path.join(self.root, "uar") files = os.listdir(file_path) - - files = files[ - :10 - ] # TODO: remove this, keeping temporarily because this func is very slow - - # csvs = self.label_urls.keys() # only uar for now - csvs = ["treecover", "elevation", "population"] - labels_ds = [ - (lab, pd.read_csv(os.path.join(self.root, lab + ".csv"))) for lab in csvs - ] - - samples = [] - for f in files: - img_path = os.path.join(file_path, f) - samp = {"image": img_path} - - id_ = f[5:-4] - - for lab, ds in labels_ds: - samp[lab] = ds[ds["ID"] == id_][lab].values[0] - - samples.append(samp) - return samples + return files def _load_image(self, path: str) -> Tensor: with rasterio.open(path) as f: From 8a72999dfe268a3adda09e9aa253acd03f54fbed Mon Sep 17 00:00:00 2001 From: iejMac Date: Thu, 24 Feb 2022 17:52:40 +0000 Subject: [PATCH 42/65] dont' use Any --- torchgeo/datasets/usavars.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 01b2f77ec07..2774df537a9 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import matplotlib.pyplot as plt import numpy as np From d28a6e4a211bdb994eecc50fc03912e8d9471267 Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 25 Feb 2022 03:44:41 +0000 Subject: [PATCH 43/65] check if all csv files exist --- torchgeo/datasets/usavars.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 2774df537a9..1fae57c3f6b 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -142,12 +142,13 @@ def _verify(self) -> None: """ # Check if the extracted files already exist pathname = os.path.join(self.root, "uar") - if glob.glob(pathname): + csv_pathname = os.path.join(self.root, "*.csv") + if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: return # Check if the zip files have already been downloaded pathname = os.path.join(self.root, self.zipfile) - if glob.glob(pathname): + if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: self._extract() return From faf6f21b0838895058f050c8b30ee5111ca2849c Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 25 Feb 2022 03:46:48 +0000 Subject: [PATCH 44/65] Add docstring to init --- torchgeo/datasets/usavars.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 1fae57c3f6b..f288e05ce04 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -70,7 +70,19 @@ def __init__( download: bool = False, checksum: bool = False, ) -> None: - """Initialize a new USAVars dataset instance.""" + """Initialize a new USAVars dataset instance. + + Args: + root: root directory where dataset can be found + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if ``download=False`` and data is not found, or checksums + don't match + """ self.root = root self.transforms = transforms self.download = download From 5e75097c997d20785cd1194f40d382cfae8e7071 Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 25 Feb 2022 03:57:46 +0000 Subject: [PATCH 45/65] use index col = ID --- torchgeo/datasets/usavars.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index f288e05ce04..3e27c196a5c 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -102,7 +102,8 @@ def __init__( # csvs = self.label_urls.keys() # only uar for now csvs = ["treecover", "elevation", "population"] self.label_dfs = [ - (lab, pd.read_csv(os.path.join(self.root, lab + ".csv"))) for lab in csvs + (lab, pd.read_csv(os.path.join(self.root, lab + ".csv"), index_col="ID")) + for lab in csvs ] def __getitem__(self, index: int) -> Dict[str, Tensor]: @@ -119,7 +120,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: sample = {} for lab, ds in self.label_dfs: - sample[lab] = Tensor([ds[ds["ID"] == id_][lab].values[0]]) + sample[lab] = Tensor([ds.loc[id_][lab]]) sample["image"] = self._load_image(os.path.join(self.root, "uar", tif_file)) if self.transforms is not None: From 5626062c7dccb512bb268d860e1170498715753c Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 25 Feb 2022 04:15:56 +0000 Subject: [PATCH 46/65] labels in as list --- torchgeo/datasets/usavars.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 3e27c196a5c..dba2d1ef6e1 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Sequence import matplotlib.pyplot as plt import numpy as np @@ -63,9 +63,13 @@ class USAVars(VisionDataset): + f"outcomes_sampled_treecover_{uar_csv_suffix}", } + # ALL_LABELS = label_urls.keys() + ALL_LABELS = ["treecover", "elevation", "population"] + def __init__( self, root: str = "data", + labels: Sequence[str] = ALL_LABELS, transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, @@ -84,6 +88,7 @@ def __init__( don't match """ self.root = root + self.labels = labels self.transforms = transforms self.download = download self.checksum = checksum @@ -99,12 +104,15 @@ def __init__( self.files = self._load_files() - # csvs = self.label_urls.keys() # only uar for now - csvs = ["treecover", "elevation", "population"] - self.label_dfs = [ - (lab, pd.read_csv(os.path.join(self.root, lab + ".csv"), index_col="ID")) - for lab in csvs - ] + self.label_dfs = dict( + [ + ( + lab, + pd.read_csv(os.path.join(self.root, lab + ".csv"), index_col="ID"), + ) + for lab in self.labels + ] + ) def __getitem__(self, index: int) -> Dict[str, Tensor]: """Return an index within the dataset. @@ -118,10 +126,12 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: tif_file = self.files[index] id_ = tif_file[5:-4] - sample = {} - for lab, ds in self.label_dfs: - sample[lab] = Tensor([ds.loc[id_][lab]]) - sample["image"] = self._load_image(os.path.join(self.root, "uar", tif_file)) + sample = { + "labels": Tensor( + [self.label_dfs[lab].loc[id_][lab] for lab in self.labels] + ), + "image": self._load_image(os.path.join(self.root, "uar", tif_file)), + } if self.transforms is not None: sample = self.transforms(sample) From f57f99986c83250df64d01dd6c2bf7b0c722ba3f Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 25 Feb 2022 04:18:15 +0000 Subject: [PATCH 47/65] adjust to only 3 labels + adjust tests --- tests/datasets/test_usavars.py | 2 +- torchgeo/datasets/usavars.py | 7 +++++-- usavars.py | 7 +++++++ 3 files changed, 13 insertions(+), 3 deletions(-) create mode 100644 usavars.py diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py index de9587ff25b..0abe82d007f 100644 --- a/tests/datasets/test_usavars.py +++ b/tests/datasets/test_usavars.py @@ -63,7 +63,7 @@ def test_getitem(self, dataset: USAVars) -> None: assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) assert x["image"].ndim == 3 - assert len(x.keys()) == 4 # image, elevation, population, treecover + assert len(x.keys()) == 2 # image, elevation, population, treecover assert x["image"].shape[0] == 4 # R, G, B, Inf def test_len(self, dataset: USAVars) -> None: diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index dba2d1ef6e1..1236163fc4d 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -166,12 +166,15 @@ def _verify(self) -> None: # Check if the extracted files already exist pathname = os.path.join(self.root, "uar") csv_pathname = os.path.join(self.root, "*.csv") - if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: + + # if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: + if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 3: return # Check if the zip files have already been downloaded pathname = os.path.join(self.root, self.zipfile) - if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: + # if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: + if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 3: self._extract() return diff --git a/usavars.py b/usavars.py new file mode 100644 index 00000000000..495eeb58e31 --- /dev/null +++ b/usavars.py @@ -0,0 +1,7 @@ +import os +import pandas as pd + +from torchgeo.datasets import USAVars +ds = USAVars(download=True) + +print(ds[0]) From 9f24afb140729898e12148dbebcbbd2e1028e844 Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 25 Feb 2022 04:23:47 +0000 Subject: [PATCH 48/65] citation --- torchgeo/datasets/usavars.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 1236163fc4d..f65fe03bf75 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -34,6 +34,10 @@ class USAVars(VisionDataset): - road length - housing price + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.1038/s41467-021-24638-z + .. versionadded:: 0.3 """ From e97939ae5a8a832688d5c9e37c70eae734b3168c Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 25 Feb 2022 04:26:39 +0000 Subject: [PATCH 49/65] remove testing file --- usavars.py | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 usavars.py diff --git a/usavars.py b/usavars.py deleted file mode 100644 index 495eeb58e31..00000000000 --- a/usavars.py +++ /dev/null @@ -1,7 +0,0 @@ -import os -import pandas as pd - -from torchgeo.datasets import USAVars -ds = USAVars(download=True) - -print(ds[0]) From 2ba7e6bd8f131662d607a842cd2528c0ad1181de Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 25 Feb 2022 04:48:44 +0000 Subject: [PATCH 50/65] no need to rename zipfile --- torchgeo/datasets/usavars.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index f65fe03bf75..efdd27a1daf 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -49,7 +49,7 @@ class USAVars(VisionDataset): uar_csv_suffix = "CONTUS_16_640_UAR_100000_0.csv?download" data_url = "https://mosaiks.blob.core.windows.net/datasets/uar.zip" - zipfile = "usavars.zip" + dirname = "uar" md5 = "677e89fd20e5dd0fe4d29b61827c2456" @@ -176,7 +176,7 @@ def _verify(self) -> None: return # Check if the zip files have already been downloaded - pathname = os.path.join(self.root, self.zipfile) + pathname = os.path.join(self.root, self.dirname+".zip") # if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 3: self._extract() @@ -200,12 +200,11 @@ def _download(self) -> None: download_url( self.data_url, self.root, - filename=self.zipfile, md5=self.md5 if self.checksum else None, ) def _extract(self) -> None: - extract_archive(os.path.join(self.root, self.zipfile)) + extract_archive(os.path.join(self.root, self.dirname+".zip")) def plot( self, From fca13f6aa5c2299d7332c69f6c076a33b8d9d52d Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 25 Feb 2022 05:24:51 +0000 Subject: [PATCH 51/65] adding data.py to test data + adjusting tests --- tests/data/usavars/data.py | 67 ++++++++++++++++++++++++++++ tests/data/usavars/population.csv | 6 +-- tests/data/usavars/uar.zip | Bin 0 -> 811 bytes tests/data/usavars/uar/tile_0,0.tif | Bin 422 -> 449 bytes tests/data/usavars/uar/tile_0,1.tif | Bin 422 -> 449 bytes tests/data/usavars/usavars.zip | Bin 774 -> 0 bytes tests/datasets/test_usavars.py | 4 +- 7 files changed, 72 insertions(+), 5 deletions(-) create mode 100644 tests/data/usavars/data.py create mode 100644 tests/data/usavars/uar.zip delete mode 100644 tests/data/usavars/usavars.zip diff --git a/tests/data/usavars/data.py b/tests/data/usavars/data.py new file mode 100644 index 00000000000..da9dc03c8fb --- /dev/null +++ b/tests/data/usavars/data.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import glob +import hashlib +import os +import shutil + +import numpy as np +import pandas as pd +import rasterio + +data_dir = "uar" +labels = ["elevation", "population", "treecover"] +SIZE = 3 + +def create_file(path: str, dtype: str, num_channels: int) -> None: + profile = {} + profile["driver"] = "GTiff" + profile["dtype"] = dtype + profile["count"] = num_channels + profile["crs"] = "epsg:4326" + profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile["height"] = SIZE + profile["width"] = SIZE + profile["compress"] = "lzw" + profile["predictor"] = 2 + + Z = np.random.randint( + np.iinfo(profile["dtype"]).max, size=(4, SIZE, SIZE), dtype=profile["dtype"] + ) + src = rasterio.open(path, "w", **profile) + src.write(Z) + +# Remove old data +filename = f"{data_dir}.zip" +csvs = glob.glob("*.csv") + +for csv in csvs: + os.remove(csv) +if os.path.exists(filename): + os.remove(filename) +if os.path.exists(data_dir): + shutil.rmtree(data_dir) + +# Create tifs: +os.makedirs(data_dir) +create_file(os.path.join(data_dir, "tile_0,0.tif"), np.uint8, 4) +create_file(os.path.join(data_dir, "tile_0,1.tif"), np.uint8, 4) + +# Create labels: +columns = [["ID", "lon", "lat", lab] for lab in labels] +fake_vals = [["0,0", 0.0, 0.0, 0.0], ["0,1", 0.1, 0.1, 1.0]] +for lab, cols in zip(labels, columns): + df = pd.DataFrame(fake_vals, columns=cols) + df.to_csv(lab+".csv") + +# Compress data +shutil.make_archive(data_dir, "zip", ".", data_dir) + +# Compute checksums +filename = f"{data_dir}.zip" +with open(filename, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(repr(filename) + ": " + repr(md5) + ",") diff --git a/tests/data/usavars/population.csv b/tests/data/usavars/population.csv index 99a9d9425c6..5c53019d3fb 100644 --- a/tests/data/usavars/population.csv +++ b/tests/data/usavars/population.csv @@ -1,3 +1,3 @@ -,V1,V1.1,ID,lon,lat,population -0,1,1,"0,0",0.0,0.0,0.0 -1,2,2,"0,1",0.1,0.1,1.0 +,ID,lon,lat,population +0,"0,0",0.0,0.0,0.0 +1,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/uar.zip b/tests/data/usavars/uar.zip new file mode 100644 index 0000000000000000000000000000000000000000..adb4ace0c59f408a1f608a89e28cbb5d86d3f818 GIT binary patch literal 811 zcmWIWW@Zs#0D%u$ks)9PlwbkUrHMuQ0ZT)o~wQZ${b{5U=V<+D9Ox8jW^IS z)GNtM`~2lf(u0JA1c9W4gp{-dmQ6DnX9#wlWl7o0@G@P&E@X~q$IS&fY6@a0DU6$! zGjay5u<+=Ym}L1V!1R`b4znnGVzf$e(}Wub1FG*c8pa#*-{{GAxO3&u;R9?!%vv$` z&z)JokaC;xru41K1HWE1yy<=GlMw&4k%yUyPa~zpiQ`8?zn3t#z$?YHjm@^FIkwE5 z$f|XWanXEEgQA?AboN%!8TXf;Ik|D6tjc-bWs(wBMc&I!>S-1+7|Bj+b+KsM&8{6W z_KEOjDl#!*@!=1ryW~eUd83xE6-}Gy@@sV=bc+jS=aXp-MG4^;i}U0 z*Uhon7dAI5>HtHCkx7IZcR~XO7Z3-n#0L7kx-l%GO~d6P6HYs z3Ka+G5reWpW{N}AXam_YP;robwn*j#A+dvjY>*l6fS9kD2kfqIKz36L4+9%W45)d1 zJ2Qgvun*erGi zkV6@PwgbHg5d#J(Bf~QRMvjf`Kv@PhFuzZklVd|WNDnudW?<;ZEtw{}q{soN#9>)T zgk!jSu!4n24MT(3jKyDXX}o%;r_j{+QBboYThNxztu3H--7A4z%?ns9a=kv-wFW&q H7{USo1FSJg literal 422 zcmZvYJqp4=5QS%x7*hn%BH9QpDuRND*jO5{38;lXM^LcQ!bZeOg>+s;atK>XEAa-l zRu;ZlSBT)Z4DT^JGdpY$)W{*CG?6e#z?ldU^9HH{^o&F;&P-n#0L7kx-l%GO~d6P6HYs z3Ka+G5reWpW{N}AXam_YP;robwn*j#A+dvjY>*l6fS9kD2kfqIKz36L4+9%W45)d1 zJ2Qgvun*erGi zkV6@PwgbHg5d#J(Bf~QRMvjf`Kv@PhFuzZklVd|WNDnudW?<;ZEtw{}q{soN#9>)T zgk!jSu!4n24MT&%<%6atOCGdUO3ll1zV-9eVfocpa~&HOMx+_6V|c&ylFo)Qw?5g% HNg^x&I9f9C literal 422 zcmZvYJqp4=5QS%x7*hn%BH9QpDuRND*jO5{38;lXM^LcQ!bZeOg>+s;atK>XEAa-l zRu;ZlSBT)Z4DT^JGdpY$)W{*CG?6e#z?ldU^9HH{^o&F;&P$S(cQ|3@?in>_TK(ZZ61CQxHo@Vca~s zp+zG^W{Eq`NtwbGQlbje9F5dDPR>bK5qrI1#`>A;Yu?M9xFaUVW~T6$x8wSDV`B}5 zOw$F+df)o+7&(+lonTOS(BR~z(ZayU9wKFKz{d1*yH}sXsvooSCMf&JB#Bwy2u_&1 z&_ww3v9iDkHph>ZnXv`3nek1im~X_)-w}2?h4J#NQ+0iMV#m0CaV=Eb#})dsawS)r zSeY7@L*&TVaA None: USAVars(root=dataset.root, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "usavars", "usavars.zip") + pathname = os.path.join("tests", "data", "usavars", "uar.zip") root = str(tmp_path) shutil.copy(pathname, root) for csv in ["elevation.csv", "population.csv", "treecover.csv"]: From 07f942fdd82088a40a4c59c88e0f01a4656820a8 Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 25 Feb 2022 05:30:56 +0000 Subject: [PATCH 52/65] formatting fixes --- tests/data/usavars/data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/data/usavars/data.py b/tests/data/usavars/data.py index da9dc03c8fb..35765275093 100644 --- a/tests/data/usavars/data.py +++ b/tests/data/usavars/data.py @@ -16,6 +16,7 @@ labels = ["elevation", "population", "treecover"] SIZE = 3 + def create_file(path: str, dtype: str, num_channels: int) -> None: profile = {} profile["driver"] = "GTiff" @@ -34,6 +35,7 @@ def create_file(path: str, dtype: str, num_channels: int) -> None: src = rasterio.open(path, "w", **profile) src.write(Z) + # Remove old data filename = f"{data_dir}.zip" csvs = glob.glob("*.csv") @@ -55,7 +57,7 @@ def create_file(path: str, dtype: str, num_channels: int) -> None: fake_vals = [["0,0", 0.0, 0.0, 0.0], ["0,1", 0.1, 0.1, 1.0]] for lab, cols in zip(labels, columns): df = pd.DataFrame(fake_vals, columns=cols) - df.to_csv(lab+".csv") + df.to_csv(lab + ".csv") # Compress data shutil.make_archive(data_dir, "zip", ".", data_dir) From 681321e69c41b18e604b06d3dd7947c62fb2f2b8 Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 25 Feb 2022 05:32:00 +0000 Subject: [PATCH 53/65] style fix --- torchgeo/datasets/usavars.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index efdd27a1daf..b67b5c2c96e 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -176,7 +176,7 @@ def _verify(self) -> None: return # Check if the zip files have already been downloaded - pathname = os.path.join(self.root, self.dirname+".zip") + pathname = os.path.join(self.root, self.dirname + ".zip") # if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 3: self._extract() @@ -197,14 +197,10 @@ def _download(self) -> None: for f_name in self.label_urls: download_url(self.label_urls[f_name], self.root, filename=f_name + ".csv") - download_url( - self.data_url, - self.root, - md5=self.md5 if self.checksum else None, - ) + download_url(self.data_url, self.root, md5=self.md5 if self.checksum else None) def _extract(self) -> None: - extract_archive(os.path.join(self.root, self.dirname+".zip")) + extract_archive(os.path.join(self.root, self.dirname + ".zip")) def plot( self, From da8040b1df5520ad9498be97e3d74302419cb308 Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 25 Feb 2022 17:49:12 +0000 Subject: [PATCH 54/65] docstring --- torchgeo/datasets/usavars.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index b67b5c2c96e..affa75de622 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -21,6 +21,10 @@ class USAVars(VisionDataset): """USAVars dataset. + The USAVars dataset is a compilation of regression tasks + on NAIP imagery at lon, lat points sourced from: + https://doi.org/10.1038/s41467-021-24638-z + Dataset format: * images are 4-channel tifs * labels are singular float values From f30b50ed8a27b1bef6a02748741baac4c059bcc7 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 25 Feb 2022 17:58:44 +0000 Subject: [PATCH 55/65] Docstring --- torchgeo/datasets/usavars.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index affa75de622..8f15a1a3736 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -21,22 +21,22 @@ class USAVars(VisionDataset): """USAVars dataset. - The USAVars dataset is a compilation of regression tasks - on NAIP imagery at lon, lat points sourced from: - https://doi.org/10.1038/s41467-021-24638-z + The USAVars dataset is reproduction of the dataset used in the paper `"A + generalizable and accessible approach to machine learning with global satellite + imagery`_. Specifically, this dataset + includes 1 sq km. crops of NAIP imagery resampled to 4m/px cenetered on ~100k points + that are sampled randomly from the contiguous states in the USA. Each point contains + three continous valued labels (taken from the dataset released in the paper): tree + cover percentage, elevation, and population density. Dataset format: - * images are 4-channel tifs + * images are 4-channel GeoTIFFs * labels are singular float values Dataset labels: - - tree cover - - elevation - - population density - - nighttime lights - - income per houshold - - road length - - housing price + * tree cover + * elevation + * population density If you use this dataset in your research, please cite the following paper: From eb81d7c5b639ec9825769af3d181e75ae80f8cc6 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 25 Feb 2022 18:02:03 +0000 Subject: [PATCH 56/65] Fixing docstring --- torchgeo/datasets/usavars.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 8f15a1a3736..d5c9c2c5fae 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -21,9 +21,9 @@ class USAVars(VisionDataset): """USAVars dataset. - The USAVars dataset is reproduction of the dataset used in the paper `"A + The USAVars dataset is reproduction of the dataset used in the paper "`A generalizable and accessible approach to machine learning with global satellite - imagery`_. Specifically, this dataset + imagery `_". Specifically, this dataset includes 1 sq km. crops of NAIP imagery resampled to 4m/px cenetered on ~100k points that are sampled randomly from the contiguous states in the USA. Each point contains three continous valued labels (taken from the dataset released in the paper): tree From fb49f5dc9df918ece5813ea1d266740f303700b0 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sat, 26 Feb 2022 20:42:31 +0000 Subject: [PATCH 57/65] docstring for labels --- torchgeo/datasets/usavars.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 8f15a1a3736..918e4bdc4d7 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -86,6 +86,7 @@ def __init__( Args: root: root directory where dataset can be found + labels: list of labels to include transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory From 8243b30ce60202453737a2c97f6cfbf875d73a62 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sat, 26 Feb 2022 20:56:22 +0000 Subject: [PATCH 58/65] Adding all csv files to data + checking for all 7 labels instead of just 3 --- tests/data/usavars/data.py | 10 +++++++++- tests/data/usavars/housing.csv | 3 +++ tests/data/usavars/income.csv | 3 +++ tests/data/usavars/nightlights.csv | 3 +++ tests/data/usavars/roads.csv | 3 +++ tests/data/usavars/uar.zip | Bin 811 -> 810 bytes tests/data/usavars/uar/tile_0,0.tif | Bin 449 -> 449 bytes tests/data/usavars/uar/tile_0,1.tif | Bin 449 -> 449 bytes tests/datasets/test_usavars.py | 15 ++++++++++++++- torchgeo/datasets/usavars.py | 6 ++---- 10 files changed, 37 insertions(+), 6 deletions(-) create mode 100644 tests/data/usavars/housing.csv create mode 100644 tests/data/usavars/income.csv create mode 100644 tests/data/usavars/nightlights.csv create mode 100644 tests/data/usavars/roads.csv diff --git a/tests/data/usavars/data.py b/tests/data/usavars/data.py index 35765275093..c392887b04e 100644 --- a/tests/data/usavars/data.py +++ b/tests/data/usavars/data.py @@ -13,7 +13,15 @@ import rasterio data_dir = "uar" -labels = ["elevation", "population", "treecover"] +labels = [ + "elevation", + "population", + "treecover", + "income", + "nightlights", + "housing", + "roads", +] SIZE = 3 diff --git a/tests/data/usavars/housing.csv b/tests/data/usavars/housing.csv new file mode 100644 index 00000000000..a99a21131ac --- /dev/null +++ b/tests/data/usavars/housing.csv @@ -0,0 +1,3 @@ +,ID,lon,lat,housing +0,"0,0",0.0,0.0,0.0 +1,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/income.csv b/tests/data/usavars/income.csv new file mode 100644 index 00000000000..3ac6b3d2f05 --- /dev/null +++ b/tests/data/usavars/income.csv @@ -0,0 +1,3 @@ +,ID,lon,lat,income +0,"0,0",0.0,0.0,0.0 +1,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/nightlights.csv b/tests/data/usavars/nightlights.csv new file mode 100644 index 00000000000..9a55b7a8838 --- /dev/null +++ b/tests/data/usavars/nightlights.csv @@ -0,0 +1,3 @@ +,ID,lon,lat,nightlights +0,"0,0",0.0,0.0,0.0 +1,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/roads.csv b/tests/data/usavars/roads.csv new file mode 100644 index 00000000000..352db63ec1b --- /dev/null +++ b/tests/data/usavars/roads.csv @@ -0,0 +1,3 @@ +,ID,lon,lat,roads +0,"0,0",0.0,0.0,0.0 +1,"0,1",0.1,0.1,1.0 diff --git a/tests/data/usavars/uar.zip b/tests/data/usavars/uar.zip index adb4ace0c59f408a1f608a89e28cbb5d86d3f818..a6e122f80abd977f0e4260db951ead72b78e226a 100644 GIT binary patch delta 224 zcmZ3@wu+56z?+#xgaHI(mqktFmFEOfAt`2$i@!`XePkaIo1<+Ub?C|T-CXls17uW)#6l~c$>EHTZ3+s$X#1Zs zO}chn^qKuHtBY?J%2>N!d(ys3JWt*1oNe&7A8Ee3D%X4b|EQ}2G=q^zgc6Bf%pLH=t*Ay delta 225 zcmZ3*wwjGMz?+#xgaHITXhlxsmFEOfAxqevtA3qm`p7{e#xA33(-X(&rC0XoJdd{b zmz@@0YkGZSVWP{dgH~Hy7PWD;S9yJ&JL olRm;plh-nt%EQe50(A31Mg|4}ut<8 diff --git a/tests/data/usavars/uar/tile_0,0.tif b/tests/data/usavars/uar/tile_0,0.tif index 96dd113af0e9968fd59beea404d8b71d2b3511fb..372474660d384149d7467ae9a92c048423e2492a 100644 GIT binary patch delta 50 zcmV-20L}lw1Hl8ZmjNmW2#+I@Z88HdEfWvL>tROW9tQ|&li@AU+hKv7%yy&Md}lX} I6pzQW1lD^LkN^Mx delta 50 zcmV-20L}lw1Hl8ZmjNm^n4|UDDC+JmAcBGM5h-|U5j_q@hERR3>JYkvpanQ>O7K2~ IQs%)_1U9D@3;+NC diff --git a/tests/data/usavars/uar/tile_0,1.tif b/tests/data/usavars/uar/tile_0,1.tif index 674962c0a5dad3c171add76e854490ee8718a729..c9703a268501f281d54abdc061bb05e9ea728053 100644 GIT binary patch delta 50 zcmV-20L}lw1Hl8ZmjNm|(S~+1qykQDtrY+UYsFZUQlfy&Fba_d&L None: pathname = os.path.join("tests", "data", "usavars", "uar.zip") root = str(tmp_path) shutil.copy(pathname, root) - for csv in ["elevation.csv", "population.csv", "treecover.csv"]: + csvs = [ + "elevation.csv", + "population.csv", + "treecover.csv", + "income.csv", + "nightlights.csv", + "roads.csv", + "housing.csv", + ] + for csv in csvs: shutil.copy(os.path.join("tests", "data", "usavars", csv), root) USAVars(root) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 58a96f651f1..ce46d37378c 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -176,14 +176,12 @@ def _verify(self) -> None: pathname = os.path.join(self.root, "uar") csv_pathname = os.path.join(self.root, "*.csv") - # if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: - if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 3: + if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: return # Check if the zip files have already been downloaded pathname = os.path.join(self.root, self.dirname + ".zip") - # if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: - if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 3: + if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: self._extract() return From d5186834d6fe934166ebef3aaa5af4fac80af577 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 27 Feb 2022 08:27:11 +0000 Subject: [PATCH 59/65] docstrings --- torchgeo/datasets/usavars.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index ce46d37378c..911e2d80d93 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -197,12 +197,14 @@ def _verify(self) -> None: self._extract() def _download(self) -> None: + """Download the dataset.""" for f_name in self.label_urls: download_url(self.label_urls[f_name], self.root, filename=f_name + ".csv") download_url(self.data_url, self.root, md5=self.md5 if self.checksum else None) def _extract(self) -> None: + """Extract the dataset.""" extract_archive(os.path.join(self.root, self.dirname + ".zip")) def plot( From 72459ca001909523dc9ef73cbadb937280d16833 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 27 Feb 2022 08:35:17 +0000 Subject: [PATCH 60/65] docstring --- torchgeo/datasets/usavars.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 911e2d80d93..0a85063ab41 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -156,11 +156,20 @@ def __len__(self) -> int: return len(self.files) def _load_files(self) -> List[str]: + """Loads file names.""" file_path = os.path.join(self.root, "uar") files = os.listdir(file_path) return files def _load_image(self, path: str) -> Tensor: + """Load a single image + + Args: + path: path to the image + + Returns: + the image + """ with rasterio.open(path) as f: array: "np.typing.NDArray[np.int_]" = f.read() tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] From fdf8ee789158ff75d50a85ded56fefc60e631807 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 27 Feb 2022 08:44:35 +0000 Subject: [PATCH 61/65] ensure labels are valid --- torchgeo/datasets/usavars.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 0a85063ab41..4b933191916 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -71,8 +71,7 @@ class USAVars(VisionDataset): + f"outcomes_sampled_treecover_{uar_csv_suffix}", } - # ALL_LABELS = label_urls.keys() - ALL_LABELS = ["treecover", "elevation", "population"] + ALL_LABELS = label_urls.keys() def __init__( self, @@ -97,6 +96,10 @@ def __init__( don't match """ self.root = root + + for lab in labels: + assert lab in self.ALL_LABLES + self.labels = labels self.transforms = transforms self.download = download From 1d0d0ae0fbaa0557c760b2a409db0a4531efeb08 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 27 Feb 2022 08:48:30 +0000 Subject: [PATCH 62/65] pydocstyle fix --- torchgeo/datasets/usavars.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 4b933191916..1f7a5a6cca6 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -165,7 +165,7 @@ def _load_files(self) -> List[str]: return files def _load_image(self, path: str) -> Tensor: - """Load a single image + """Load a single image. Args: path: path to the image From 9199844c83d3f807ec6a2fa090b28e487fbe4181 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 27 Feb 2022 09:26:34 +0000 Subject: [PATCH 63/65] cast to list --- torchgeo/datasets/usavars.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 1f7a5a6cca6..641afd2ff2a 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -71,7 +71,7 @@ class USAVars(VisionDataset): + f"outcomes_sampled_treecover_{uar_csv_suffix}", } - ALL_LABELS = label_urls.keys() + ALL_LABELS = list(label_urls.keys()) def __init__( self, From c64876c6c7aafc3e621ff0623b13595abe9fcfbe Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 27 Feb 2022 09:28:58 +0000 Subject: [PATCH 64/65] remove typos --- torchgeo/datasets/usavars.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 641afd2ff2a..bf74f77f54b 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -61,7 +61,7 @@ class USAVars(VisionDataset): "housing": f"{url_prefix}/housing/outcomes_sampled_housing_{pop_csv_suffix}", "income": f"{url_prefix}/income/outcomes_sampled_income_{pop_csv_suffix}", "roads": f"{url_prefix}/roads/outcomes_sampled_roads_{pop_csv_suffix}", - "nightligths": f"{url_prefix}/nightlights/" + "nightlights": f"{url_prefix}/nightlights/" + f"outcomes_sampled_nightlights_{pop_csv_suffix}", "population": f"{url_prefix}/population/" + f"outcomes_sampled_population_{uar_csv_suffix}", @@ -98,7 +98,7 @@ def __init__( self.root = root for lab in labels: - assert lab in self.ALL_LABLES + assert lab in self.ALL_LABELS self.labels = labels self.transforms = transforms From 0925b64bb38abba199a52f89bac8773f12b3ff2e Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Sun, 27 Feb 2022 19:56:45 +0000 Subject: [PATCH 65/65] Requested changes --- torchgeo/datasets/usavars.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index bf74f77f54b..9d8118efddb 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -92,6 +92,8 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: + AssertionError: if invalid labels are provided + ImportError: if pandas is not installed RuntimeError: if ``download=False`` and data is not found, or checksums don't match """