From 2b66a0ced5793d0a7f48a35e8dec00ec29de28b2 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 21 Mar 2022 14:23:42 +0400 Subject: [PATCH 1/5] Add SpaceNet3 --- docs/api/datasets.rst | 1 + tests/data/spacenet/sn3_AOI_3_Paris.tar.gz | Bin 0 -> 1176 bytes tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz | Bin 0 -> 1193 bytes tests/datasets/test_spacenet.py | 78 +++- torchgeo/datasets/__init__.py | 11 +- torchgeo/datasets/spacenet.py | 420 ++++++++++++------ 6 files changed, 364 insertions(+), 146 deletions(-) create mode 100644 tests/data/spacenet/sn3_AOI_3_Paris.tar.gz create mode 100644 tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index e18d187865e..f4e90812e07 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -233,6 +233,7 @@ SpaceNet .. autoclass:: SpaceNet .. autoclass:: SpaceNet1 .. autoclass:: SpaceNet2 +.. autoclass:: SpaceNet3 .. autoclass:: SpaceNet4 .. autoclass:: SpaceNet5 .. autoclass:: SpaceNet7 diff --git a/tests/data/spacenet/sn3_AOI_3_Paris.tar.gz b/tests/data/spacenet/sn3_AOI_3_Paris.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..d8cb305dc35915aae2653c0bfd1a231a2b3dbc42 GIT binary patch literal 1176 zcmV;J1ZVpniwFP!000001MQkkZyQAz$7k&*1UEv-fx>}GD{}&+X5ZKL0XYq72%>Ei zqk@E@)y5t>i|k!%c1;zP^w6TZ7v3(MQB@TVNVI$bIKhn%02h!@4*di@A!gQVd%baD zH_mRG=J~D0^YYBjuIJhRjI&Q%TQ-)jtQoSgZW7CV3&nsy(RG#OqOOYJyay5$Nz??D zbwq+FOM;AeHIAE!(sn(Q@H{fz?Q52OaKCZ(6N(%0KdYB7%LZ}Gs%v<}wCu3mSTCR7 zSiN9a&AO~k`mj8)|0P*a6!3o%Oyz%Cxm3=1R_*X>Fb8Xza-jcpHR^vw5k&!bsd=4tt-P8#vFyiUsxI`jV6U(@o# zPRc*`JFff979z(mT;!1DuYOkU|MvDTZ*=o<^Y+)J%)d8&|Fkl82^gn;P2Q(}RnS5INif;|&z{&x_W#=<{j+mM>Hk%Fxl;s+$CxK?hcCV4KzyH7EF^(K|95@8C+%!y|3G zVtTkYV`7@CBOyPyhf^_HqsY literal 0 HcmV?d00001 diff --git a/tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz b/tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..0daea2f5392966c25ce41bac36e2ad1103581498 GIT binary patch literal 1193 zcmV;a1XlYWiwFP!000001MOMQZ{tK1o+MQXp+emQyBt8W!VPt^w#SZ>97?)HR29&z z8U#qFvfN~n8YOmQJFKdudtedS6XL{q4`{`WUH%A;To4xyh*J*;ap1%O0dE|qaeg%E zHi@_0zR#M@uQxOHJb&{!udYS)wOgCIs=vK&IG){V+#s`t2+%Z@=aQyM;k*kXDYBFm zRc=T`iE6S$1a;xm*t85FHMgK>EY5l)ly8dn{ zg!D>i><*NlgZy_$rC_bUA-#BSi6r zxIgwmtn*~c16My{16T6ygYT7xKdg51N&EK4!e5Dh-Y6Dt9v%F&dhJFrJTKwTv%t%7 z5?ct`k~}|JjwK(r`O_}?^)Fw0KKZE4`(N;LLhd)buV4Mf;qw(AzW3qU2OA|Jr+hlJ z{^0$`2d>1xbrPoK|3vbWGw~lj0g(Su0Qn#0|IunUpR{kt|3LrOwk93`llZT~$KC+{ zB}qg6M}aQ?m&fuy*0Gq#e=Ed)o<{MXuNLNiDOB!q{zuZ|{1NGK{%VNx?}s>F4LE=P zRXpZ%^ZZ~2-~T%sH^(0VllY&dL;P2$g#3>Jr_+D_#-2w1zXftvAR1x4D{sj9~CMX z|A_)K@_#4G)R_ITYY*b@3%vL%drY}tz4+EY5@ipZf&d=qzo;r z@04;%S~wIAVb*GADL(Y0RA}8jyjKF|b>&$|hqGm*YHn>lt{w6C27M|0x z%7(|fQ<_$lIS#AnZj-^^MsvS6c&N_}`e37kG!Ux6Jm;X_2k)2`D|rrobg(JgwgcM$ z>-qZ$7nL(gK1(%8;`E{#l~)8w&8U(_rJRz_=V>-iWnp(EWJ_}$TRGRWRkgsD3cGND z?^3&|!_?D$XA!f*l>h&ooBqqF|B+xu{-3%2tEm5x0QLV_)&G1h*U^84s@aaFD=N+V z`macNnTGTqY&-f7wsOzb*Z;s4^&f#iAP@)y0)apv5C{YUfj}S-2m}IwIQRG;o=i0} H08jt`@2PC2 literal 0 HcmV?d00001 diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py index 35ce3cecf16..fda7af88717 100644 --- a/tests/datasets/test_spacenet.py +++ b/tests/datasets/test_spacenet.py @@ -14,7 +14,14 @@ from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7 +from torchgeo.datasets import ( + SpaceNet1, + SpaceNet2, + SpaceNet3, + SpaceNet4, + SpaceNet5, + SpaceNet7, +) TEST_DATA_DIR = "tests/data/spacenet" @@ -142,6 +149,75 @@ def test_plot(self, dataset: SpaceNet2) -> None: plt.close() +class TestSpaceNet3: + @pytest.fixture( + params=itertools.product(["PAN", "MS", "PS-MS", "PS-RGB"], [False, True]) + ) + def dataset( + self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path + ) -> SpaceNet3: + radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1") + monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_collection) + test_md5 = { + "sn3_AOI_3_Paris": "197440e0ade970169a801a173a492c27", + "sn3_AOI_5_Khartoum": "b21ff7dd33a15ec32bd380c083263cdf", + } + + monkeypatch.setattr(SpaceNet3, "collection_md5_dict", test_md5) + root = str(tmp_path) + transforms = nn.Identity() # type: ignore[no-untyped-call] + return SpaceNet3( + root, + image=request.param[0], + speed_mask=request.param[1], + collections=["sn3_AOI_3_Paris", "sn3_AOI_5_Khartoum"], + transforms=transforms, + download=True, + api_key="", + ) + + def test_getitem(self, dataset: SpaceNet3) -> None: + # Iterate over all elements to maximize coverage + samples = [dataset[i] for i in range(len(dataset))] + x = samples[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["mask"], torch.Tensor) + if dataset.image == "PS-RGB": + assert x["image"].shape[0] == 3 + elif dataset.image in ["MS", "PS-MS"]: + assert x["image"].shape[0] == 8 + else: + assert x["image"].shape[0] == 1 + + def test_len(self, dataset: SpaceNet3) -> None: + assert len(dataset) == 4 + + def test_already_downloaded(self, dataset: SpaceNet3) -> None: + SpaceNet3(root=dataset.root, download=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + SpaceNet3(str(tmp_path)) + + def test_collection_checksum(self, dataset: SpaceNet3) -> None: + dataset.collection_md5_dict["sn3_AOI_5_Khartoum"] = "randommd5hash123" + with pytest.raises( + RuntimeError, match="Collection sn3_AOI_5_Khartoum corrupted" + ): + SpaceNet3(root=dataset.root, download=True, checksum=True) + + def test_plot(self, dataset: SpaceNet3) -> None: + x = dataset[0].copy() + x["prediction"] = x["mask"] + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + dataset.plot({"image": x["image"]}) + plt.close() + + class TestSpaceNet4: @pytest.fixture(params=["PAN", "MS", "PS-RGBNIR"]) def dataset( diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 8d28060b4cf..0bbdb94b01b 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -75,7 +75,15 @@ from .sen12ms import SEN12MS from .sentinel import Sentinel, Sentinel2 from .so2sat import So2Sat -from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7 +from .spacenet import ( + SpaceNet, + SpaceNet1, + SpaceNet2, + SpaceNet3, + SpaceNet4, + SpaceNet5, + SpaceNet7, +) from .ucmerced import UCMerced from .usavars import USAVars from .utils import ( @@ -155,6 +163,7 @@ "SpaceNet", "SpaceNet1", "SpaceNet2", + "SpaceNet3", "SpaceNet4", "SpaceNet5", "SpaceNet7", diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 5cf7d9fe5de..84a4ab84bf1 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -248,7 +248,7 @@ def _check_integrity(self) -> List[str]: to_be_downloaded = [] for collection in missing_collections: - archive_path = os.path.join(self.root, collection + ".tar.gz") + archive_path = os.path.join(self.root, f"{collection}.tar.gz") if os.path.exists(archive_path): print(f"Found {collection} archive") if ( @@ -554,6 +554,270 @@ def __init__( ) +class SpaceNet3(SpaceNet): + r"""SpaceNet 3: Road Network Detection. + + `SpaceNet 3 `_ + is a dataset of road networks over the cities of Vegas, Paris, Shanghai. + and Khartoum. + + Collection features: + + +------------+---------------------+------------+---------------------------+ + | AOI | Area (km\ :sup:`2`\)| # Images | # Road Network Labels (km)| + +============+=====================+============+===========================+ + | Vegas | 216 | 1353 | 3685 | + +------------+---------------------+------------+---------------------------+ + | Paris | 1030 | 257 | 425 | + +------------+---------------------+------------+---------------------------+ + | Shanghai | 1000 | 1016 | 3537 | + +------------+---------------------+------------+---------------------------+ + | Khartoum | 765 | 283 | 1030 | + +------------+---------------------+------------+---------------------------+ + + Imagery features: + + .. list-table:: + :widths: 10 10 10 10 10 + :header-rows: 1 + :stub-columns: 1 + + * - + - PAN + - MS + - PS-MS + - PS-RGB + * - GSD (m) + - 0.31 + - 1.24 + - 0.30 + - 0.30 + * - Chip size (px) + - 1300 x 1300 + - 325 x 325 + - 1300 x 1300 + - 1300 x 1300 + + Dataset format: + + * Imagery - Worldview-3 GeoTIFFs + + * PAN.tif (Panchromatic) + * MS.tif (Multispectral) + * PS-MS (Pansharpened Multispectral) + * PS-RGB (Pansharpened RGB) + + * Labels - GeoJSON + + * labels.geojson + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/abs/1807.01232 + + .. note:: + + This dataset requires the following additional library to be installed: + + * `radiant-mlhub `_ to download the + imagery and labels from the Radiant Earth MLHub + + .. versionadded:: 0.3 + """ + + dataset_id = "spacenet3" + collection_md5_dict = { + "sn3_AOI_3_Paris": "90b9ebd64cd83dc8d3d4773f45050d8f", + "sn3_AOI_5_Khartoum": "b8d549ac9a6d7456c0f7a8e6de23d9f9", + } + + imagery = { + "MS": "MS.tif", + "PAN": "PAN.tif", + "PS-MS": "PS-MS.tif", + "PS-RGB": "PS-RGB.tif", + } + chip_size = { + "MS": (325, 325), + "PAN": (1300, 1300), + "PS-MS": (1300, 1300), + "PS-RGB": (1300, 1300), + } + label_glob = "labels.geojson" + + def __init__( + self, + root: str, + image: str = "PS-RGB", + speed_mask: Optional[bool] = False, + collections: List[str] = [], + transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + download: bool = False, + api_key: Optional[str] = None, + checksum: bool = False, + ) -> None: + """Initialize a new SpaceNet 3 Dataset instance. + + Args: + root: root directory where dataset can be found + image: image selection which must be in ["MS", "PAN", "PS-MS", "PS-RGB"] + speed_mask: use multi-class speed mask (created by binning roads at + 10 mph increments) as label if true, else use binary mask + collections: collection selection which must be a subset of: + [sn3_AOI_2_Vegas, sn3_AOI_3_Paris, sn3_AOI_4_Shanghai, + sn3_AOI_5_Khartoum] + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory. + api_key: a RadiantEarth MLHub API key to use for downloading the dataset + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if ``download=False`` but dataset is missing + """ + assert image in {"MS", "PAN", "PS-MS", "PS-RGB"} + self.speed_mask = speed_mask + super().__init__( + root, image, collections, transforms, download, api_key, checksum + ) + + def _load_mask( + self, path: str, tfm: Affine, raster_crs: CRS, shape: Tuple[int, int] + ) -> Tensor: + """Rasterizes the dataset's labels (in geojson format). + + Args: + path: path to the label + tfm: transform of corresponding image + shape: shape of corresponding image + + Returns: + Tensor: label tensor + """ + min_speed_bin = 1 + max_speed_bin = 65 + speed_arr_bin = np.arange(min_speed_bin, max_speed_bin + 1) + bin_size_mph = 10.0 + speed_cls_arr: "np.typing.NDArray[np.int_]" = np.array( + [int(math.ceil(s / bin_size_mph)) for s in speed_arr_bin] + ) + + try: + with fiona.open(path) as src: + vector_crs = CRS(src.crs) + labels = [] + + for feature in src: + if raster_crs != vector_crs: + geom = transform_geom( + vector_crs.to_string(), + raster_crs.to_string(), + feature["geometry"], + ) + else: + geom = feature["geometry"] + + if self.speed_mask: + val = speed_cls_arr[ + int(feature["properties"]["inferred_speed_mph"]) - 1 + ] + else: + val = 1 + + labels.append((geom, val)) + + except FionaValueError: + labels = [] + + if not labels: + mask_data = np.zeros(shape=shape) + else: + mask_data = rasterize( + labels, + out_shape=shape, + fill=0, # nodata value + transform=tfm, + all_touched=False, + dtype=np.uint8, + ) + + mask = torch.from_numpy(mask_data).long() + return mask + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`SpaceNet.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + # image can be 1 channel or >3 channels + if sample["image"].shape[0] == 1: + image = np.rollaxis(sample["image"].numpy(), 0, 3) + else: + image = np.rollaxis(sample["image"][:3].numpy(), 0, 3) + image = percentile_normalization(image, axis=(0, 1)) + + ncols = 1 + show_mask = "mask" in sample + show_predictions = "prediction" in sample + + if show_mask: + mask = sample["mask"].numpy() + ncols += 1 + + if show_predictions: + prediction = sample["prediction"].numpy() + ncols += 1 + + fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8)) + if not isinstance(axs, np.ndarray): + axs = [axs] + axs[0].imshow(image) + axs[0].axis("off") + if show_titles: + axs[0].set_title("Image") + + if show_mask: + if self.speed_mask: + cmap = copy.copy(plt.get_cmap("autumn_r")) + cmap.set_under(color="black") + axs[1].imshow(mask, vmin=0.1, vmax=7, cmap=cmap, interpolation="none") + else: + axs[1].imshow(mask, cmap="Greys_r", interpolation="none") + axs[1].axis("off") + if show_titles: + axs[1].set_title("Label") + + if show_predictions: + if self.speed_mask: + cmap = copy.copy(plt.get_cmap("autumn_r")) + cmap.set_under(color="black") + axs[2].imshow( + prediction, vmin=0.1, vmax=7, cmap=cmap, interpolation="none" + ) + else: + axs[2].imshow(prediction, cmap="Greys_r", interpolation="none") + axs[2].axis("off") + if show_titles: + axs[2].set_title("Prediction") + + if suptitle is not None: + plt.suptitle(suptitle) + return fig + + class SpaceNet4(SpaceNet): """SpaceNet 4: Off-Nadir Buildings Dataset. @@ -699,7 +963,7 @@ def _load_files(self, root: str) -> List[Dict[str, str]]: lbl_dir = os.path.dirname(imgpath).split("-nadir")[0] - lbl_path = os.path.join(lbl_dir + "-labels", self.label_glob) + lbl_path = os.path.join(f"{lbl_dir}-labels", self.label_glob) assert os.path.exists(lbl_path) _file = {"image_path": imgpath, "label_path": lbl_path} @@ -724,7 +988,7 @@ def _load_files(self, root: str) -> List[Dict[str, str]]: return files -class SpaceNet5(SpaceNet): +class SpaceNet5(SpaceNet3): r"""SpaceNet 5: Automated Road Network Extraction and Route Travel Time Estimation. `SpaceNet 5 `_ @@ -842,148 +1106,17 @@ def __init__( Raises: RuntimeError: if ``download=False`` but dataset is missing """ - assert image in {"MS", "PAN", "PS-MS", "PS-RGB"} - self.speed_mask = speed_mask super().__init__( - root, image, collections, transforms, download, api_key, checksum - ) - - def _load_mask( - self, path: str, tfm: Affine, raster_crs: CRS, shape: Tuple[int, int] - ) -> Tensor: - """Rasterizes the dataset's labels (in geojson format). - - Args: - path: path to the label - tfm: transform of corresponding image - shape: shape of corresponding image - - Returns: - Tensor: label tensor - """ - min_speed_bin = 1 - max_speed_bin = 65 - speed_arr_bin = np.arange(min_speed_bin, max_speed_bin + 1) - bin_size_mph = 10.0 - speed_cls_arr: "np.typing.NDArray[np.int_]" = np.array( - [int(math.ceil(s / bin_size_mph)) for s in speed_arr_bin] + root, + image, + speed_mask, + collections, + transforms, + download, + api_key, + checksum, ) - try: - with fiona.open(path) as src: - vector_crs = CRS(src.crs) - labels = [] - - for feature in src: - if raster_crs != vector_crs: - geom = transform_geom( - vector_crs.to_string(), - raster_crs.to_string(), - feature["geometry"], - ) - else: - geom = feature["geometry"] - - if self.speed_mask: - val = speed_cls_arr[ - int(feature["properties"]["inferred_speed_mph"]) - 1 - ] - else: - val = 1 - - labels.append((geom, val)) - - except FionaValueError: - labels = [] - - if not labels: - mask_data = np.zeros(shape=shape) - else: - mask_data = rasterize( - labels, - out_shape=shape, - fill=0, # nodata value - transform=tfm, - all_touched=False, - dtype=np.uint8, - ) - - mask = torch.from_numpy(mask_data).long() - return mask - - def plot( - self, - sample: Dict[str, Tensor], - show_titles: bool = True, - suptitle: Optional[str] = None, - ) -> Figure: - """Plot a sample from the dataset. - - Args: - sample: a sample returned by :meth:`SpaceNet.__getitem__` - show_titles: flag indicating whether to show titles above each panel - suptitle: optional string to use as a suptitle - - Returns: - a matplotlib Figure with the rendered sample - - .. versionadded:: 0.2 - """ - # image can be 1 channel or >3 channels - if sample["image"].shape[0] == 1: - image = np.rollaxis(sample["image"].numpy(), 0, 3) - else: - image = np.rollaxis(sample["image"][:3].numpy(), 0, 3) - image = percentile_normalization(image, axis=(0, 1)) - - ncols = 1 - show_mask = "mask" in sample - show_predictions = "prediction" in sample - - if show_mask: - mask = sample["mask"].numpy() - ncols += 1 - - if show_predictions: - prediction = sample["prediction"].numpy() - ncols += 1 - - fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8)) - if not isinstance(axs, np.ndarray): - axs = [axs] - axs[0].imshow(image) - axs[0].axis("off") - if show_titles: - axs[0].set_title("Image") - - if show_mask: - if self.speed_mask: - cmap = copy.copy(plt.get_cmap("autumn_r")) - cmap.set_under(color="black") - axs[1].imshow(mask, vmin=0.1, vmax=7, cmap=cmap, interpolation="none") - else: - axs[1].imshow(mask, cmap="Greys_r", interpolation="none") - axs[1].axis("off") - if show_titles: - axs[1].set_title("Label") - - if show_predictions: - if self.speed_mask: - cmap = copy.copy(plt.get_cmap("autumn_r")) - cmap.set_under(color="black") - axs[2].imshow( - prediction, vmin=0.1, vmax=7, cmap=cmap, interpolation="none" - ) - else: - axs[2].imshow(prediction, cmap="Greys_r", interpolation="none") - axs[2].axis("off") - if show_titles: - axs[2].set_title("Prediction") - - if suptitle is not None: - plt.suptitle(suptitle) - return fig - class SpaceNet7(SpaceNet): """SpaceNet 7: Multi-Temporal Urban Development Challenge. @@ -1126,13 +1259,12 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: Returns: data at that index """ - sample = {} files = self.files[index] img, tfm, raster_crs = self._load_image(files["image_path"]) h, w = img.shape[1:] ch, cw = self.chip_size["img"] - sample["image"] = img[:, :ch, :cw] + sample = {"image": img[:, :ch, :cw]} if self.split == "train": mask = self._load_mask(files["label_path"], tfm, raster_crs, (h, w)) sample["mask"] = mask[:ch, :cw] From 41a9ad6a42d1962416b98459fcc0bb0f1b98056e Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 24 Mar 2022 12:25:18 +0400 Subject: [PATCH 2/5] Fixes --- torchgeo/datasets/spacenet.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 84a4ab84bf1..6be78ca11dc 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -281,7 +281,7 @@ def _download(self, collections: List[str], api_key: Optional[str] = None) -> No """ for collection in collections: download_radiant_mlhub_collection(collection, self.root, api_key) - archive_path = os.path.join(self.root, collection + ".tar.gz") + archive_path = os.path.join(self.root, f"{collection}.tar.gz") if ( not self.checksum or not check_integrity( @@ -558,7 +558,7 @@ class SpaceNet3(SpaceNet): r"""SpaceNet 3: Road Network Detection. `SpaceNet 3 `_ - is a dataset of road networks over the cities of Vegas, Paris, Shanghai. + is a dataset of road networks over the cities of Las Vegas, Paris, Shanghai, and Khartoum. Collection features: @@ -566,11 +566,11 @@ class SpaceNet3(SpaceNet): +------------+---------------------+------------+---------------------------+ | AOI | Area (km\ :sup:`2`\)| # Images | # Road Network Labels (km)| +============+=====================+============+===========================+ - | Vegas | 216 | 1353 | 3685 | + | Vegas | 216 | 854 | 3685 | +------------+---------------------+------------+---------------------------+ | Paris | 1030 | 257 | 425 | +------------+---------------------+------------+---------------------------+ - | Shanghai | 1000 | 1016 | 3537 | + | Shanghai | 1000 | 1028 | 3537 | +------------+---------------------+------------+---------------------------+ | Khartoum | 765 | 283 | 1030 | +------------+---------------------+------------+---------------------------+ @@ -627,7 +627,9 @@ class SpaceNet3(SpaceNet): dataset_id = "spacenet3" collection_md5_dict = { + "sn3_AOI_2_Vegas": "8ce7e6abffb8849eb88885035f061ee8", "sn3_AOI_3_Paris": "90b9ebd64cd83dc8d3d4773f45050d8f", + "sn3_AOI_4_Shanghai": "3ea291df34548962dfba8b5ed37d700c", "sn3_AOI_5_Khartoum": "b8d549ac9a6d7456c0f7a8e6de23d9f9", } From a85ed5d80323661b6f0b4aac169e3182e629cb02 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Fri, 25 Mar 2022 12:21:26 +0400 Subject: [PATCH 3/5] Replace itertools.product with zip --- tests/datasets/test_spacenet.py | 9 ++------- torchgeo/datasets/spacenet.py | 21 +++++++++++---------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py index fda7af88717..c9398ae6e80 100644 --- a/tests/datasets/test_spacenet.py +++ b/tests/datasets/test_spacenet.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import glob -import itertools import os import shutil from pathlib import Path @@ -150,9 +149,7 @@ def test_plot(self, dataset: SpaceNet2) -> None: class TestSpaceNet3: - @pytest.fixture( - params=itertools.product(["PAN", "MS", "PS-MS", "PS-RGB"], [False, True]) - ) + @pytest.fixture(params=zip(["PAN", "MS", "PS-MS", "PS-RGB"], [False, True])) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> SpaceNet3: @@ -282,9 +279,7 @@ def test_plot(self, dataset: SpaceNet4) -> None: class TestSpaceNet5: - @pytest.fixture( - params=itertools.product(["PAN", "MS", "PS-MS", "PS-RGB"], [False, True]) - ) + @pytest.fixture(params=zip(["PAN", "MS", "PS-MS", "PS-RGB"], [False, True])) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> SpaceNet5: diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 6be78ca11dc..07d9f494d65 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -133,7 +133,7 @@ def _load_files(self, root: str) -> List[Dict[str, str]]: images = sorted(images) for imgpath in images: lbl_path = os.path.join( - os.path.dirname(imgpath) + "-labels", self.label_glob + f"{os.path.dirname(imgpath)}-labels", self.label_glob ) files.append({"image_path": imgpath, "label_path": lbl_path}) return files @@ -183,10 +183,8 @@ def _load_mask( except FionaValueError: labels = [] - if not labels: - mask_data = np.zeros(shape=shape) - else: - mask_data = rasterize( + mask_data = ( + rasterize( labels, out_shape=shape, fill=0, # nodata value @@ -194,6 +192,9 @@ def _load_mask( all_touched=False, dtype=np.uint8, ) + if labels + else np.zeros(shape=shape) + ) mask = torch.from_numpy(mask_data).long() @@ -731,10 +732,8 @@ def _load_mask( except FionaValueError: labels = [] - if not labels: - mask_data = np.zeros(shape=shape) - else: - mask_data = rasterize( + mask_data = ( + rasterize( labels, out_shape=shape, fill=0, # nodata value @@ -742,6 +741,9 @@ def _load_mask( all_touched=False, dtype=np.uint8, ) + if labels + else np.zeros(shape=shape) + ) mask = torch.from_numpy(mask_data).long() return mask @@ -762,7 +764,6 @@ def plot( Returns: a matplotlib Figure with the rendered sample - .. versionadded:: 0.2 """ # image can be 1 channel or >3 channels if sample["image"].shape[0] == 1: From 9af7d622628da6d141bf20303e772b37ecf5cda5 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Fri, 25 Mar 2022 22:51:11 +0400 Subject: [PATCH 4/5] Update docstring --- torchgeo/datasets/spacenet.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 07d9f494d65..7746e3e09fd 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -183,8 +183,10 @@ def _load_mask( except FionaValueError: labels = [] - mask_data = ( - rasterize( + if not labels: + mask_data = np.zeros(shape=shape) + else: + mask_data = rasterize( labels, out_shape=shape, fill=0, # nodata value @@ -192,9 +194,6 @@ def _load_mask( all_touched=False, dtype=np.uint8, ) - if labels - else np.zeros(shape=shape) - ) mask = torch.from_numpy(mask_data).long() @@ -539,7 +538,8 @@ def __init__( image: image selection which must be in ["MS", "PAN", "PS-MS", "PS-RGB"] collections: collection selection which must be a subset of: [sn2_AOI_2_Vegas, sn2_AOI_3_Paris, sn2_AOI_4_Shanghai, - sn2_AOI_5_Khartoum] + sn2_AOI_5_Khartoum]. If unspecified, all collections will be + used. transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory. @@ -668,7 +668,8 @@ def __init__( 10 mph increments) as label if true, else use binary mask collections: collection selection which must be a subset of: [sn3_AOI_2_Vegas, sn3_AOI_3_Paris, sn3_AOI_4_Shanghai, - sn3_AOI_5_Khartoum] + sn3_AOI_5_Khartoum]. If unspecified, all collections will be + used. transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory. @@ -732,8 +733,10 @@ def _load_mask( except FionaValueError: labels = [] - mask_data = ( - rasterize( + if not labels: + mask_data = np.zeros(shape=shape) + else: + mask_data = rasterize( labels, out_shape=shape, fill=0, # nodata value @@ -741,9 +744,6 @@ def _load_mask( all_touched=False, dtype=np.uint8, ) - if labels - else np.zeros(shape=shape) - ) mask = torch.from_numpy(mask_data).long() return mask @@ -1099,7 +1099,8 @@ def __init__( speed_mask: use multi-class speed mask (created by binning roads at 10 mph increments) as label if true, else use binary mask collections: collection selection which must be a subset of: - [sn5_AOI_7_Moscow, sn5_AOI_8_Mumbai] + [sn5_AOI_7_Moscow, sn5_AOI_8_Mumbai]. If unspecified, all + collections will be used. transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory. From a3f0f0271e9a1621406ba78bbd2a830806f4ac53 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Sat, 26 Mar 2022 00:19:14 +0400 Subject: [PATCH 5/5] Remove unused options --- tests/datasets/test_spacenet.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py index c9398ae6e80..7e4cf90af47 100644 --- a/tests/datasets/test_spacenet.py +++ b/tests/datasets/test_spacenet.py @@ -149,7 +149,7 @@ def test_plot(self, dataset: SpaceNet2) -> None: class TestSpaceNet3: - @pytest.fixture(params=zip(["PAN", "MS", "PS-MS", "PS-RGB"], [False, True])) + @pytest.fixture(params=zip(["PAN", "MS"], [False, True])) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> SpaceNet3: @@ -180,9 +180,7 @@ def test_getitem(self, dataset: SpaceNet3) -> None: assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) assert isinstance(x["mask"], torch.Tensor) - if dataset.image == "PS-RGB": - assert x["image"].shape[0] == 3 - elif dataset.image in ["MS", "PS-MS"]: + if dataset.image == "MS": assert x["image"].shape[0] == 8 else: assert x["image"].shape[0] == 1 @@ -279,7 +277,7 @@ def test_plot(self, dataset: SpaceNet4) -> None: class TestSpaceNet5: - @pytest.fixture(params=zip(["PAN", "MS", "PS-MS", "PS-RGB"], [False, True])) + @pytest.fixture(params=zip(["PAN", "MS"], [False, True])) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> SpaceNet5: @@ -310,9 +308,7 @@ def test_getitem(self, dataset: SpaceNet5) -> None: assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) assert isinstance(x["mask"], torch.Tensor) - if dataset.image == "PS-RGB": - assert x["image"].shape[0] == 3 - elif dataset.image in ["MS", "PS-MS"]: + if dataset.image == "MS": assert x["image"].shape[0] == 8 else: assert x["image"].shape[0] == 1