Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Style fixes
Browse files Browse the repository at this point in the history
ashnair1 committed Sep 13, 2021
1 parent c7dfc77 commit 173f08c
Showing 2 changed files with 49 additions and 22 deletions.
4 changes: 2 additions & 2 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -46,8 +46,8 @@
from .resisc45 import RESISC45
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel2
from .spacenet1 import Spacenet1
from .so2sat import So2Sat
from .spacenet1 import Spacenet1
from .utils import BoundingBox, collate_dict

__all__ = (
@@ -79,6 +79,7 @@
"NAIP",
"Sentinel",
"Sentinel2",
"Spacenet1",
# VisionDataset
"BeninSmallHolderCashews",
"COWC",
@@ -93,7 +94,6 @@
"RESISC45",
"SEN12MS",
"So2Sat",
"Spacenet1",
"TropicalCycloneWindEstimation",
"VHR10",
# Base classes
67 changes: 47 additions & 20 deletions torchgeo/datasets/spacenet1.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@


class Spacenet1(RasterDataset):
"""Spacenet 1: Building Detection v1 Dataset
"""Spacenet 1: Building Detection v1 Dataset.
`Spacenet 1 <https://spacenet.ai/spacenet-buildings-dataset-v1/>`_
is a dataset of building footprints over the city of Rio de Janeiro.
@@ -38,10 +38,9 @@ class Spacenet1(RasterDataset):
* Imagery - Raw 8 band Worldview-3 (GeoTIFF) & Pansharpened RGB image (GeoTIFF)
* Labels - GeoJSON
If you are using data from SpaceNet in a paper, please use the following citation:
If you are using data from SpaceNet in a paper, please cite the following paper:
* Van Etten, A., Lindenbaum, D., & Bacastow, T.M. (2018). SpaceNet: A Remote
Sensing Dataset and Challenge Series. ArXiv, abs/1807.01232.
* https://arxiv.org/abs/1807.01232
.. note::
@@ -63,14 +62,30 @@ def __init__(
self,
root: str,
crs: Optional[CRS] = None,
res: float = None,
res: Optional[float] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
download: bool = False,
api_key: str = None,
api_key: Optional[str] = None,
checksum: bool = False,
) -> None:
"""Initialise a new Spacenet 1 Dataset instance
Args:
root: root directory where dataset can be found
crs (Optional[CRS], optional): [description]. Defaults to None.
res (float, optional): [description]. Defaults to None.
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version.
download: if True, download dataset and store it in the root directory.
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
Raises:
RuntimeError: if ``download=False`` but dataset is missing
"""

self.root = root
self.transforms = transforms
self.checksum = checksum

if not self._check_integrity():
if download:
@@ -88,7 +103,16 @@ def __init__(
)
self.files = self._load_files(os.path.join(root, self.foldername))

def _load_files(self, root) -> List[Dict[str, str]]:
def _load_files(self, root: str) -> List[Dict[str, str]]:
"""Return the paths of the files in the dataset.
Args:
root: root dir of dataset
Returns:
list of dicts containing paths for each triple of rgb,
8band and label
"""
files = []
images = glob.glob(os.path.join(root, "*", self.filename_glob))
images = sorted(images)
@@ -100,7 +124,7 @@ def _load_files(self, root) -> List[Dict[str, str]]:
files.append({"rgb": imgpath, "8band": rawpath, "label": lbl_path})
return files

def _load_image(self, path: str) -> Tensor:
def _load_image(self, path: str) -> Tuple[Tensor, Affine]:
"""Load a single image.
Args:
@@ -166,8 +190,8 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
files = self.files[index]
rgb, tfm = self._load_image(files["rgb"])
raw, _ = self._load_image(files["8band"])
out_shape = tuple(rgb.shape[1:])
label = self._load_label(files["label"], tfm, out_shape)
h, w = rgb.shape[1:]
label = self._load_label(files["label"], tfm, (h, w))

sample = {"rgb": rgb, "8band": raw, "label": label}

@@ -182,16 +206,15 @@ def _check_integrity(self) -> bool:
True if the dataset directories are found, else False
"""
stacpath = os.path.join(self.root, self.foldername, "collection.json")
# If dataset folder does not exist. Check for uncorrupted archive
if not bool(os.path.exists(stacpath)):
# If dataset folder does not exist, check for uncorrupted archive
if not os.path.exists(stacpath):
archive_path = os.path.join(self.root, self.foldername + ".tar.gz")
if not bool(os.path.exists(archive_path)) or not check_integrity(
archive_path, self.md5
):
return False
print("Archive found. Extracting...")
extract_archive(archive_path)
return True
if os.path.exists(archive_path):
print("Archive found")
if self.checksum and not check_integrity(archive_path, self.md5):
return False
print("Extracting...")
extract_archive(archive_path)

def _download(self, api_key: Optional[str] = None) -> None:
"""Download the dataset and extract it.
@@ -209,7 +232,11 @@ def _download(self, api_key: Optional[str] = None) -> None:

download_radiant_mlhub(self.dataset_id, self.root, api_key)
archive_path = os.path.join(self.root, self.foldername + ".tar.gz")
if check_integrity(archive_path, self.md5):
if (
self.checksum
and check_integrity(archive_path, self.md5)
or not self.checksum
):
extract_archive(archive_path)
else:
raise RuntimeError("Dataset corrupted")

0 comments on commit 173f08c

Please sign in to comment.