diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index d3c52a4947c..5cb41ba682a 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -143,6 +143,7 @@ SpaceNet .. autoclass:: SpaceNet .. autoclass:: SpaceNet1 .. autoclass:: SpaceNet2 +.. autoclass:: SpaceNet4 Tropical Cyclone Wind Estimation Competition ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/data/spacenet/sn4_AOI_6_Atlanta.tar.gz b/tests/data/spacenet/sn4_AOI_6_Atlanta.tar.gz new file mode 100644 index 00000000000..1c382ff25dd Binary files /dev/null and b/tests/data/spacenet/sn4_AOI_6_Atlanta.tar.gz differ diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py index 4c236df915d..b769f438570 100644 --- a/tests/datasets/test_spacenet.py +++ b/tests/datasets/test_spacenet.py @@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import SpaceNet1, SpaceNet2 +from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4 from torchgeo.transforms import Identity TEST_DATA_DIR = "tests/data/spacenet" @@ -141,3 +141,67 @@ def test_collection_checksum(self, dataset: SpaceNet2) -> None: dataset.collection_md5_dict["sn2_AOI_2_Vegas"] = "randommd5hash123" with pytest.raises(RuntimeError, match="Collection sn2_AOI_2_Vegas corrupted"): SpaceNet2(root=dataset.root, download=True, checksum=True) + + +class TestSpaceNet4: + @pytest.fixture(params=["PAN", "MS", "PS-RGBNIR"]) + def dataset( + self, + request: SubRequest, + monkeypatch: Generator[MonkeyPatch, None, None], + tmp_path: Path, + ) -> SpaceNet4: + radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1") + monkeypatch.setattr( # type: ignore[attr-defined] + radiant_mlhub.Collection, "fetch", fetch_collection + ) + test_md5 = { + "sn4_AOI_6_Atlanta": "ea37c2d87e2c3a1d8b2a7c2230080d46", + } + + test_angles = ["nadir", "off-nadir", "very-off-nadir"] + + monkeypatch.setattr( # type: ignore[attr-defined] + SpaceNet4, "collection_md5_dict", test_md5 + ) + root = str(tmp_path) + transforms = Identity() + return SpaceNet4( + root, + image=request.param, + angles=test_angles, + transforms=transforms, + download=True, + api_key="", + ) + + def test_getitem(self, dataset: SpaceNet4) -> None: + # Get image-label pair with empty label to + # enusre coverage + x = dataset[2] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["mask"], torch.Tensor) + if dataset.image == "PS-RGBNIR": + assert x["image"].shape[0] == 4 + elif dataset.image == "MS": + assert x["image"].shape[0] == 8 + else: + assert x["image"].shape[0] == 1 + + def test_len(self, dataset: SpaceNet4) -> None: + assert len(dataset) == 4 + + def test_already_downloaded(self, dataset: SpaceNet4) -> None: + SpaceNet4(root=dataset.root, download=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + SpaceNet4(str(tmp_path)) + + def test_collection_checksum(self, dataset: SpaceNet4) -> None: + dataset.collection_md5_dict["sn4_AOI_6_Atlanta"] = "randommd5hash123" + with pytest.raises( + RuntimeError, match="Collection sn4_AOI_6_Atlanta corrupted" + ): + SpaceNet4(root=dataset.root, download=True, checksum=True) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index cad16315aad..98f9ba46f52 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -56,7 +56,7 @@ from .sen12ms import SEN12MS from .sentinel import Sentinel, Sentinel2 from .so2sat import So2Sat -from .spacenet import SpaceNet, SpaceNet1, SpaceNet2 +from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4 from .ucmerced import UCMerced from .utils import BoundingBox, collate_dict from .zuericrop import ZueriCrop @@ -109,6 +109,7 @@ "SpaceNet", "SpaceNet1", "SpaceNet2", + "SpaceNet4", "TropicalCycloneWindEstimation", "UCMerced", "VHR10", diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 26e27a66498..855f7f8da65 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -14,6 +14,7 @@ import rasterio as rio import torch from affine import Affine +from fiona.errors import FionaValueError from rasterio.features import rasterize from torch import Tensor @@ -154,8 +155,11 @@ def _load_mask(self, path: str, tfm: Affine, shape: Tuple[int, int]) -> Tensor: Returns: Tensor: label tensor """ - with fiona.open(path) as src: - labels = [feature["geometry"] for feature in src] + try: + with fiona.open(path) as src: + labels = [feature["geometry"] for feature in src] + except FionaValueError: + labels = [] if not labels: mask_data = np.zeros(shape=shape) @@ -490,3 +494,181 @@ def _load_files(self, root: str) -> List[Dict[str, str]]: ) files.append({"image_path": imgpath, "label_path": lbl_path}) return files + + +class SpaceNet4(SpaceNet): + """SpaceNet 4: Off-Nadir Buildings Dataset. + + `SpaceNet 4 `_ is a + dataset of 27 WV2 imagery captured at varying off-nadir angles and + associated building footprints over the city of Atlanta. The off-nadir angle + ranges from 7 degrees to 54 degrees. + + + Dataset features + + * No. of chipped images: 28,728 (PAN/MS/PS-RGBNIR) + * No. of label files: 1064 + * No. of building footprints: >120,000 + * Area Coverage: 665 sq km + * Chip size: 225 x 225 (MS), 900 x 900 (PAN/PS-RGBNIR) + + Dataset format + + * Imagery - Worldview-3 GeoTIFFs + * PAN.tif (Panchromatic) + * MS.tif (Multispectral) + * PS-RGBNIR (Pansharpened RGBNIR) + * Labels - GeoJSON + * labels.geojson + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/abs/1903.12239 + + .. note:: + + This dataset requires the following additional library to be installed: + + * `radiant-mlhub `_ to download the + imagery and labels from the Radiant Earth MLHub + + """ + + dataset_id = "spacenet4" + collection_md5_dict = { + "sn4_AOI_6_Atlanta": "c597d639cba5257927a97e3eff07b753", + } + + imagery = { + "MS": "MS.tif", + "PAN": "PAN.tif", + "PS-RGBNIR": "PS-RGBNIR.tif", + } + chip_size = { + "MS": (225, 225), + "PAN": (900, 900), + "PS-RGBNIR": (900, 900), + } + label_glob = "labels.geojson" + + angle_catalog_map = { + "nadir": [ + "1030010003D22F00", + "10300100023BC100", + "1030010003993E00", + "1030010003CAF100", + "1030010002B7D800", + "10300100039AB000", + "1030010002649200", + "1030010003C92000", + "1030010003127500", + "103001000352C200", + "103001000307D800", + ], + "off-nadir": [ + "1030010003472200", + "1030010003315300", + "10300100036D5200", + "103001000392F600", + "1030010003697400", + "1030010003895500", + "1030010003832800", + ], + "very-off-nadir": [ + "10300100035D1B00", + "1030010003CCD700", + "1030010003713C00", + "10300100033C5200", + "1030010003492700", + "10300100039E6200", + "1030010003BDDC00", + "1030010003CD4300", + "1030010003193D00", + ], + } + + def __init__( + self, + root: str, + image: str = "PS-RGBNIR", + angles: List[str] = [], + transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + download: bool = False, + api_key: Optional[str] = None, + checksum: bool = False, + ) -> None: + """Initialize a new SpaceNet 4 Dataset instance. + + Args: + root: root directory where dataset can be found + image: image selection which must be in ["MS", "PAN", "PS-RGBNIR"] + angles: angle selection which must be in ["nadir", "off-nadir", + "very-off-nadir"] + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory. + api_key: a RadiantEarth MLHub API key to use for downloading the dataset + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if ``download=False`` but dataset is missing + """ + collections = ["sn4_AOI_6_Atlanta"] + assert image in {"MS", "PAN", "PS-RGBNIR"} + self.angles = angles + if self.angles: + for angle in self.angles: + assert angle in self.angle_catalog_map.keys() + super().__init__( + root, image, collections, transforms, download, api_key, checksum + ) + + def _load_files(self, root: str) -> List[Dict[str, str]]: + """Return the paths of the files in the dataset. + + Args: + root: root dir of dataset + + Returns: + list of dicts containing paths for each pair of image and label + """ + files = [] + nadir = [] + offnadir = [] + veryoffnadir = [] + images = glob.glob(os.path.join(root, self.collections[0], "*", self.filename)) + images = sorted(images) + + catalog_id_pattern = re.compile(r"(_[A-Z0-9])\w+$") + for imgpath in images: + imgdir = os.path.basename(os.path.dirname(imgpath)) + match = catalog_id_pattern.search(imgdir) + assert match is not None, "Invalid image directory" + catalog_id = match.group()[1:] + + lbl_dir = os.path.dirname(imgpath).split("-nadir")[0] + + lbl_path = os.path.join(lbl_dir + "-labels", self.label_glob) + assert os.path.exists(lbl_path) + + _file = {"image_path": imgpath, "label_path": lbl_path} + if catalog_id in self.angle_catalog_map["very-off-nadir"]: + veryoffnadir.append(_file) + elif catalog_id in self.angle_catalog_map["off-nadir"]: + offnadir.append(_file) + elif catalog_id in self.angle_catalog_map["nadir"]: + nadir.append(_file) + + angle_file_map = { + "nadir": nadir, + "off-nadir": offnadir, + "very-off-nadir": veryoffnadir, + } + + if not self.angles: + files.extend(nadir + offnadir + veryoffnadir) + else: + for angle in self.angles: + files.extend(angle_file_map[angle]) + return files