diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index d3c52a4947c..5cb41ba682a 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -143,6 +143,7 @@ SpaceNet
.. autoclass:: SpaceNet
.. autoclass:: SpaceNet1
.. autoclass:: SpaceNet2
+.. autoclass:: SpaceNet4
Tropical Cyclone Wind Estimation Competition
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/tests/data/spacenet/sn4_AOI_6_Atlanta.tar.gz b/tests/data/spacenet/sn4_AOI_6_Atlanta.tar.gz
new file mode 100644
index 00000000000..1c382ff25dd
Binary files /dev/null and b/tests/data/spacenet/sn4_AOI_6_Atlanta.tar.gz differ
diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py
index 4c236df915d..b769f438570 100644
--- a/tests/datasets/test_spacenet.py
+++ b/tests/datasets/test_spacenet.py
@@ -12,7 +12,7 @@
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
-from torchgeo.datasets import SpaceNet1, SpaceNet2
+from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4
from torchgeo.transforms import Identity
TEST_DATA_DIR = "tests/data/spacenet"
@@ -141,3 +141,67 @@ def test_collection_checksum(self, dataset: SpaceNet2) -> None:
dataset.collection_md5_dict["sn2_AOI_2_Vegas"] = "randommd5hash123"
with pytest.raises(RuntimeError, match="Collection sn2_AOI_2_Vegas corrupted"):
SpaceNet2(root=dataset.root, download=True, checksum=True)
+
+
+class TestSpaceNet4:
+ @pytest.fixture(params=["PAN", "MS", "PS-RGBNIR"])
+ def dataset(
+ self,
+ request: SubRequest,
+ monkeypatch: Generator[MonkeyPatch, None, None],
+ tmp_path: Path,
+ ) -> SpaceNet4:
+ radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
+ monkeypatch.setattr( # type: ignore[attr-defined]
+ radiant_mlhub.Collection, "fetch", fetch_collection
+ )
+ test_md5 = {
+ "sn4_AOI_6_Atlanta": "ea37c2d87e2c3a1d8b2a7c2230080d46",
+ }
+
+ test_angles = ["nadir", "off-nadir", "very-off-nadir"]
+
+ monkeypatch.setattr( # type: ignore[attr-defined]
+ SpaceNet4, "collection_md5_dict", test_md5
+ )
+ root = str(tmp_path)
+ transforms = Identity()
+ return SpaceNet4(
+ root,
+ image=request.param,
+ angles=test_angles,
+ transforms=transforms,
+ download=True,
+ api_key="",
+ )
+
+ def test_getitem(self, dataset: SpaceNet4) -> None:
+ # Get image-label pair with empty label to
+ # enusre coverage
+ x = dataset[2]
+ assert isinstance(x, dict)
+ assert isinstance(x["image"], torch.Tensor)
+ assert isinstance(x["mask"], torch.Tensor)
+ if dataset.image == "PS-RGBNIR":
+ assert x["image"].shape[0] == 4
+ elif dataset.image == "MS":
+ assert x["image"].shape[0] == 8
+ else:
+ assert x["image"].shape[0] == 1
+
+ def test_len(self, dataset: SpaceNet4) -> None:
+ assert len(dataset) == 4
+
+ def test_already_downloaded(self, dataset: SpaceNet4) -> None:
+ SpaceNet4(root=dataset.root, download=True)
+
+ def test_not_downloaded(self, tmp_path: Path) -> None:
+ with pytest.raises(RuntimeError, match="Dataset not found"):
+ SpaceNet4(str(tmp_path))
+
+ def test_collection_checksum(self, dataset: SpaceNet4) -> None:
+ dataset.collection_md5_dict["sn4_AOI_6_Atlanta"] = "randommd5hash123"
+ with pytest.raises(
+ RuntimeError, match="Collection sn4_AOI_6_Atlanta corrupted"
+ ):
+ SpaceNet4(root=dataset.root, download=True, checksum=True)
diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py
index cad16315aad..98f9ba46f52 100644
--- a/torchgeo/datasets/__init__.py
+++ b/torchgeo/datasets/__init__.py
@@ -56,7 +56,7 @@
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel2
from .so2sat import So2Sat
-from .spacenet import SpaceNet, SpaceNet1, SpaceNet2
+from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4
from .ucmerced import UCMerced
from .utils import BoundingBox, collate_dict
from .zuericrop import ZueriCrop
@@ -109,6 +109,7 @@
"SpaceNet",
"SpaceNet1",
"SpaceNet2",
+ "SpaceNet4",
"TropicalCycloneWindEstimation",
"UCMerced",
"VHR10",
diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py
index 26e27a66498..855f7f8da65 100644
--- a/torchgeo/datasets/spacenet.py
+++ b/torchgeo/datasets/spacenet.py
@@ -14,6 +14,7 @@
import rasterio as rio
import torch
from affine import Affine
+from fiona.errors import FionaValueError
from rasterio.features import rasterize
from torch import Tensor
@@ -154,8 +155,11 @@ def _load_mask(self, path: str, tfm: Affine, shape: Tuple[int, int]) -> Tensor:
Returns:
Tensor: label tensor
"""
- with fiona.open(path) as src:
- labels = [feature["geometry"] for feature in src]
+ try:
+ with fiona.open(path) as src:
+ labels = [feature["geometry"] for feature in src]
+ except FionaValueError:
+ labels = []
if not labels:
mask_data = np.zeros(shape=shape)
@@ -490,3 +494,181 @@ def _load_files(self, root: str) -> List[Dict[str, str]]:
)
files.append({"image_path": imgpath, "label_path": lbl_path})
return files
+
+
+class SpaceNet4(SpaceNet):
+ """SpaceNet 4: Off-Nadir Buildings Dataset.
+
+ `SpaceNet 4 `_ is a
+ dataset of 27 WV2 imagery captured at varying off-nadir angles and
+ associated building footprints over the city of Atlanta. The off-nadir angle
+ ranges from 7 degrees to 54 degrees.
+
+
+ Dataset features
+
+ * No. of chipped images: 28,728 (PAN/MS/PS-RGBNIR)
+ * No. of label files: 1064
+ * No. of building footprints: >120,000
+ * Area Coverage: 665 sq km
+ * Chip size: 225 x 225 (MS), 900 x 900 (PAN/PS-RGBNIR)
+
+ Dataset format
+
+ * Imagery - Worldview-3 GeoTIFFs
+ * PAN.tif (Panchromatic)
+ * MS.tif (Multispectral)
+ * PS-RGBNIR (Pansharpened RGBNIR)
+ * Labels - GeoJSON
+ * labels.geojson
+
+ If you use this dataset in your research, please cite the following paper:
+
+ * https://arxiv.org/abs/1903.12239
+
+ .. note::
+
+ This dataset requires the following additional library to be installed:
+
+ * `radiant-mlhub `_ to download the
+ imagery and labels from the Radiant Earth MLHub
+
+ """
+
+ dataset_id = "spacenet4"
+ collection_md5_dict = {
+ "sn4_AOI_6_Atlanta": "c597d639cba5257927a97e3eff07b753",
+ }
+
+ imagery = {
+ "MS": "MS.tif",
+ "PAN": "PAN.tif",
+ "PS-RGBNIR": "PS-RGBNIR.tif",
+ }
+ chip_size = {
+ "MS": (225, 225),
+ "PAN": (900, 900),
+ "PS-RGBNIR": (900, 900),
+ }
+ label_glob = "labels.geojson"
+
+ angle_catalog_map = {
+ "nadir": [
+ "1030010003D22F00",
+ "10300100023BC100",
+ "1030010003993E00",
+ "1030010003CAF100",
+ "1030010002B7D800",
+ "10300100039AB000",
+ "1030010002649200",
+ "1030010003C92000",
+ "1030010003127500",
+ "103001000352C200",
+ "103001000307D800",
+ ],
+ "off-nadir": [
+ "1030010003472200",
+ "1030010003315300",
+ "10300100036D5200",
+ "103001000392F600",
+ "1030010003697400",
+ "1030010003895500",
+ "1030010003832800",
+ ],
+ "very-off-nadir": [
+ "10300100035D1B00",
+ "1030010003CCD700",
+ "1030010003713C00",
+ "10300100033C5200",
+ "1030010003492700",
+ "10300100039E6200",
+ "1030010003BDDC00",
+ "1030010003CD4300",
+ "1030010003193D00",
+ ],
+ }
+
+ def __init__(
+ self,
+ root: str,
+ image: str = "PS-RGBNIR",
+ angles: List[str] = [],
+ transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ download: bool = False,
+ api_key: Optional[str] = None,
+ checksum: bool = False,
+ ) -> None:
+ """Initialize a new SpaceNet 4 Dataset instance.
+
+ Args:
+ root: root directory where dataset can be found
+ image: image selection which must be in ["MS", "PAN", "PS-RGBNIR"]
+ angles: angle selection which must be in ["nadir", "off-nadir",
+ "very-off-nadir"]
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ download: if True, download dataset and store it in the root directory.
+ api_key: a RadiantEarth MLHub API key to use for downloading the dataset
+ checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ RuntimeError: if ``download=False`` but dataset is missing
+ """
+ collections = ["sn4_AOI_6_Atlanta"]
+ assert image in {"MS", "PAN", "PS-RGBNIR"}
+ self.angles = angles
+ if self.angles:
+ for angle in self.angles:
+ assert angle in self.angle_catalog_map.keys()
+ super().__init__(
+ root, image, collections, transforms, download, api_key, checksum
+ )
+
+ def _load_files(self, root: str) -> List[Dict[str, str]]:
+ """Return the paths of the files in the dataset.
+
+ Args:
+ root: root dir of dataset
+
+ Returns:
+ list of dicts containing paths for each pair of image and label
+ """
+ files = []
+ nadir = []
+ offnadir = []
+ veryoffnadir = []
+ images = glob.glob(os.path.join(root, self.collections[0], "*", self.filename))
+ images = sorted(images)
+
+ catalog_id_pattern = re.compile(r"(_[A-Z0-9])\w+$")
+ for imgpath in images:
+ imgdir = os.path.basename(os.path.dirname(imgpath))
+ match = catalog_id_pattern.search(imgdir)
+ assert match is not None, "Invalid image directory"
+ catalog_id = match.group()[1:]
+
+ lbl_dir = os.path.dirname(imgpath).split("-nadir")[0]
+
+ lbl_path = os.path.join(lbl_dir + "-labels", self.label_glob)
+ assert os.path.exists(lbl_path)
+
+ _file = {"image_path": imgpath, "label_path": lbl_path}
+ if catalog_id in self.angle_catalog_map["very-off-nadir"]:
+ veryoffnadir.append(_file)
+ elif catalog_id in self.angle_catalog_map["off-nadir"]:
+ offnadir.append(_file)
+ elif catalog_id in self.angle_catalog_map["nadir"]:
+ nadir.append(_file)
+
+ angle_file_map = {
+ "nadir": nadir,
+ "off-nadir": offnadir,
+ "very-off-nadir": veryoffnadir,
+ }
+
+ if not self.angles:
+ files.extend(nadir + offnadir + veryoffnadir)
+ else:
+ for angle in self.angles:
+ files.extend(angle_file_map[angle])
+ return files