From 85810c0310f6e79aa36c7e3ea07e639f5d193c8f Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Mon, 24 Oct 2022 12:18:37 -0300 Subject: [PATCH 01/49] add extent_crop to BoundingBox --- torchgeo/datasets/utils.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 1e7f43bbf87..3eb6725e847 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -396,6 +396,32 @@ def intersects(self, other: "BoundingBox") -> bool: and self.maxt >= other.mint ) + def extent_crop( + self, + skip_bottom: float = 0.0, + skip_left: float = 0.0, + skip_top: float = 0.0, + skip_right: float = 0.0, + ) -> "BoundingBox": + """Crop BoundingBox by skipping a proportion from its sides. + + Args: + skip_bottom: proportion to skip from the bottom + skip_left: proportion to skip from the left + skip_top: proportion to skip from the top + skip_right: proportion to skip from the right + + Returns: + The cropped BoundingBox + """ + h = self.maxy - self.miny + w = self.maxx - self.minx + + miny, minx = self.miny + int(h * skip_bottom), self.minx + int(w * skip_left) + maxy, maxx = self.maxy - int(h * skip_top), self.maxx - int(w * skip_right) + + return BoundingBox(minx, maxx, miny, maxy, self.mint, self.maxt) + def disambiguate_timestamp(date_str: str, format: str) -> Tuple[float, float]: """Disambiguate partial timestamps. From 45f963bfd12b2d198bbc86f3c82ed21bd0576ede Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Mon, 24 Oct 2022 12:21:28 -0300 Subject: [PATCH 02/49] add extent_crop param to RasterDataset --- torchgeo/datasets/geo.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index f21c7901a97..6b61b02a4fe 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -309,6 +309,7 @@ def __init__( bands: Optional[Sequence[str]] = None, transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, cache: bool = True, + extent_crop: Optional[Tuple[float, float, float, float]] = None, ) -> None: """Initialize a new Dataset instance. @@ -322,6 +323,8 @@ def __init__( transforms: a function/transform that takes an input sample and returns a transformed version cache: if True, cache file handle to speed up repeated sampling + extent_crop: (skip_bottom, skip_left, skip_top, skip_right) crop + underlying raster by skipping a proportion of it from its edges Raises: FileNotFoundError: if no files are found in ``root`` @@ -364,7 +367,11 @@ def __init__( date = match.group("date") mint, maxt = disambiguate_timestamp(date, self.date_format) - coords = (minx, maxx, miny, maxy, mint, maxt) + if extent_crop: + bbox = BoundingBox(minx, maxx, miny, maxy, mint, maxt) + coords = tuple(bbox.extent_crop(*extent_crop)) + else: + coords = (minx, maxx, miny, maxy, mint, maxt) self.index.insert(i, coords, filepath) i += 1 From cf8e824af843be5c84ec30f66521fadecf9c6ee9 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Thu, 27 Oct 2022 12:59:43 -0300 Subject: [PATCH 03/49] train_test_split function --- torchgeo/datasets/utils.py | 83 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 3eb6725e847..4acc6e76686 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -11,6 +11,7 @@ import os import sys import tarfile +from copy import deepcopy from dataclasses import dataclass from datetime import datetime, timedelta from typing import ( @@ -30,10 +31,13 @@ import numpy as np import rasterio import torch +from rtree.index import Index, Property from torch import Tensor from torchvision.datasets.utils import check_integrity, download_url from torchvision.utils import draw_segmentation_masks +from .geo import GeoDataset + __all__ = ( "check_integrity", "download_url", @@ -422,6 +426,40 @@ def extent_crop( return BoundingBox(minx, maxx, miny, maxy, self.mint, self.maxt) + def split( + self, proportion: float, horizontal: bool = True + ) -> Tuple["BoundingBox", "BoundingBox"]: + """Split BoundingBox in two. + + Args: + proportion: split proportion + horizontal: whether the split is horizontal (True) or + vertical + + Returns: + A tuple with the resulting BoundingBoxes + """ + if horizontal: + w = self.maxx - self.minx + splitx = self.minx + int(w * proportion) + bbox1 = BoundingBox( + self.minx, splitx, self.miny, self.maxy, self.mint, self.maxt + ) + bbox2 = BoundingBox( + splitx, self.maxx, self.miny, self.maxy, self.mint, self.maxt + ) + else: + h = self.maxy - self.miny + splity = self.miny + int(h * proportion) + bbox1 = BoundingBox( + self.minx, self.maxx, self.miny, splity, self.mint, self.maxt + ) + bbox2 = BoundingBox( + self.minx, self.maxx, splity, self.maxy, self.mint, self.maxt + ) + + return bbox1, bbox2 + def disambiguate_timestamp(date_str: str, format: str) -> Tuple[float, float]: """Disambiguate partial timestamps. @@ -733,3 +771,48 @@ def percentile_normalization( (img - lower_percentile) / (upper_percentile - lower_percentile), 0, 1 ) return img_normalized + + +def train_test_split( + dataset: GeoDataset, test_size: float = 0.25, random_seed: Optional[int] = None +) -> Tuple[GeoDataset, GeoDataset]: + """Splits a dataset into train and test. + + This function will go through each BoundingBox saved in the index and split it + in a random direction by the proportion specified in test_size. + + Args: + dataset: GeoDataset to split + test_size: proportion of GeoDataset to use for test, in range [0,1] + random_seed: random seed for reproducibility + + Returns + normalized version of ``img`` + + .. versionadded:: 0.4 + """ + assert 0 < test_size < 1, "test_size must be between 0 and 1" + + np.random.seed(random_seed) + + index_train = Index(interleaved=False, properties=Property(dimension=3)) + index_test = Index(interleaved=False, properties=Property(dimension=3)) + + for i, hit in enumerate( + dataset.index.intersection(dataset.index.bounds, objects=True) + ): + box = BoundingBox(*hit.bounds) + horizontal, flip = np.random.randint(2, size=2) + if flip: + box_train, box_test = box.split(1 - test_size, horizontal) + else: + box_test, box_train = box.split(test_size, horizontal) + index_train.insert(i, tuple(box_train)) + index_test.insert(i, tuple(box_test)) + + dataset_train = deepcopy(dataset) + dataset_train.index = index_train + dataset_test = deepcopy(dataset) + dataset_test.index = index_test + + return dataset_train, dataset_test From a50c98321d5f2e80d502323524005de87c7d24e7 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 2 Nov 2022 11:18:58 -0300 Subject: [PATCH 04/49] minor changes --- torchgeo/datasets/utils.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 4acc6e76686..1a253f33583 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -55,6 +55,7 @@ "draw_semantic_segmentation_masks", "rgb_to_mask", "percentile_normalization", + "train_test_split", ) @@ -432,13 +433,16 @@ def split( """Split BoundingBox in two. Args: - proportion: split proportion - horizontal: whether the split is horizontal (True) or - vertical + proportion: split proportion in range [0,1] + horizontal: whether the split is horizontal or vertical Returns: A tuple with the resulting BoundingBoxes + + .. versionadded:: 0.4 """ + assert 0 < proportion < 1, "test_size must be between 0 and 1" + if horizontal: w = self.maxx - self.minx splitx = self.minx + int(w * proportion) @@ -778,8 +782,8 @@ def train_test_split( ) -> Tuple[GeoDataset, GeoDataset]: """Splits a dataset into train and test. - This function will go through each BoundingBox saved in the index and split it - in a random direction by the proportion specified in test_size. + This function will go through each BoundingBox saved in the GeoDataset's index and + split it in a random direction by the proportion specified in test_size. Args: dataset: GeoDataset to split @@ -787,13 +791,14 @@ def train_test_split( random_seed: random seed for reproducibility Returns - normalized version of ``img`` + A tuple with the resulting GeoDatasets in order (train, test) .. versionadded:: 0.4 """ assert 0 < test_size < 1, "test_size must be between 0 and 1" - np.random.seed(random_seed) + if random_seed: + np.random.seed(random_seed) index_train = Index(interleaved=False, properties=Property(dimension=3)) index_test = Index(interleaved=False, properties=Property(dimension=3)) From 85d5718077624d16b501493d5210636deb2e6023 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 2 Nov 2022 13:01:59 -0300 Subject: [PATCH 05/49] fix circular import --- torchgeo/datasets/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 1a253f33583..a4e48d19959 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -15,6 +15,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import ( + TYPE_CHECKING, Any, Dict, Iterable, @@ -36,7 +37,8 @@ from torchvision.datasets.utils import check_integrity, download_url from torchvision.utils import draw_segmentation_masks -from .geo import GeoDataset +if TYPE_CHECKING: + from .geo import GeoDataset __all__ = ( "check_integrity", @@ -778,8 +780,8 @@ def percentile_normalization( def train_test_split( - dataset: GeoDataset, test_size: float = 0.25, random_seed: Optional[int] = None -) -> Tuple[GeoDataset, GeoDataset]: + dataset: "GeoDataset", test_size: float = 0.25, random_seed: Optional[int] = None +) -> Tuple["GeoDataset", "GeoDataset"]: """Splits a dataset into train and test. This function will go through each BoundingBox saved in the GeoDataset's index and From d343c3bf3fb2543c4928890a39cc0733e4d1cb22 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Tue, 20 Dec 2022 16:19:27 -0300 Subject: [PATCH 06/49] remove extent_crop --- torchgeo/datasets/geo.py | 9 +-------- torchgeo/datasets/utils.py | 26 -------------------------- 2 files changed, 1 insertion(+), 34 deletions(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 13c503d67ff..60bf2f2b8f0 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -309,7 +309,6 @@ def __init__( bands: Optional[Sequence[str]] = None, transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, cache: bool = True, - extent_crop: Optional[Tuple[float, float, float, float]] = None, ) -> None: """Initialize a new Dataset instance. @@ -323,8 +322,6 @@ def __init__( transforms: a function/transform that takes an input sample and returns a transformed version cache: if True, cache file handle to speed up repeated sampling - extent_crop: (skip_bottom, skip_left, skip_top, skip_right) crop - underlying raster by skipping a proportion of it from its edges Raises: FileNotFoundError: if no files are found in ``root`` @@ -367,11 +364,7 @@ def __init__( date = match.group("date") mint, maxt = disambiguate_timestamp(date, self.date_format) - if extent_crop: - bbox = BoundingBox(minx, maxx, miny, maxy, mint, maxt) - coords = tuple(bbox.extent_crop(*extent_crop)) - else: - coords = (minx, maxx, miny, maxy, mint, maxt) + coords = (minx, maxx, miny, maxy, mint, maxt) self.index.insert(i, coords, filepath) i += 1 diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index a4e48d19959..91cd0de44eb 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -403,32 +403,6 @@ def intersects(self, other: "BoundingBox") -> bool: and self.maxt >= other.mint ) - def extent_crop( - self, - skip_bottom: float = 0.0, - skip_left: float = 0.0, - skip_top: float = 0.0, - skip_right: float = 0.0, - ) -> "BoundingBox": - """Crop BoundingBox by skipping a proportion from its sides. - - Args: - skip_bottom: proportion to skip from the bottom - skip_left: proportion to skip from the left - skip_top: proportion to skip from the top - skip_right: proportion to skip from the right - - Returns: - The cropped BoundingBox - """ - h = self.maxy - self.miny - w = self.maxx - self.minx - - miny, minx = self.miny + int(h * skip_bottom), self.minx + int(w * skip_left) - maxy, maxx = self.maxy - int(h * skip_top), self.maxx - int(w * skip_right) - - return BoundingBox(minx, maxx, miny, maxy, self.mint, self.maxt) - def split( self, proportion: float, horizontal: bool = True ) -> Tuple["BoundingBox", "BoundingBox"]: From bd65c85367a600b967ef11fc81379103f0d9e5f9 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Tue, 20 Dec 2022 16:26:59 -0300 Subject: [PATCH 07/49] move existing functions to new file --- torchgeo/datasets/splits.py | 88 +++++++++++++++++++++++++++++++++++++ torchgeo/datasets/utils.py | 53 ---------------------- 2 files changed, 88 insertions(+), 53 deletions(-) create mode 100644 torchgeo/datasets/splits.py diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py new file mode 100644 index 00000000000..171455ada5f --- /dev/null +++ b/torchgeo/datasets/splits.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Dataset splitting utilities.""" + +from copy import deepcopy +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +from rtree.index import Index, Property +from torch.utils.data import Subset, TensorDataset, random_split + +from ..datasets import GeoDataset, NonGeoDataset +from .utils import BoundingBox + + +def random_nongeo_split( + dataset: Union[TensorDataset, NonGeoDataset], + val_pct: float, + test_pct: Optional[float] = None, +) -> List[Subset[Any]]: + """Split a torch Dataset into train/val/test sets. + + If ``test_pct`` is not set then only train and validation splits are returned. + + Args: + dataset: dataset to be split into train/val or train/val/test subsets + val_pct: percentage of samples to be in validation set + test_pct: (Optional) percentage of samples to be in test set + + Returns: + a list of the subset datasets. Either [train, val] or [train, val, test] + """ + if test_pct is None: + val_length = round(len(dataset) * val_pct) + train_length = len(dataset) - val_length + return random_split(dataset, [train_length, val_length]) + else: + val_length = round(len(dataset) * val_pct) + test_length = round(len(dataset) * test_pct) + train_length = len(dataset) - (val_length + test_length) + return random_split(dataset, [train_length, val_length, test_length]) + + +def random_bbox_splitting( + dataset: GeoDataset, test_size: float = 0.25, random_seed: Optional[int] = None +) -> Tuple[GeoDataset, GeoDataset]: + """Splits a dataset into train and test. + + This function will go through each BoundingBox saved in the GeoDataset's index and + split it in a random direction by the proportion specified in test_size. + + Args: + dataset: GeoDataset to split + test_size: proportion of GeoDataset to use for test, in range [0,1] + random_seed: random seed for reproducibility + + Returns + A tuple with the resulting GeoDatasets in order (train, test) + + .. versionadded:: 0.4 + """ + assert 0 < test_size < 1, "test_size must be between 0 and 1" + + if random_seed: + np.random.seed(random_seed) + + index_train = Index(interleaved=False, properties=Property(dimension=3)) + index_test = Index(interleaved=False, properties=Property(dimension=3)) + + for i, hit in enumerate( + dataset.index.intersection(dataset.index.bounds, objects=True) + ): + box = BoundingBox(*hit.bounds) + horizontal, flip = np.random.randint(2, size=2) + if flip: + box_train, box_test = box.split(1 - test_size, horizontal) + else: + box_test, box_train = box.split(test_size, horizontal) + index_train.insert(i, tuple(box_train)) + index_test.insert(i, tuple(box_test)) + + dataset_train = deepcopy(dataset) + dataset_train.index = index_train + dataset_test = deepcopy(dataset) + dataset_test.index = index_test + + return dataset_train, dataset_test diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 91cd0de44eb..102ff7bbd24 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -11,11 +11,9 @@ import os import sys import tarfile -from copy import deepcopy from dataclasses import dataclass from datetime import datetime, timedelta from typing import ( - TYPE_CHECKING, Any, Dict, Iterable, @@ -32,14 +30,10 @@ import numpy as np import rasterio import torch -from rtree.index import Index, Property from torch import Tensor from torchvision.datasets.utils import check_integrity, download_url from torchvision.utils import draw_segmentation_masks -if TYPE_CHECKING: - from .geo import GeoDataset - __all__ = ( "check_integrity", "download_url", @@ -57,7 +51,6 @@ "draw_semantic_segmentation_masks", "rgb_to_mask", "percentile_normalization", - "train_test_split", ) @@ -751,49 +744,3 @@ def percentile_normalization( (img - lower_percentile) / (upper_percentile - lower_percentile), 0, 1 ) return img_normalized - - -def train_test_split( - dataset: "GeoDataset", test_size: float = 0.25, random_seed: Optional[int] = None -) -> Tuple["GeoDataset", "GeoDataset"]: - """Splits a dataset into train and test. - - This function will go through each BoundingBox saved in the GeoDataset's index and - split it in a random direction by the proportion specified in test_size. - - Args: - dataset: GeoDataset to split - test_size: proportion of GeoDataset to use for test, in range [0,1] - random_seed: random seed for reproducibility - - Returns - A tuple with the resulting GeoDatasets in order (train, test) - - .. versionadded:: 0.4 - """ - assert 0 < test_size < 1, "test_size must be between 0 and 1" - - if random_seed: - np.random.seed(random_seed) - - index_train = Index(interleaved=False, properties=Property(dimension=3)) - index_test = Index(interleaved=False, properties=Property(dimension=3)) - - for i, hit in enumerate( - dataset.index.intersection(dataset.index.bounds, objects=True) - ): - box = BoundingBox(*hit.bounds) - horizontal, flip = np.random.randint(2, size=2) - if flip: - box_train, box_test = box.split(1 - test_size, horizontal) - else: - box_test, box_train = box.split(test_size, horizontal) - index_train.insert(i, tuple(box_train)) - index_test.insert(i, tuple(box_test)) - - dataset_train = deepcopy(dataset) - dataset_train.index = index_train - dataset_test = deepcopy(dataset) - dataset_test.index = index_test - - return dataset_train, dataset_test From 6f694e8a21886db17fefbf1294be0cffde95a81c Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Tue, 20 Dec 2022 16:39:42 -0300 Subject: [PATCH 08/49] refactor random_nongeo_split --- torchgeo/datasets/splits.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 171455ada5f..ac03d1c6d93 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -4,10 +4,11 @@ """Dataset splitting utilities.""" from copy import deepcopy -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np from rtree.index import Index, Property +from torch import Generator, default_generator from torch.utils.data import Subset, TensorDataset, random_split from ..datasets import GeoDataset, NonGeoDataset @@ -16,30 +17,20 @@ def random_nongeo_split( dataset: Union[TensorDataset, NonGeoDataset], - val_pct: float, - test_pct: Optional[float] = None, + lengths: Sequence[Union[int, float]], + generator: Optional[Generator] = default_generator, ) -> List[Subset[Any]]: - """Split a torch Dataset into train/val/test sets. - - If ``test_pct`` is not set then only train and validation splits are returned. + """Randomly split a NonGeoDataset into non-overlapping new NonGeoDatasets. Args: - dataset: dataset to be split into train/val or train/val/test subsets - val_pct: percentage of samples to be in validation set - test_pct: (Optional) percentage of samples to be in test set + dataset: dataset to be split + lengths: lengths or fractions of splits to be produced + generator: (optional) generator used for the random permutation Returns: - a list of the subset datasets. Either [train, val] or [train, val, test] + A list of the subset datasets. """ - if test_pct is None: - val_length = round(len(dataset) * val_pct) - train_length = len(dataset) - val_length - return random_split(dataset, [train_length, val_length]) - else: - val_length = round(len(dataset) * val_pct) - test_length = round(len(dataset) * test_pct) - train_length = len(dataset) - (val_length + test_length) - return random_split(dataset, [train_length, val_length, test_length]) + return random_split(dataset, lengths, generator) def random_bbox_splitting( From 9a8094352068c3d11e4cc95c1eef02ea9157fe3c Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Tue, 20 Dec 2022 17:34:43 -0300 Subject: [PATCH 09/49] refactor random_bbox_splitting --- torchgeo/datasets/splits.py | 72 ++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index ac03d1c6d93..bec874e1241 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -4,11 +4,10 @@ """Dataset splitting utilities.""" from copy import deepcopy -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Union -import numpy as np from rtree.index import Index, Property -from torch import Generator, default_generator +from torch import Generator, default_generator, randint from torch.utils.data import Subset, TensorDataset, random_split from ..datasets import GeoDataset, NonGeoDataset @@ -29,51 +28,60 @@ def random_nongeo_split( Returns: A list of the subset datasets. + + .. versionadded:: 0.4 """ return random_split(dataset, lengths, generator) def random_bbox_splitting( - dataset: GeoDataset, test_size: float = 0.25, random_seed: Optional[int] = None -) -> Tuple[GeoDataset, GeoDataset]: - """Splits a dataset into train and test. + dataset: GeoDataset, + fractions: Sequence[float], + generator: Optional[Generator] = default_generator, +) -> List[GeoDataset]: + """Randomly split a GeoDataset by splitting its index's BoundingBoxes. - This function will go through each BoundingBox saved in the GeoDataset's index and - split it in a random direction by the proportion specified in test_size. + This function will go through each BoundingBox in the GeoDataset's index and + split it in a random direction. Args: - dataset: GeoDataset to split - test_size: proportion of GeoDataset to use for test, in range [0,1] - random_seed: random seed for reproducibility + dataset: dataset to be split + fractions: fractions of splits to be produced + generator: (optional) generator used for the random permutation Returns - A tuple with the resulting GeoDatasets in order (train, test) + A list of the subset datasets. .. versionadded:: 0.4 """ - assert 0 < test_size < 1, "test_size must be between 0 and 1" + assert sum(fractions) == 1, "fractions must add up to 1" - if random_seed: - np.random.seed(random_seed) - - index_train = Index(interleaved=False, properties=Property(dimension=3)) - index_test = Index(interleaved=False, properties=Property(dimension=3)) + new_indexes = [ + Index(interleaved=False, properties=Property(dimension=3)) for _ in fractions + ] for i, hit in enumerate( dataset.index.intersection(dataset.index.bounds, objects=True) ): box = BoundingBox(*hit.bounds) - horizontal, flip = np.random.randint(2, size=2) - if flip: - box_train, box_test = box.split(1 - test_size, horizontal) - else: - box_test, box_train = box.split(test_size, horizontal) - index_train.insert(i, tuple(box_train)) - index_test.insert(i, tuple(box_test)) - - dataset_train = deepcopy(dataset) - dataset_train.index = index_train - dataset_test = deepcopy(dataset) - dataset_test.index = index_test - - return dataset_train, dataset_test + fraction_left = 1.0 + + for j, frac in enumerate(fractions): + horizontal, flip = randint(0, 2, (2,), generator=generator) + + if fraction_left == frac: + new_box = box + elif flip: + box, new_box = box.split((1 - frac) / fraction_left, horizontal) + else: + new_box, box = box.split(frac / fraction_left, horizontal) + + new_indexes[j].insert(i, tuple(new_box)) + fraction_left -= frac + + def new_geodataset_like(dataset: GeoDataset, index: Index) -> GeoDataset: + new_dataset = deepcopy(dataset) + new_dataset.index = index + return new_dataset + + return [new_geodataset_like(dataset, index) for index in new_indexes] From 8c54b844980205aa3cd61268b0fb6be9dce883a8 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Tue, 20 Dec 2022 17:55:56 -0300 Subject: [PATCH 10/49] add roi_split --- torchgeo/datasets/splits.py | 44 +++++++++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index bec874e1241..a45edc629cf 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -34,6 +34,23 @@ def random_nongeo_split( return random_split(dataset, lengths, generator) +def new_geodataset_like(dataset: GeoDataset, index: Index) -> GeoDataset: + """Create a new GeoDataset like an existing one and change its index. + + Args: + dataset: dataset to copy + index: new index + + Returns: + A new GeoDataset. + + .. versionadded:: 0.4 + """ + new_dataset = deepcopy(dataset) + new_dataset.index = index + return new_dataset + + def random_bbox_splitting( dataset: GeoDataset, fractions: Sequence[float], @@ -79,9 +96,28 @@ def random_bbox_splitting( new_indexes[j].insert(i, tuple(new_box)) fraction_left -= frac - def new_geodataset_like(dataset: GeoDataset, index: Index) -> GeoDataset: - new_dataset = deepcopy(dataset) - new_dataset.index = index - return new_dataset + return [new_geodataset_like(dataset, index) for index in new_indexes] + + +def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDataset]: + """Split a GeoDataset by intersecting it with a ROI for each desired new GeoDataset. + + Args: + dataset: dataset to be split + rois: regions of interest of splits to be produced + + Returns + A list of the subset datasets. + + .. versionadded:: 0.4 + """ + new_indexes = [ + Index(interleaved=False, properties=Property(dimension=3)) for _ in rois + ] + + for i, roi in enumerate(rois): + for j, hit in enumerate(dataset.index.intersection(tuple(roi), objects=True)): + box = BoundingBox(*hit.bounds) + new_indexes[i].insert(j, tuple(box & roi)) return [new_geodataset_like(dataset, index) for index in new_indexes] From 6698745c7f97c2ad34301e3d8cf74cb1159c4ad0 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Tue, 20 Dec 2022 18:55:12 -0300 Subject: [PATCH 11/49] add random_bbox_assignment --- torchgeo/datasets/splits.py | 39 +++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index a45edc629cf..339ab33aa84 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -7,7 +7,7 @@ from typing import Any, List, Optional, Sequence, Union from rtree.index import Index, Property -from torch import Generator, default_generator, randint +from torch import Generator, default_generator, randint, randperm from torch.utils.data import Subset, TensorDataset, random_split from ..datasets import GeoDataset, NonGeoDataset @@ -35,7 +35,7 @@ def random_nongeo_split( def new_geodataset_like(dataset: GeoDataset, index: Index) -> GeoDataset: - """Create a new GeoDataset like an existing one and change its index. + """Utility to create a new GeoDataset from an existing one with a different index. Args: dataset: dataset to copy @@ -56,7 +56,7 @@ def random_bbox_splitting( fractions: Sequence[float], generator: Optional[Generator] = default_generator, ) -> List[GeoDataset]: - """Randomly split a GeoDataset by splitting its index's BoundingBoxes. + """Split a GeoDataset randomly splitting its index's BoundingBoxes. This function will go through each BoundingBox in the GeoDataset's index and split it in a random direction. @@ -99,8 +99,39 @@ def random_bbox_splitting( return [new_geodataset_like(dataset, index) for index in new_indexes] +def random_bbox_assignment( + dataset: GeoDataset, + lengths: Sequence[int], + generator: Optional[Generator] = default_generator, +) -> List[GeoDataset]: + """Split a GeoDataset randomly assigning its index's BoundingBoxes. + + Args: + dataset: dataset to be split + lengths: lengths of splits to be produced + generator: (optional) generator used for the random permutation + + Returns + A list of the subset datasets. + + .. versionadded:: 0.4 + """ + hits = list(dataset.index.intersection(dataset.index.bounds, objects=True)) + hits = [hits[i] for i in randperm(sum(lengths), generator=generator)] + + new_indexes = [ + Index(interleaved=False, properties=Property(dimension=3)) for _ in lengths + ] + + for i, length in enumerate(lengths): + for j in range(length): + new_indexes[i].insert(j, hits.pop().bounds) + + return [new_geodataset_like(dataset, index) for index in new_indexes] + + def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDataset]: - """Split a GeoDataset by intersecting it with a ROI for each desired new GeoDataset. + """Split a GeoDataset intersecting it with a ROI for each desired new GeoDataset. Args: dataset: dataset to be split From 4de89b4b994eebdf9ab0c00cf87144d54448a88a Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Tue, 20 Dec 2022 20:49:54 -0300 Subject: [PATCH 12/49] add input checks --- torchgeo/datasets/splits.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 339ab33aa84..ec094a61661 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -4,6 +4,7 @@ """Dataset splitting utilities.""" from copy import deepcopy +from math import floor from typing import Any, List, Optional, Sequence, Union from rtree.index import Index, Property @@ -71,7 +72,11 @@ def random_bbox_splitting( .. versionadded:: 0.4 """ - assert sum(fractions) == 1, "fractions must add up to 1" + if sum(fractions) != 1: + raise ValueError("Sum of input fractions must equal 1.") + + if any(n <= 0 for n in fractions): + raise ValueError("All items in input fractions must be greater than 0.") new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in fractions @@ -108,7 +113,7 @@ def random_bbox_assignment( Args: dataset: dataset to be split - lengths: lengths of splits to be produced + lengths: lengths or fractions of splits to be produced generator: (optional) generator used for the random permutation Returns @@ -117,6 +122,22 @@ def random_bbox_assignment( .. versionadded:: 0.4 """ hits = list(dataset.index.intersection(dataset.index.bounds, objects=True)) + if sum(lengths) != 1 or sum(lengths) != len(hits): + raise ValueError( + "Sum of input lengths must equal 1 or be the length of the dataset's index." + ) + + if any(n <= 0 for n in lengths): + raise ValueError("All items in input lengths must be greater than 0.") + + if sum(lengths) == 1: + lengths = [floor(frac * len(hits)) for frac in lengths] + remainder = len(hits) - sum(lengths) + # add 1 to all the lengths in round-robin fashion until the remainder is 0 + for i in range(remainder): + idx_to_add_at = i % len(lengths) + lengths[idx_to_add_at] += 1 + hits = [hits[i] for i in randperm(sum(lengths), generator=generator)] new_indexes = [ From 38a8b1b26789646b7bd415b9aa459210e863bfc4 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Tue, 20 Dec 2022 21:00:53 -0300 Subject: [PATCH 13/49] fix input type --- torchgeo/datasets/splits.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index ec094a61661..c924ddb08a6 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -106,7 +106,7 @@ def random_bbox_splitting( def random_bbox_assignment( dataset: GeoDataset, - lengths: Sequence[int], + lengths: Sequence[Union[int, float]], generator: Optional[Generator] = default_generator, ) -> List[GeoDataset]: """Split a GeoDataset randomly assigning its index's BoundingBoxes. @@ -132,20 +132,23 @@ def random_bbox_assignment( if sum(lengths) == 1: lengths = [floor(frac * len(hits)) for frac in lengths] - remainder = len(hits) - sum(lengths) + remainder = int(len(hits) - sum(lengths)) # add 1 to all the lengths in round-robin fashion until the remainder is 0 for i in range(remainder): idx_to_add_at = i % len(lengths) lengths[idx_to_add_at] += 1 - hits = [hits[i] for i in randperm(sum(lengths), generator=generator)] + hits = [ + hits[i] + for i in randperm(sum(lengths), generator=generator) # type: ignore[arg-type] + ] new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in lengths ] for i, length in enumerate(lengths): - for j in range(length): + for j in range(length): # type: ignore[arg-type] new_indexes[i].insert(j, hits.pop().bounds) return [new_geodataset_like(dataset, index) for index in new_indexes] From 3ca7c5f459b799472a3b9188e7101aa62e487505 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 21 Dec 2022 12:36:44 -0300 Subject: [PATCH 14/49] minor reorder --- torchgeo/datasets/splits.py | 113 +++++++++++++++++++----------------- 1 file changed, 60 insertions(+), 53 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index c924ddb08a6..9d02dcfec88 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -14,6 +14,13 @@ from ..datasets import GeoDataset, NonGeoDataset from .utils import BoundingBox +__all__ = ( + "random_nongeo_split", + "random_bbox_assignment", + "random_bbox_splitting", + "roi_split", +) + def random_nongeo_split( dataset: Union[TensorDataset, NonGeoDataset], @@ -35,7 +42,7 @@ def random_nongeo_split( return random_split(dataset, lengths, generator) -def new_geodataset_like(dataset: GeoDataset, index: Index) -> GeoDataset: +def _create_geodataset_like(dataset: GeoDataset, index: Index) -> GeoDataset: """Utility to create a new GeoDataset from an existing one with a different index. Args: @@ -52,6 +59,56 @@ def new_geodataset_like(dataset: GeoDataset, index: Index) -> GeoDataset: return new_dataset +def random_bbox_assignment( + dataset: GeoDataset, + lengths: Sequence[Union[int, float]], + generator: Optional[Generator] = default_generator, +) -> List[GeoDataset]: + """Split a GeoDataset randomly assigning its index's BoundingBoxes. + + Args: + dataset: dataset to be split + lengths: lengths or fractions of splits to be produced + generator: (optional) generator used for the random permutation + + Returns + A list of the subset datasets. + + .. versionadded:: 0.4 + """ + hits = list(dataset.index.intersection(dataset.index.bounds, objects=True)) + if sum(lengths) != 1 or sum(lengths) != len(hits): + raise ValueError( + "Sum of input lengths must equal 1 or be the length of the dataset's index." + ) + + if any(n <= 0 for n in lengths): + raise ValueError("All items in input lengths must be greater than 0.") + + if sum(lengths) == 1: + lengths = [floor(frac * len(hits)) for frac in lengths] + remainder = int(len(hits) - sum(lengths)) + # add 1 to all the lengths in round-robin fashion until the remainder is 0 + for i in range(remainder): + idx_to_add_at = i % len(lengths) + lengths[idx_to_add_at] += 1 + + hits = [ + hits[i] + for i in randperm(sum(lengths), generator=generator) # type: ignore[arg-type] + ] + + new_indexes = [ + Index(interleaved=False, properties=Property(dimension=3)) for _ in lengths + ] + + for i, length in enumerate(lengths): + for j in range(length): # type: ignore[arg-type] + new_indexes[i].insert(j, hits.pop().bounds) + + return [_create_geodataset_like(dataset, index) for index in new_indexes] + + def random_bbox_splitting( dataset: GeoDataset, fractions: Sequence[float], @@ -101,57 +158,7 @@ def random_bbox_splitting( new_indexes[j].insert(i, tuple(new_box)) fraction_left -= frac - return [new_geodataset_like(dataset, index) for index in new_indexes] - - -def random_bbox_assignment( - dataset: GeoDataset, - lengths: Sequence[Union[int, float]], - generator: Optional[Generator] = default_generator, -) -> List[GeoDataset]: - """Split a GeoDataset randomly assigning its index's BoundingBoxes. - - Args: - dataset: dataset to be split - lengths: lengths or fractions of splits to be produced - generator: (optional) generator used for the random permutation - - Returns - A list of the subset datasets. - - .. versionadded:: 0.4 - """ - hits = list(dataset.index.intersection(dataset.index.bounds, objects=True)) - if sum(lengths) != 1 or sum(lengths) != len(hits): - raise ValueError( - "Sum of input lengths must equal 1 or be the length of the dataset's index." - ) - - if any(n <= 0 for n in lengths): - raise ValueError("All items in input lengths must be greater than 0.") - - if sum(lengths) == 1: - lengths = [floor(frac * len(hits)) for frac in lengths] - remainder = int(len(hits) - sum(lengths)) - # add 1 to all the lengths in round-robin fashion until the remainder is 0 - for i in range(remainder): - idx_to_add_at = i % len(lengths) - lengths[idx_to_add_at] += 1 - - hits = [ - hits[i] - for i in randperm(sum(lengths), generator=generator) # type: ignore[arg-type] - ] - - new_indexes = [ - Index(interleaved=False, properties=Property(dimension=3)) for _ in lengths - ] - - for i, length in enumerate(lengths): - for j in range(length): # type: ignore[arg-type] - new_indexes[i].insert(j, hits.pop().bounds) - - return [new_geodataset_like(dataset, index) for index in new_indexes] + return [_create_geodataset_like(dataset, index) for index in new_indexes] def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDataset]: @@ -175,4 +182,4 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDatas box = BoundingBox(*hit.bounds) new_indexes[i].insert(j, tuple(box & roi)) - return [new_geodataset_like(dataset, index) for index in new_indexes] + return [_create_geodataset_like(dataset, index) for index in new_indexes] From c3ff1122b7b9a120bad58ce9e674f1cb47286096 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 21 Dec 2022 13:28:47 -0300 Subject: [PATCH 15/49] add tests --- tests/datasets/test_splits.py | 78 +++++++++++++++++++++++++++++++++++ torchgeo/datasets/splits.py | 2 +- 2 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 tests/datasets/test_splits.py diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py new file mode 100644 index 00000000000..43a015e224b --- /dev/null +++ b/tests/datasets/test_splits.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from math import floor + +import pytest +import torch +from torch.utils.data import TensorDataset + +from torchgeo.datasets.splits import ( # random_bbox_splitting,; roi_split, + random_bbox_assignment, + random_nongeo_split, +) +from torchgeo.datasets.utils import BoundingBox + +from .test_geo import CustomGeoDataset + + +def test_random_nongeo_split() -> None: + num_samples = 24 + x = torch.ones(num_samples, 5) + y = torch.randint(low=0, high=2, size=(num_samples,)) + ds = TensorDataset(x, y) + + # Test only train/val set split + train_ds, val_ds = random_nongeo_split(ds, lengths=[1 / 2, 1 / 2]) + assert len(train_ds) == round(num_samples / 2) + assert len(val_ds) == round(num_samples / 2) + + # Test train/val/test set split + train_ds, val_ds, test_ds = random_nongeo_split(ds, lengths=[1 / 3, 1 / 3, 1 / 3]) + assert len(train_ds) == round(num_samples / 3) + assert len(val_ds) == round(num_samples / 3) + assert len(test_ds) == round(num_samples / 3) + + +def test_random_bbox_assignment() -> None: + ds = ( + CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 0)) + | CustomGeoDataset(BoundingBox(1, 2, 0, 1, 0, 0)) + | CustomGeoDataset(BoundingBox(2, 3, 0, 1, 0, 0)) + | CustomGeoDataset(BoundingBox(3, 4, 0, 1, 0, 0)) + ) + + num_bbox = len(ds.index.count(ds.index.bounds)) + + # Test list of lengths + train_ds, val_ds, test_ds = random_bbox_assignment(ds, lengths=[2, 1, 1]) + assert len(train_ds.index.count(train_ds.index.bounds)) == 2 + assert len(val_ds.index.count(val_ds.index.bounds)) == 1 + assert len(test_ds.index.count(test_ds.index.bounds)) == 1 + + # Test list of fractions + train_ds, val_ds, test_ds = random_bbox_assignment( + ds, lengths=[1 / 2, 1 / 4, 1 / 4] + ) + assert len(train_ds.index.count(train_ds.index.bounds)) == num_bbox / 2 + assert len(val_ds.index.count(val_ds.index.bounds)) == num_bbox / 4 + assert len(test_ds.index.count(test_ds.index.bounds)) == num_bbox / 4 + + # Test list of fractions with remainder + train_ds, val_ds, test_ds = random_bbox_assignment( + ds, lengths=[1 / 3, 1 / 3, 1 / 3] + ) + assert len(train_ds.index.count(train_ds.index.bounds)) == floor(num_bbox / 3) + 1 + assert len(val_ds.index.count(val_ds.index.bounds)) == floor(num_bbox / 3) + assert len(test_ds.index.count(test_ds.index.bounds)) == floor(num_bbox / 3) + + # Test invalid input lenghts + with pytest.raises( + ValueError, + match="Sum of input lengths must equal 1 or the length of dataset's index.", + ): + random_bbox_assignment(ds, lengths=[2, 2, 1]) + with pytest.raises( + ValueError, match="All items in input lengths must be greater than 0." + ): + random_bbox_assignment(ds, lengths=[1 / 2, 3 / 4, -1 / 4]) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 9d02dcfec88..dfe538a893a 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -79,7 +79,7 @@ def random_bbox_assignment( hits = list(dataset.index.intersection(dataset.index.bounds, objects=True)) if sum(lengths) != 1 or sum(lengths) != len(hits): raise ValueError( - "Sum of input lengths must equal 1 or be the length of the dataset's index." + "Sum of input lengths must equal 1 or the length of dataset's index." ) if any(n <= 0 for n in lengths): From 49fc2d1a7d44b2f9c744311beae38f75274efc76 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 21 Dec 2022 16:00:30 -0300 Subject: [PATCH 16/49] add non-overlapping test --- tests/datasets/test_splits.py | 30 ++++++++++++------------------ torchgeo/datasets/splits.py | 9 +++++---- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 43a015e224b..928c94bc319 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -42,31 +42,25 @@ def test_random_bbox_assignment() -> None: | CustomGeoDataset(BoundingBox(3, 4, 0, 1, 0, 0)) ) - num_bbox = len(ds.index.count(ds.index.bounds)) - # Test list of lengths train_ds, val_ds, test_ds = random_bbox_assignment(ds, lengths=[2, 1, 1]) - assert len(train_ds.index.count(train_ds.index.bounds)) == 2 - assert len(val_ds.index.count(val_ds.index.bounds)) == 1 - assert len(test_ds.index.count(test_ds.index.bounds)) == 1 - - # Test list of fractions - train_ds, val_ds, test_ds = random_bbox_assignment( - ds, lengths=[1 / 2, 1 / 4, 1 / 4] - ) - assert len(train_ds.index.count(train_ds.index.bounds)) == num_bbox / 2 - assert len(val_ds.index.count(val_ds.index.bounds)) == num_bbox / 4 - assert len(test_ds.index.count(test_ds.index.bounds)) == num_bbox / 4 + assert len(train_ds) == 2 + assert len(val_ds) == 1 + assert len(test_ds) == 1 + assert len(train_ds & val_ds & test_ds) == 0 + assert (train_ds | val_ds | test_ds).bounds == ds.bounds - # Test list of fractions with remainder + # Test list of fractions (with remainder) train_ds, val_ds, test_ds = random_bbox_assignment( ds, lengths=[1 / 3, 1 / 3, 1 / 3] ) - assert len(train_ds.index.count(train_ds.index.bounds)) == floor(num_bbox / 3) + 1 - assert len(val_ds.index.count(val_ds.index.bounds)) == floor(num_bbox / 3) - assert len(test_ds.index.count(test_ds.index.bounds)) == floor(num_bbox / 3) + assert len(train_ds) == floor(len(ds) / 3) + 1 + assert len(val_ds) == floor(len(ds) / 3) + assert len(test_ds) == floor(len(ds) / 3) + assert len(train_ds & val_ds & test_ds) == 0 + assert (train_ds | val_ds | test_ds).bounds == ds.bounds - # Test invalid input lenghts + # Test invalid input lengths with pytest.raises( ValueError, match="Sum of input lengths must equal 1 or the length of dataset's index.", diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index dfe538a893a..939a99451f8 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -76,8 +76,7 @@ def random_bbox_assignment( .. versionadded:: 0.4 """ - hits = list(dataset.index.intersection(dataset.index.bounds, objects=True)) - if sum(lengths) != 1 or sum(lengths) != len(hits): + if sum(lengths) != 1 or sum(lengths) != len(dataset): raise ValueError( "Sum of input lengths must equal 1 or the length of dataset's index." ) @@ -86,13 +85,15 @@ def random_bbox_assignment( raise ValueError("All items in input lengths must be greater than 0.") if sum(lengths) == 1: - lengths = [floor(frac * len(hits)) for frac in lengths] - remainder = int(len(hits) - sum(lengths)) + lengths = [floor(frac * len(dataset)) for frac in lengths] + remainder = int(len(dataset) - sum(lengths)) # add 1 to all the lengths in round-robin fashion until the remainder is 0 for i in range(remainder): idx_to_add_at = i % len(lengths) lengths[idx_to_add_at] += 1 + hits = list(dataset.index.intersection(dataset.index.bounds, objects=True)) + hits = [ hits[i] for i in randperm(sum(lengths), generator=generator) # type: ignore[arg-type] From d8ad1b4f72027cfaf8aeb8fa3c414bb32bd5c345 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 21 Dec 2022 16:39:47 -0300 Subject: [PATCH 17/49] more tests --- tests/datasets/test_splits.py | 76 ++++++++++++++++++++++++++++++++++- torchgeo/datasets/splits.py | 4 ++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 928c94bc319..68fd4b817e9 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -7,9 +7,12 @@ import torch from torch.utils.data import TensorDataset -from torchgeo.datasets.splits import ( # random_bbox_splitting,; roi_split, +from torchgeo.datasets import GeoDataset +from torchgeo.datasets.splits import ( random_bbox_assignment, + random_bbox_splitting, random_nongeo_split, + roi_split, ) from torchgeo.datasets.utils import BoundingBox @@ -70,3 +73,74 @@ def test_random_bbox_assignment() -> None: ValueError, match="All items in input lengths must be greater than 0." ): random_bbox_assignment(ds, lengths=[1 / 2, 3 / 4, -1 / 4]) + + +def get_total_area(dataset: GeoDataset) -> float: + + total_area = 0.0 + for hit in dataset.index.intersection(dataset.index.bounds, objects=True): + total_area += BoundingBox(*hit.bounds).area + + return total_area + + +def test_random_bbox_splitting() -> None: + ds = ( + CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 0)) + | CustomGeoDataset(BoundingBox(1, 2, 0, 1, 0, 0)) + | CustomGeoDataset(BoundingBox(2, 3, 0, 1, 0, 0)) + | CustomGeoDataset(BoundingBox(3, 4, 0, 1, 0, 0)) + ) + + ds_area = get_total_area(ds) + + # Test list of fractions + train_ds, val_ds, test_ds = random_bbox_splitting( + ds, fractions=[1 / 2, 1 / 4, 1 / 4] + ) + train_ds_area = get_total_area(train_ds) + val_ds_area = get_total_area(val_ds) + test_ds_area = get_total_area(test_ds) + + assert train_ds_area == ds_area / 2 + assert val_ds_area == ds_area / 4 + assert test_ds_area == ds_area / 4 + assert len(train_ds & val_ds & test_ds) == 0 + assert (train_ds | val_ds | test_ds).bounds == ds.bounds + + # Test invalid input fractions + with pytest.raises(ValueError, match="Sum of input fractions must equal 1."): + random_bbox_splitting(ds, fractions=[1 / 2, 1 / 3, 1 / 4]) + with pytest.raises( + ValueError, match="All items in input lengths must be greater than 0." + ): + random_bbox_splitting(ds, fractions=[1 / 2, 3 / 4, -1 / 4]) + + +def test_roi_split() -> None: + ds = ( + CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 0)) + | CustomGeoDataset(BoundingBox(1, 2, 0, 1, 0, 0)) + | CustomGeoDataset(BoundingBox(2, 3, 0, 1, 0, 0)) + | CustomGeoDataset(BoundingBox(3, 4, 0, 1, 0, 0)) + ) + + train_ds, val_ds, test_ds = roi_split( + ds, + rois=[ + BoundingBox(0, 2, 0, 1, 0, 0), + BoundingBox(2, 3.5, 0, 1, 0, 0), + BoundingBox(3.5, 4, 0, 1, 0, 0), + ], + ) + assert len(train_ds) == 2 + assert len(val_ds) == 2 + assert len(test_ds) == 1 + assert len(train_ds & val_ds & test_ds) == 0 + assert (train_ds | val_ds | test_ds).bounds == ds.bounds + + # Test invalid input rois + with pytest.raises(ValueError, match="ROIs in input roi should not overlap."): + roi_split( + ds, rois=[BoundingBox(0, 2, 0, 1, 0, 0), BoundingBox(1, 3, 0, 1, 0, 0)] + ) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 939a99451f8..9ee7cc83319 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -4,6 +4,7 @@ """Dataset splitting utilities.""" from copy import deepcopy +from functools import reduce from math import floor from typing import Any, List, Optional, Sequence, Union @@ -174,6 +175,9 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDatas .. versionadded:: 0.4 """ + if reduce(lambda x, y: x & y, rois).area != 0: + raise ValueError("ROIs in input roi should not overlap.") + new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in rois ] From 8c80c9bbf75eca464b02dd25fd70d87a67527d12 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 21 Dec 2022 19:09:06 -0300 Subject: [PATCH 18/49] fix tests --- tests/datasets/test_splits.py | 36 ++++++++++++++++++----------------- torchgeo/datasets/geo.py | 6 ++++-- torchgeo/datasets/splits.py | 13 ++++++------- torchgeo/datasets/utils.py | 4 ++-- 4 files changed, 31 insertions(+), 28 deletions(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 68fd4b817e9..28acfeca87d 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -50,7 +50,9 @@ def test_random_bbox_assignment() -> None: assert len(train_ds) == 2 assert len(val_ds) == 1 assert len(test_ds) == 1 - assert len(train_ds & val_ds & test_ds) == 0 + assert len(train_ds & val_ds) == 0 + assert len(val_ds & test_ds) == 0 + assert len(test_ds & train_ds) == 0 assert (train_ds | val_ds | test_ds).bounds == ds.bounds # Test list of fractions (with remainder) @@ -60,7 +62,9 @@ def test_random_bbox_assignment() -> None: assert len(train_ds) == floor(len(ds) / 3) + 1 assert len(val_ds) == floor(len(ds) / 3) assert len(test_ds) == floor(len(ds) / 3) - assert len(train_ds & val_ds & test_ds) == 0 + assert len(train_ds & val_ds) == 0 + assert len(val_ds & test_ds) == 0 + assert len(test_ds & train_ds) == 0 assert (train_ds | val_ds | test_ds).bounds == ds.bounds # Test invalid input lengths @@ -105,14 +109,16 @@ def test_random_bbox_splitting() -> None: assert train_ds_area == ds_area / 2 assert val_ds_area == ds_area / 4 assert test_ds_area == ds_area / 4 - assert len(train_ds & val_ds & test_ds) == 0 - assert (train_ds | val_ds | test_ds).bounds == ds.bounds + assert len(train_ds & val_ds) == 0 + assert len(val_ds & test_ds) == 0 + assert len(test_ds & train_ds) == 0 + assert get_total_area(train_ds | val_ds | test_ds) == ds_area # Test invalid input fractions with pytest.raises(ValueError, match="Sum of input fractions must equal 1."): random_bbox_splitting(ds, fractions=[1 / 2, 1 / 3, 1 / 4]) with pytest.raises( - ValueError, match="All items in input lengths must be greater than 0." + ValueError, match="All items in input fractions must be greater than 0." ): random_bbox_splitting(ds, fractions=[1 / 2, 3 / 4, -1 / 4]) @@ -120,27 +126,23 @@ def test_random_bbox_splitting() -> None: def test_roi_split() -> None: ds = ( CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 0)) - | CustomGeoDataset(BoundingBox(1, 2, 0, 1, 0, 0)) | CustomGeoDataset(BoundingBox(2, 3, 0, 1, 0, 0)) - | CustomGeoDataset(BoundingBox(3, 4, 0, 1, 0, 0)) + | CustomGeoDataset(BoundingBox(4, 5, 0, 1, 0, 0)) + | CustomGeoDataset(BoundingBox(6, 7, 0, 1, 0, 0)) ) train_ds, val_ds, test_ds = roi_split( ds, rois=[ - BoundingBox(0, 2, 0, 1, 0, 0), - BoundingBox(2, 3.5, 0, 1, 0, 0), - BoundingBox(3.5, 4, 0, 1, 0, 0), + BoundingBox(0, 3, 0, 1, 0, 0), + BoundingBox(4, 6.5, 0, 1, 0, 0), + BoundingBox(6.5, 7, 0, 1, 0, 0), ], ) assert len(train_ds) == 2 assert len(val_ds) == 2 assert len(test_ds) == 1 - assert len(train_ds & val_ds & test_ds) == 0 + assert len(train_ds & val_ds) == 0 + assert len(val_ds & test_ds) == 0 + assert len(test_ds & train_ds) == 0 assert (train_ds | val_ds | test_ds).bounds == ds.bounds - - # Test invalid input rois - with pytest.raises(ValueError, match="ROIs in input roi should not overlap."): - roi_split( - ds, rois=[BoundingBox(0, 2, 0, 1, 0, 0), BoundingBox(1, 3, 0, 1, 0, 0)] - ) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 60bf2f2b8f0..78653e83895 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -884,8 +884,10 @@ def _merge_dataset_indices(self) -> None: for hit2 in ds2.index.intersection(hit1.bounds, objects=True): box1 = BoundingBox(*hit1.bounds) box2 = BoundingBox(*hit2.bounds) - self.index.insert(i, tuple(box1 & box2)) - i += 1 + new_box = box1 & box2 + if new_box.area > 0: + self.index.insert(i, tuple(box1 & box2)) + i += 1 def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: """Retrieve image and metadata indexed by query. diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 9ee7cc83319..d29ab1cbc77 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -4,7 +4,6 @@ """Dataset splitting utilities.""" from copy import deepcopy -from functools import reduce from math import floor from typing import Any, List, Optional, Sequence, Union @@ -77,7 +76,7 @@ def random_bbox_assignment( .. versionadded:: 0.4 """ - if sum(lengths) != 1 or sum(lengths) != len(dataset): + if not (sum(lengths) == 1 or sum(lengths) == len(dataset)): raise ValueError( "Sum of input lengths must equal 1 or the length of dataset's index." ) @@ -147,18 +146,21 @@ def random_bbox_splitting( box = BoundingBox(*hit.bounds) fraction_left = 1.0 + horizontal, flip = randint(0, 2, (2,), generator=generator) for j, frac in enumerate(fractions): - horizontal, flip = randint(0, 2, (2,), generator=generator) if fraction_left == frac: new_box = box elif flip: - box, new_box = box.split((1 - frac) / fraction_left, horizontal) + box, new_box = box.split( + (fraction_left - frac) / fraction_left, horizontal + ) else: new_box, box = box.split(frac / fraction_left, horizontal) new_indexes[j].insert(i, tuple(new_box)) fraction_left -= frac + horizontal = not horizontal return [_create_geodataset_like(dataset, index) for index in new_indexes] @@ -175,9 +177,6 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDatas .. versionadded:: 0.4 """ - if reduce(lambda x, y: x & y, rois).area != 0: - raise ValueError("ROIs in input roi should not overlap.") - new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in rois ] diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 102ff7bbd24..0d4cb144bce 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -414,7 +414,7 @@ def split( if horizontal: w = self.maxx - self.minx - splitx = self.minx + int(w * proportion) + splitx = self.minx + w * proportion bbox1 = BoundingBox( self.minx, splitx, self.miny, self.maxy, self.mint, self.maxt ) @@ -423,7 +423,7 @@ def split( ) else: h = self.maxy - self.miny - splity = self.miny + int(h * proportion) + splity = self.miny + h * proportion bbox1 = BoundingBox( self.minx, self.maxx, self.miny, splity, self.mint, self.maxt ) From d676ed1b3da6520d1e4bc585666ce6057a4e470d Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Thu, 22 Dec 2022 10:00:32 -0300 Subject: [PATCH 19/49] additional tests --- tests/datasets/test_geo.py | 6 ++++++ tests/datasets/test_utils.py | 29 +++++++++++++++++++++++++++++ torchgeo/datasets/geo.py | 2 +- torchgeo/datasets/utils.py | 5 +++-- 4 files changed, 39 insertions(+), 3 deletions(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index b691961806e..5f02c1dfeb2 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -447,6 +447,12 @@ def test_no_overlap(self) -> None: ds = IntersectionDataset(ds1, ds2) assert len(ds) == 0 + def test_contiguous(self) -> None: + ds1 = CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 1)) + ds2 = CustomGeoDataset(BoundingBox(1, 2, 0, 1, 0, 1)) + ds = IntersectionDataset(ds1, ds2) + assert len(ds) == 0 + def test_invalid_query(self, dataset: IntersectionDataset) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index abd37ce6930..e89794a9726 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -394,6 +394,35 @@ def test_intersects( bbox2 = BoundingBox(*test_input) assert bbox1.intersects(bbox2) == bbox2.intersects(bbox1) == expected + @pytest.mark.parametrize( + "proportion,horizontal,expected", + [ + (0.25, True, ((0, 0.25, 0, 1, 0, 1), (0.25, 1, 0, 1, 0, 1))), + (0.25, False, ((0, 1, 0, 0.25, 0, 1), (0, 1, 0.25, 1, 0, 1))), + ], + ) + def test_split( + self, + proportion: float, + horizontal: bool, + expected: Tuple[ + Tuple[float, float, float, float, float, float], + Tuple[float, float, float, float, float, float], + ], + ) -> None: + bbox = BoundingBox(0, 1, 0, 1, 0, 1) + bbox1, bbox2 = bbox.split(proportion, horizontal) + assert bbox1 == BoundingBox(*expected[0]) + assert bbox2 == BoundingBox(*expected[1]) + assert bbox1 | bbox2 == bbox + + def test_split_error(self) -> None: + bbox = BoundingBox(0, 1, 0, 1, 0, 1) + with pytest.raises( + ValueError, match="Input proportion must be between 0 and 1." + ): + bbox.split(1.5) + def test_picklable(self) -> None: bbox = BoundingBox(0, 1, 2, 3, 4, 5) x = pickle.dumps(bbox) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 78653e83895..5c77981d047 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -886,7 +886,7 @@ def _merge_dataset_indices(self) -> None: box2 = BoundingBox(*hit2.bounds) new_box = box1 & box2 if new_box.area > 0: - self.index.insert(i, tuple(box1 & box2)) + self.index.insert(i, tuple(new_box)) i += 1 def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 0d4cb144bce..7e3b0d9619d 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -402,7 +402,7 @@ def split( """Split BoundingBox in two. Args: - proportion: split proportion in range [0,1] + proportion: split proportion in range (0,1) horizontal: whether the split is horizontal or vertical Returns: @@ -410,7 +410,8 @@ def split( .. versionadded:: 0.4 """ - assert 0 < proportion < 1, "test_size must be between 0 and 1" + if not (0.0 < proportion < 1.0): + raise ValueError("Input proportion must be between 0 and 1.") if horizontal: w = self.maxx - self.minx From 503478cc8b7dfb46a2d5ce95983540c719ec853a Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Thu, 22 Dec 2022 10:47:27 -0300 Subject: [PATCH 20/49] check overlapping rois --- tests/datasets/test_splits.py | 16 +++++++++++----- torchgeo/datasets/splits.py | 14 ++++++++++++-- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 28acfeca87d..a9fe5da2926 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -126,17 +126,17 @@ def test_random_bbox_splitting() -> None: def test_roi_split() -> None: ds = ( CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 0)) + | CustomGeoDataset(BoundingBox(1, 2, 0, 1, 0, 0)) | CustomGeoDataset(BoundingBox(2, 3, 0, 1, 0, 0)) - | CustomGeoDataset(BoundingBox(4, 5, 0, 1, 0, 0)) - | CustomGeoDataset(BoundingBox(6, 7, 0, 1, 0, 0)) + | CustomGeoDataset(BoundingBox(3, 4, 0, 1, 0, 0)) ) train_ds, val_ds, test_ds = roi_split( ds, rois=[ - BoundingBox(0, 3, 0, 1, 0, 0), - BoundingBox(4, 6.5, 0, 1, 0, 0), - BoundingBox(6.5, 7, 0, 1, 0, 0), + BoundingBox(0, 2, 0, 1, 0, 0), + BoundingBox(2, 3.5, 0, 1, 0, 0), + BoundingBox(3.5, 4, 0, 1, 0, 0), ], ) assert len(train_ds) == 2 @@ -146,3 +146,9 @@ def test_roi_split() -> None: assert len(val_ds & test_ds) == 0 assert len(test_ds & train_ds) == 0 assert (train_ds | val_ds | test_ds).bounds == ds.bounds + + # Test invalid input rois + with pytest.raises(ValueError, match="ROIs in input rois can't overlap."): + roi_split( + ds, rois=[BoundingBox(0, 2, 0, 1, 0, 0), BoundingBox(1, 3, 0, 1, 0, 0)] + ) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index d29ab1cbc77..03ed919cfce 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -177,13 +177,23 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDatas .. versionadded:: 0.4 """ + _rois = list(rois).copy() + while len(_rois) > 1: + r = _rois.pop() + if any(r.intersects(x) and (r & x).area > 0 for x in _rois): + raise ValueError("ROIs in input rois can't overlap.") + new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in rois ] for i, roi in enumerate(rois): - for j, hit in enumerate(dataset.index.intersection(tuple(roi), objects=True)): + j = 0 + for hit in dataset.index.intersection(tuple(roi), objects=True): box = BoundingBox(*hit.bounds) - new_indexes[i].insert(j, tuple(box & roi)) + new_box = box & roi + if new_box.area > 0: + new_indexes[i].insert(j, tuple(new_box)) + j += 1 return [_create_geodataset_like(dataset, index) for index in new_indexes] From 4a96e4091c3a4078e732b35ee473c4469e627d46 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Thu, 29 Dec 2022 13:37:36 -0300 Subject: [PATCH 21/49] add time_series_split with tests --- tests/datasets/test_splits.py | 80 +++++++++++++++++++++++++++++ torchgeo/datasets/splits.py | 95 ++++++++++++++++++++++++++++++++--- 2 files changed, 168 insertions(+), 7 deletions(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index a9fe5da2926..af87710d665 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -13,6 +13,7 @@ random_bbox_splitting, random_nongeo_split, roi_split, + time_series_split, ) from torchgeo.datasets.utils import BoundingBox @@ -152,3 +153,82 @@ def test_roi_split() -> None: roi_split( ds, rois=[BoundingBox(0, 2, 0, 1, 0, 0), BoundingBox(1, 3, 0, 1, 0, 0)] ) + + +def test_time_series_split() -> None: + ds = ( + CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 10)) + | CustomGeoDataset(BoundingBox(0, 1, 0, 1, 10, 20)) + | CustomGeoDataset(BoundingBox(0, 1, 0, 1, 20, 30)) + | CustomGeoDataset(BoundingBox(0, 1, 0, 1, 30, 40)) + ) + + # Test lengths input using timestamps + train_ds, val_ds, test_ds = time_series_split( + ds, lengths=[(0, 20), (20, 35), (35, 40)] + ) + + assert len(train_ds) == 2 + assert len(val_ds) == 2 + assert len(test_ds) == 1 + assert len(train_ds & val_ds) == 0 + assert len(val_ds & test_ds) == 0 + assert len(test_ds & train_ds) == 0 + assert (train_ds | val_ds | test_ds).bounds == ds.bounds + + # Test lengths input using lengths + train_ds, val_ds, test_ds = time_series_split(ds, lengths=[20, 15, 5]) + + assert len(train_ds) == 2 + assert len(val_ds) == 2 + assert len(test_ds) == 1 + assert len(train_ds & val_ds) == 0 + assert len(val_ds & test_ds) == 0 + assert len(test_ds & train_ds) == 0 + assert (train_ds | val_ds | test_ds).bounds == ds.bounds + + # Test lengths input using fractions + train_ds, val_ds, test_ds = time_series_split(ds, lengths=[1 / 2, 3 / 8, 1 / 8]) + + assert len(train_ds) == 2 + assert len(val_ds) == 2 + assert len(test_ds) == 1 + assert len(train_ds & val_ds) == 0 + assert len(val_ds & test_ds) == 0 + assert len(test_ds & train_ds) == 0 + assert (train_ds | val_ds | test_ds).bounds == ds.bounds + + # Test invalid input lengths + with pytest.raises( + ValueError, + match="Pairs of timestamps in lengths must have end greater than start.", + ): + time_series_split(ds, lengths=[(0, 20), (35, 20), (35, 40)]) + + with pytest.raises( + ValueError, + match="Pairs of timestamps in lengths must cover dataset's time bounds.", + ): + time_series_split(ds, lengths=[(0, 20), (20, 35)]) + + with pytest.raises( + ValueError, + match="Pairs of timestamps in lengths can't be out of dataset's time bounds.", + ): + time_series_split(ds, lengths=[(0, 20), (20, 45)]) + + with pytest.raises( + ValueError, match="Pairs of timestamps in lengths can't overlap." + ): + time_series_split(ds, lengths=[(0, 10), (10, 20), (15, 40)]) + + with pytest.raises( + ValueError, + match="Sum of input lengths must equal 1 or the dataset's time length.", + ): + time_series_split(ds, lengths=[1 / 2, 1 / 2, 1 / 2]) + + with pytest.raises( + ValueError, match="All items in input lengths must be greater than 0." + ): + time_series_split(ds, lengths=[20, 25, -5]) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 03ed919cfce..7472a5ba268 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -5,10 +5,11 @@ from copy import deepcopy from math import floor -from typing import Any, List, Optional, Sequence, Union +from typing import Any, List, Optional, Sequence, Tuple, Union from rtree.index import Index, Property from torch import Generator, default_generator, randint, randperm +from torch._utils import _accumulate from torch.utils.data import Subset, TensorDataset, random_split from ..datasets import GeoDataset, NonGeoDataset @@ -19,6 +20,7 @@ "random_bbox_assignment", "random_bbox_splitting", "roi_split", + "time_series_split", ) @@ -177,17 +179,14 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDatas .. versionadded:: 0.4 """ - _rois = list(rois).copy() - while len(_rois) > 1: - r = _rois.pop() - if any(r.intersects(x) and (r & x).area > 0 for x in _rois): - raise ValueError("ROIs in input rois can't overlap.") - new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in rois ] for i, roi in enumerate(rois): + if any(roi.intersects(x) and (roi & x).area > 0 for x in rois[i + 1 :]): + raise ValueError("ROIs in input rois can't overlap.") + j = 0 for hit in dataset.index.intersection(tuple(roi), objects=True): box = BoundingBox(*hit.bounds) @@ -197,3 +196,85 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDatas j += 1 return [_create_geodataset_like(dataset, index) for index in new_indexes] + + +def time_series_split( + dataset: GeoDataset, lengths: Sequence[Union[int, float, Tuple[int, int]]] +) -> List[GeoDataset]: + """Split a GeoDataset on it's time dimension to create non-overlapping GeoDatasets. + + Args: + dataset: dataset to be split + lengths: lengths, fractions or pairs of timestamps (start, end) of splits + to be produced + + Returns + A list of the subset datasets. + + .. versionadded:: 0.4 + """ + minx, maxx, miny, maxy, mint, maxt = dataset.bounds + + totalt = maxt - mint + + if not all(isinstance(x, tuple) for x in lengths): + + if not (sum(lengths) == 1 or sum(lengths) == totalt): # type: ignore[arg-type] + raise ValueError( + "Sum of input lengths must equal 1 or the dataset's time length." + ) + + if any(n <= 0 for n in lengths): # type: ignore[operator] + raise ValueError("All items in input lengths must be greater than 0.") + + if sum(lengths) == 1: # type: ignore[arg-type] + lengths = [totalt * f for f in lengths] # type: ignore[operator] + + lengths = [ + (mint + offset - length, mint + offset) + for offset, length in zip( + _accumulate(lengths), lengths # type: ignore[no-untyped-call] + ) + ] + + new_indexes = [ + Index(interleaved=False, properties=Property(dimension=3)) for _ in lengths + ] + + _totalt = 0 + for i, (start, end) in enumerate(lengths): # type: ignore[misc] + + if start >= end: + raise ValueError( + "Pairs of timestamps in lengths must have end greater than start." + ) + + if start < mint or end > maxt: + raise ValueError( + "Pairs of timestamps in lengths can't be out of dataset's time bounds." + ) + + if any( # type: ignore[misc] + start < x < end or start < y < end for x, y in lengths[i + 1 :] + ): + raise ValueError("Pairs of timestamps in lengths can't overlap.") + + # remove one second from each BoundingBox's maxt to avoid overlapping + offset = 0 if i == len(lengths) - 1 else 1 + roi = BoundingBox(minx, maxx, miny, maxy, start, end - offset) + j = 0 + for hit in dataset.index.intersection(tuple(roi), objects=True): + box = BoundingBox(*hit.bounds) + new_box = box & roi + if new_box.volume > 0: + new_indexes[i].insert(j, tuple(new_box)) + j += 1 + + _totalt += end - start + + if not _totalt == totalt: + raise ValueError( + "Pairs of timestamps in lengths must cover dataset's time bounds." + ) + + return [_create_geodataset_like(dataset, index) for index in new_indexes] From 8842d783ddfdfc38b04f7c664546f44773e8bf5d Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Thu, 29 Dec 2022 13:43:33 -0300 Subject: [PATCH 22/49] fix random_nongeo_split to work with fractions in torch 1.9 --- torchgeo/datasets/splits.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 7472a5ba268..63335476f74 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -41,6 +41,13 @@ def random_nongeo_split( .. versionadded:: 0.4 """ + if sum(lengths) == 1: + lengths = [floor(frac * len(dataset)) for frac in lengths] + remainder = int(len(dataset) - sum(lengths)) + # add 1 to all the lengths in round-robin fashion until the remainder is 0 + for i in range(remainder): + idx_to_add_at = i % len(lengths) + lengths[idx_to_add_at] += 1 return random_split(dataset, lengths, generator) From c9b8b19abc25867a132a111de70a78c5aa2daa4c Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Thu, 29 Dec 2022 14:00:31 -0300 Subject: [PATCH 23/49] modify random_nongeo_split test for coverage --- tests/datasets/test_splits.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index af87710d665..9ea37402cc9 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -21,7 +21,7 @@ def test_random_nongeo_split() -> None: - num_samples = 24 + num_samples = 26 x = torch.ones(num_samples, 5) y = torch.randint(low=0, high=2, size=(num_samples,)) ds = TensorDataset(x, y) @@ -31,11 +31,11 @@ def test_random_nongeo_split() -> None: assert len(train_ds) == round(num_samples / 2) assert len(val_ds) == round(num_samples / 2) - # Test train/val/test set split + # Test train/val/test set split with remainder train_ds, val_ds, test_ds = random_nongeo_split(ds, lengths=[1 / 3, 1 / 3, 1 / 3]) - assert len(train_ds) == round(num_samples / 3) - assert len(val_ds) == round(num_samples / 3) - assert len(test_ds) == round(num_samples / 3) + assert len(train_ds) == floor(num_samples / 3) + 1 + assert len(val_ds) == floor(num_samples / 3) + 1 + assert len(test_ds) == floor(num_samples / 3) def test_random_bbox_assignment() -> None: From 54a1781a299c3b166f6dbb976977dd2e1f16ddba Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Tue, 3 Jan 2023 11:10:31 -0300 Subject: [PATCH 24/49] add random_grid_cell_assignment with tests --- tests/datasets/test_splits.py | 41 +++++++++-- torchgeo/datasets/splits.py | 128 +++++++++++++++++++++++++++------- 2 files changed, 138 insertions(+), 31 deletions(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 9ea37402cc9..78ef8860e94 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from math import floor +from math import floor, isclose import pytest import torch @@ -11,6 +11,7 @@ from torchgeo.datasets.splits import ( random_bbox_assignment, random_bbox_splitting, + random_grid_cell_assignment, random_nongeo_split, roi_split, time_series_split, @@ -80,7 +81,7 @@ def test_random_bbox_assignment() -> None: random_bbox_assignment(ds, lengths=[1 / 2, 3 / 4, -1 / 4]) -def get_total_area(dataset: GeoDataset) -> float: +def _get_total_area(dataset: GeoDataset) -> float: total_area = 0.0 for hit in dataset.index.intersection(dataset.index.bounds, objects=True): @@ -97,15 +98,15 @@ def test_random_bbox_splitting() -> None: | CustomGeoDataset(BoundingBox(3, 4, 0, 1, 0, 0)) ) - ds_area = get_total_area(ds) + ds_area = _get_total_area(ds) # Test list of fractions train_ds, val_ds, test_ds = random_bbox_splitting( ds, fractions=[1 / 2, 1 / 4, 1 / 4] ) - train_ds_area = get_total_area(train_ds) - val_ds_area = get_total_area(val_ds) - test_ds_area = get_total_area(test_ds) + train_ds_area = _get_total_area(train_ds) + val_ds_area = _get_total_area(val_ds) + test_ds_area = _get_total_area(test_ds) assert train_ds_area == ds_area / 2 assert val_ds_area == ds_area / 4 @@ -113,7 +114,7 @@ def test_random_bbox_splitting() -> None: assert len(train_ds & val_ds) == 0 assert len(val_ds & test_ds) == 0 assert len(test_ds & train_ds) == 0 - assert get_total_area(train_ds | val_ds | test_ds) == ds_area + assert isclose(_get_total_area(train_ds | val_ds | test_ds), ds_area) # Test invalid input fractions with pytest.raises(ValueError, match="Sum of input fractions must equal 1."): @@ -124,6 +125,32 @@ def test_random_bbox_splitting() -> None: random_bbox_splitting(ds, fractions=[1 / 2, 3 / 4, -1 / 4]) +def test_random_grid_cell_assignment() -> None: + ds = CustomGeoDataset(BoundingBox(0, 12, 0, 12, 0, 0)) | CustomGeoDataset( + BoundingBox(12, 24, 0, 12, 0, 0) + ) + + train_ds, val_ds, test_ds = random_grid_cell_assignment( + ds, fractions=[1 / 2, 1 / 4, 1 / 4], size=5 + ) + + assert len(train_ds) == 1 / 2 * 2 * 5**2 + 1 + assert len(val_ds) == floor(1 / 4 * 2 * 5**2) + assert len(test_ds) == floor(1 / 4 * 2 * 5**2) + assert len(train_ds & val_ds) == 0 + assert len(val_ds & test_ds) == 0 + assert len(test_ds & train_ds) == 0 + assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds)) + + # Test invalid input fractions + with pytest.raises(ValueError, match="Sum of input fractions must equal 1."): + random_grid_cell_assignment(ds, fractions=[1 / 2, 1 / 3, 1 / 4]) + with pytest.raises( + ValueError, match="All items in input fractions must be greater than 0." + ): + random_grid_cell_assignment(ds, fractions=[1 / 2, 3 / 4, -1 / 4]) + + def test_roi_split() -> None: ds = ( CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 0)) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 63335476f74..46f7f1af563 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -19,36 +19,31 @@ "random_nongeo_split", "random_bbox_assignment", "random_bbox_splitting", + "random_grid_cell_assignment", "roi_split", "time_series_split", ) -def random_nongeo_split( - dataset: Union[TensorDataset, NonGeoDataset], - lengths: Sequence[Union[int, float]], - generator: Optional[Generator] = default_generator, -) -> List[Subset[Any]]: - """Randomly split a NonGeoDataset into non-overlapping new NonGeoDatasets. +def _fractions_to_lengths(fractions: Sequence[float], total: int) -> Sequence[int]: + """Utility to divide a number into a list of integers according to fractions. Args: - dataset: dataset to be split - lengths: lengths or fractions of splits to be produced - generator: (optional) generator used for the random permutation + fractions: list of fractions + total: total to be divided Returns: - A list of the subset datasets. + List of lengths. .. versionadded:: 0.4 """ - if sum(lengths) == 1: - lengths = [floor(frac * len(dataset)) for frac in lengths] - remainder = int(len(dataset) - sum(lengths)) - # add 1 to all the lengths in round-robin fashion until the remainder is 0 - for i in range(remainder): - idx_to_add_at = i % len(lengths) - lengths[idx_to_add_at] += 1 - return random_split(dataset, lengths, generator) + lengths = [floor(frac * total) for frac in fractions] + remainder = int(total - sum(lengths)) + # add 1 to all the lengths in round-robin fashion until the remainder is 0 + for i in range(remainder): + idx_to_add_at = i % len(lengths) + lengths[idx_to_add_at] += 1 + return lengths def _create_geodataset_like(dataset: GeoDataset, index: Index) -> GeoDataset: @@ -68,6 +63,28 @@ def _create_geodataset_like(dataset: GeoDataset, index: Index) -> GeoDataset: return new_dataset +def random_nongeo_split( + dataset: Union[TensorDataset, NonGeoDataset], + lengths: Sequence[Union[int, float]], + generator: Optional[Generator] = default_generator, +) -> List[Subset[Any]]: + """Randomly split a NonGeoDataset into non-overlapping new NonGeoDatasets. + + Args: + dataset: dataset to be split + lengths: lengths or fractions of splits to be produced + generator: (optional) generator used for the random permutation + + Returns: + A list of the subset datasets. + + .. versionadded:: 0.4 + """ + if sum(lengths) == 1: + lengths = _fractions_to_lengths(lengths, len(dataset)) + return random_split(dataset, lengths, generator) + + def random_bbox_assignment( dataset: GeoDataset, lengths: Sequence[Union[int, float]], @@ -94,12 +111,7 @@ def random_bbox_assignment( raise ValueError("All items in input lengths must be greater than 0.") if sum(lengths) == 1: - lengths = [floor(frac * len(dataset)) for frac in lengths] - remainder = int(len(dataset) - sum(lengths)) - # add 1 to all the lengths in round-robin fashion until the remainder is 0 - for i in range(remainder): - idx_to_add_at = i % len(lengths) - lengths[idx_to_add_at] += 1 + lengths = _fractions_to_lengths(lengths, len(dataset)) hits = list(dataset.index.intersection(dataset.index.bounds, objects=True)) @@ -174,6 +186,74 @@ def random_bbox_splitting( return [_create_geodataset_like(dataset, index) for index in new_indexes] +def random_grid_cell_assignment( + dataset: GeoDataset, + fractions: Sequence[float], + size: int = 6, + generator: Optional[Generator] = default_generator, +) -> List[GeoDataset]: + """Overlays a grid over a GeoDataset and randomly assigns cells to new GeoDatasets. + + This function will go through each BoundingBox in the GeoDataset's index, overlay + a grid over it, and randomly assign each cell to new GeoDatasets. + + Args: + dataset: dataset to be split + fractions: fractions of splits to be produced + size: (optional) size of the grid + generator: (optional) generator used for the random permutation + + Returns + A list of the subset datasets. + + .. versionadded:: 0.4 + """ + if sum(fractions) != 1: + raise ValueError("Sum of input fractions must equal 1.") + + if any(n <= 0 for n in fractions): + raise ValueError("All items in input fractions must be greater than 0.") + + new_indexes = [ + Index(interleaved=False, properties=Property(dimension=3)) for _ in fractions + ] + + lengths = _fractions_to_lengths(fractions, len(dataset) * size**2) + + cells = [] + + for i, hit in enumerate( + dataset.index.intersection(dataset.index.bounds, objects=True) + ): + minx, maxx, miny, maxy, mint, maxt = hit.bounds + + stridex = (maxx - minx) / size + stridey = (maxy - miny) / size + + cells.extend( + [ + ( + minx + x * stridex, + minx + (x + 1) * stridex, + miny + y * stridey, + miny + (y + 1) * stridey, + mint, + maxt, + ) + for x in range(size) + for y in range(size) + ] + ) + + cells = [cells[i] for i in randperm(len(cells), generator=generator)] + + for i, length in enumerate(lengths): + for j in range(length): + new_indexes[i].insert(j, cells.pop()) + + return [_create_geodataset_like(dataset, index) for index in new_indexes] + + def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDataset]: """Split a GeoDataset intersecting it with a ROI for each desired new GeoDataset. From 2997b95567ba4841955026aab776a470e03dc3e5 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Tue, 3 Jan 2023 11:14:31 -0300 Subject: [PATCH 25/49] add test --- tests/datasets/test_splits.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 78ef8860e94..478ff63b85d 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -115,6 +115,7 @@ def test_random_bbox_splitting() -> None: assert len(val_ds & test_ds) == 0 assert len(test_ds & train_ds) == 0 assert isclose(_get_total_area(train_ds | val_ds | test_ds), ds_area) + assert (train_ds | val_ds | test_ds).bounds == ds.bounds # Test invalid input fractions with pytest.raises(ValueError, match="Sum of input fractions must equal 1."): @@ -141,6 +142,7 @@ def test_random_grid_cell_assignment() -> None: assert len(val_ds & test_ds) == 0 assert len(test_ds & train_ds) == 0 assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds)) + assert (train_ds | val_ds | test_ds).bounds == ds.bounds # Test invalid input fractions with pytest.raises(ValueError, match="Sum of input fractions must equal 1."): From a27d09d784961bb21f851eb9db779d2629821463 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 4 Jan 2023 09:20:56 -0300 Subject: [PATCH 26/49] insert object into new indexes --- torchgeo/datasets/splits.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 46f7f1af563..b1c5b4b8ad1 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -126,7 +126,8 @@ def random_bbox_assignment( for i, length in enumerate(lengths): for j in range(length): # type: ignore[arg-type] - new_indexes[i].insert(j, hits.pop().bounds) + hit = hits.pop() + new_indexes[i].insert(j, hit.bounds, hit.object) return [_create_geodataset_like(dataset, index) for index in new_indexes] @@ -179,7 +180,7 @@ def random_bbox_splitting( else: new_box, box = box.split(frac / fraction_left, horizontal) - new_indexes[j].insert(i, tuple(new_box)) + new_indexes[j].insert(i, tuple(new_box), hit.object) fraction_left -= frac horizontal = not horizontal @@ -233,12 +234,15 @@ def random_grid_cell_assignment( cells.extend( [ ( - minx + x * stridex, - minx + (x + 1) * stridex, - miny + y * stridey, - miny + (y + 1) * stridey, - mint, - maxt, + ( + minx + x * stridex, + minx + (x + 1) * stridex, + miny + y * stridey, + miny + (y + 1) * stridey, + mint, + maxt, + ), + hit.object, ) for x in range(size) for y in range(size) @@ -249,7 +253,8 @@ def random_grid_cell_assignment( for i, length in enumerate(lengths): for j in range(length): - new_indexes[i].insert(j, cells.pop()) + cell = cells.pop() + new_indexes[i].insert(j, cell[0], cell[1]) return [_create_geodataset_like(dataset, index) for index in new_indexes] @@ -279,7 +284,7 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDatas box = BoundingBox(*hit.bounds) new_box = box & roi if new_box.area > 0: - new_indexes[i].insert(j, tuple(new_box)) + new_indexes[i].insert(j, tuple(new_box), hit.object) j += 1 return [_create_geodataset_like(dataset, index) for index in new_indexes] @@ -354,7 +359,7 @@ def time_series_split( box = BoundingBox(*hit.bounds) new_box = box & roi if new_box.volume > 0: - new_indexes[i].insert(j, tuple(new_box)) + new_indexes[i].insert(j, tuple(new_box), hit.object) j += 1 _totalt += end - start From 3a805c008935be5e14ba4f82b77699d0a1d3b41f Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 4 Jan 2023 09:26:04 -0300 Subject: [PATCH 27/49] check grid_size --- tests/datasets/test_splits.py | 4 +++- torchgeo/datasets/splits.py | 17 ++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 478ff63b85d..eef21d2f639 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -132,7 +132,7 @@ def test_random_grid_cell_assignment() -> None: ) train_ds, val_ds, test_ds = random_grid_cell_assignment( - ds, fractions=[1 / 2, 1 / 4, 1 / 4], size=5 + ds, fractions=[1 / 2, 1 / 4, 1 / 4], grid_size=5 ) assert len(train_ds) == 1 / 2 * 2 * 5**2 + 1 @@ -151,6 +151,8 @@ def test_random_grid_cell_assignment() -> None: ValueError, match="All items in input fractions must be greater than 0." ): random_grid_cell_assignment(ds, fractions=[1 / 2, 3 / 4, -1 / 4]) + with pytest.raises(ValueError, match="Input grid_size must be greater than 1."): + random_grid_cell_assignment(ds, fractions=[1 / 2, 1 / 4, 1 / 4], grid_size=1) def test_roi_split() -> None: diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index b1c5b4b8ad1..ae9dddb29bc 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -190,7 +190,7 @@ def random_bbox_splitting( def random_grid_cell_assignment( dataset: GeoDataset, fractions: Sequence[float], - size: int = 6, + grid_size: int = 6, generator: Optional[Generator] = default_generator, ) -> List[GeoDataset]: """Overlays a grid over a GeoDataset and randomly assigns cells to new GeoDatasets. @@ -201,7 +201,7 @@ def random_grid_cell_assignment( Args: dataset: dataset to be split fractions: fractions of splits to be produced - size: (optional) size of the grid + grid_size: (optional) size of the grid generator: (optional) generator used for the random permutation Returns @@ -215,11 +215,14 @@ def random_grid_cell_assignment( if any(n <= 0 for n in fractions): raise ValueError("All items in input fractions must be greater than 0.") + if grid_size < 2: + raise ValueError("Input grid_size must be greater than 1.") + new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in fractions ] - lengths = _fractions_to_lengths(fractions, len(dataset) * size**2) + lengths = _fractions_to_lengths(fractions, len(dataset) * grid_size**2) cells = [] @@ -228,8 +231,8 @@ def random_grid_cell_assignment( ): minx, maxx, miny, maxy, mint, maxt = hit.bounds - stridex = (maxx - minx) / size - stridey = (maxy - miny) / size + stridex = (maxx - minx) / grid_size + stridey = (maxy - miny) / grid_size cells.extend( [ @@ -244,8 +247,8 @@ def random_grid_cell_assignment( ), hit.object, ) - for x in range(size) - for y in range(size) + for x in range(grid_size) + for y in range(grid_size) ] ) From 56a3692b76a2e0dc363cd67291a517f6642b2515 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 4 Jan 2023 15:01:21 -0300 Subject: [PATCH 28/49] better tests --- .DS_Store | Bin 0 -> 8196 bytes tests/.DS_Store | Bin 0 -> 8196 bytes tests/data/.DS_Store | Bin 0 -> 10244 bytes tests/datasets/test_splits.py | 231 ++++++++++++++++++++++------------ 4 files changed, 151 insertions(+), 80 deletions(-) create mode 100644 .DS_Store create mode 100644 tests/.DS_Store create mode 100644 tests/data/.DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..85b0a68d3f8560363a6f44985ea0a211dacc1b69 GIT binary patch literal 8196 zcmeI1TTc@~6vxj%ku3#8qmdVrO?*WJG+yFMxtWkC7z!Gr2D@}AEA1|}+ae*DKI?bz zEBNXs@niVzlm5?{rJ3EfJ{kiv&CHqIJ^#7<=CaH}L~5y3nJ1biA_tM}Q~^n!z|T4M zgrVe)3#*_f@@bbMYSSzo_jubImVhN-30MM_fFxHZR8d=-S^}0pF98>~#&KTfb9*Iqey>w)trqQiwa9}HFRMkw z+20b7Bg!YbHgpXtQ?{7TR@` z0TRhoNPD=d!si&}^|EJ@W@!^r+OZ~rbYMG#t*N{MCAXrC@s7FOun}6+98|Yr!e$BM z7z-b)xYu=*)1U+0nm(-4x@B6UEs)uQ-Jo@-cG2&>v`%=Nx~sGd?-g1_FE{83Qodft zoo(c{P)eOD=`xtUnl$L=ggGc@{J=Him(WQ2F~yYMWg*bUm^DBzL<^WmEDU=_j|ryK z1$ehXu7(-PGu$7b!OL&Q2Jz08l|TgnPre(rxLBE z84$BFy>n=PggG0f{e+e2Hz_NH=YB-*(fVEC6yvHtDN~!LO|)haWBm-a&qyn*=v1r5 zQIAg_K$geG|0^f6rULgEecwinJWF5ot4)agG>GM~ZhZ8og+A)pTusvuYNlcg^=#8~ z{5DWi-`Oh!P0fhRktWj2Xv@qR;OQMT50l;dyLVU%n(^G+4>dHL8@V*I-(IZu{~t~`h}{4H literal 0 HcmV?d00001 diff --git a/tests/.DS_Store b/tests/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..b68d86b5d60aa1300a78c447d6dd205f3e03a50c GIT binary patch literal 8196 zcmeHM&u`O66n?LS)*)N!Lc0_Yl0_~^RBQ`~U4&MIz-p-wYPYFMD}H31WRhBSXCm8a zDOIJM;lv-n-uJZEmG}qt!j&81Pk<{Y_RS9$ns?xe8YQ9k9ysQId$c&7UMdAYlkU*V27oaW zXP@W3p0%x2z$$Pc6_ED_kE*m4WBb~wTL(2N1t43XVO6M4d4S@0##W5&YpbZ_Q)dro zsAy6Q)O4EnB%Ex;*uJ)!4yfsX=FDg^6jWy?o~Ph|6>V*66|f2nDj;+BlFHcw;GgC1 z+jF64HbT)Lgxpj&s9yM}3iVP}7q+1ZKGii~(oc@{PmWa}q@XjkKgZ#j0bGl`CgA6A z&-TMLw&Jdd_TNS~4B}*F<~0r-DjXgjaYmey^RTuiwrYu=^x~=?-Iqs8LWGU9?>CkM zwpDj0ZVI0G0cVjS1T9982loQr5?fW#b6YkI;dI>ov>wX-uj^UJh%arWokX>aH1 z#hGdE@}(=gyG3WR{L{6?`g*v*#Z!u%3YuFOjOydPR<5l8X5B$^)2s(yrR#a?y6K9F zM@m189vySXj~ySM7@wT{@x;{R$y29K)!pL_zq#ItWTZDmD57fEWA{5z(5iL)X23fu z345H>f+wu*9?L-&S?kfKp{4CEVJ1EeD0ja{jB({}zTJ*lvXbeq*4?qV8Hq4l5OI(M zg0DV^N$0tQZOrkupfhfg#`9e{6)&?WCOg1cS@omH=N(qvei-@<7S8jkND{HJ5;O(Z zk&h|2bhr>P>9TzNy_T_Dxs!Iz>5dC!-2o_9Vo&v*iFC+VT;hVWV!5)Q2a5DKI!|Bu zyRZ!F@C2U23-}ve!hi4v-oXbPK^M>BdAy9j;&r@@f8ZiEaSa2Ew1c7aBFis!reI75 zdlz}L7V%K<4vk5CDZTixL7SunJqT%2#kno;fsoma=Xl&v5pm@$NWMDE-Fi@N zj4ab3PKsyL%SDD#p)CH0!E-uWo7NstlsS%4Da#=jwpazeX9W)F#Z8sS{ zyEd!9fmZ+vw`#YlBt!a}C9jCgwLPkTQk6yG`r0ZAH7Xq^tI~0@dw&>G-y^8xRE+Iw UOCwai{)+(n{8z%;=0Gd(KaM1H`Tzg` literal 0 HcmV?d00001 diff --git a/tests/data/.DS_Store b/tests/data/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..21f1469a02093ca4271ac8818209732394a3674c GIT binary patch literal 10244 zcmeHN&u`;I6n<{g*lD&)wJZoBfg&Lx4%L!$7g&U}Dy4hiz-5czz>lmGr>(WF9mUCp zQWfj$+0wF0#QwF0#QwF3VK1(36OL~Ud1 zYpp=7K&`-10nQH{oU{j(y*5^-4qS8#0NcQ0J#deGfO0yOJ*e!pu}XzcvwH|CYEW%4 z0z=394&kr|mAy9B&?PW*396EVYKIbN>ELw~F2RAZ^|e-@R$x^D?%hL)C6VX2TKo56 zfHxdWALG3Dw;s3?c@?-*8Ax9;Nx>T7ZY~9W61l*Af+x9@csG>`^DSWM>sQZL3%*L= zMlyjFM3?)Kq4^0hi^d0TR5n6Mqi~&l@w1tV* zf$=hNeW=2=j}SgifZmrb&V3incnoe^+J^9$;D^eNrS#B25QFYZ;FA1rN} z&@tAw_WY40RK~{M0=9r>wseZR#HuJpnpbyfWAhs5X^u&MbWsxO23a(8JiWV_o=o3ZE(v_W&X63X#<^8K~*zG6r^mo$D&3yYfD*-gV;rC z#MsF8YSgnQJQg*wS<#ZEy(@4tm$cQlE`)Eq%73 zdI~&s;kaiMdHxwe4yOqYvj9iDr$6;^u#L*+*sTXg(t6xrUXBNu&al4?##vfUyWKy6 zjm^eZGmOHh9es0rR?Ls5@pM)m#*?r3+vi1*4HozD;EQxL@9*CIq{ydnnvW(*NRtsD zUw@hANijbxW_eO7_t?CMT2ZUN``YDYulsPXv)}vRYOizo;eL0o^WgsbS68j*mABsg z=t=)scAghMVlKoH1jfovWful*!)n)PIevaT$+IG7b|LMj`UD`v++XKZFyX5tf^%rnduD>FXh#PPEO3^QAfaXkYgL!WAv@q1OH=65S< z8eo88Vjn+P zna2^|4!!{#*BOP|@&NifxKMF9GQ&Q!_oAHIJLp!0HY*>_HS3i9#ZtBjMgiaJGA5O= zu@O*=P5d3}rl@rAFKI-wpBQai7tYZGT=nD}GhWN7kXzae%yXb8TA~+ss+5fnF&^0p zuGJ@Nm9o(-#igi~8NH~JTOYKJ@HpVnCl&6Q@#T> zRc*Y!;(UK2_P$#Ks^MT;@;&sRb{u~??py|A$6+jH;@Z{<)Cz1G2|8`q$2b4|zZR-i zpjO}oRv>5`9UmQ{fs2ic-&THW@8EojlLv9v#wrCD{TvUgpX2chKgYj=lh`gT58NMA g_S#rH!SNsbF`(Z6>wjt4Uo4$HV7>o;!T10F07e*Rc>n+a literal 0 HcmV?d00001 diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index eef21d2f639..7d49d33ecac 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -2,9 +2,11 @@ # Licensed under the MIT License. from math import floor, isclose +from typing import Dict, List, Sequence, Tuple, Union import pytest import torch +from rasterio.crs import CRS from torch.utils.data import TensorDataset from torchgeo.datasets import GeoDataset @@ -18,7 +20,24 @@ ) from torchgeo.datasets.utils import BoundingBox -from .test_geo import CustomGeoDataset + +class CustomGeoDataset(GeoDataset): + def __init__( + self, + items: List[Tuple[BoundingBox, str]] = [(BoundingBox(0, 1, 0, 1, 0, 40), "")], + crs: CRS = CRS.from_epsg(3005), + res: float = 1, + ) -> None: + super().__init__() + for box, content in items: + self.index.insert(0, tuple(box), content) + self._crs = crs + self.res = res + + def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]: + hits = self.index.intersection(tuple(query), objects=True) + hit = next(iter(hits)) + return {"content": hit.object} def test_random_nongeo_split() -> None: @@ -39,46 +58,58 @@ def test_random_nongeo_split() -> None: assert len(test_ds) == floor(num_samples / 3) -def test_random_bbox_assignment() -> None: - ds = ( - CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 0)) - | CustomGeoDataset(BoundingBox(1, 2, 0, 1, 0, 0)) - | CustomGeoDataset(BoundingBox(2, 3, 0, 1, 0, 0)) - | CustomGeoDataset(BoundingBox(3, 4, 0, 1, 0, 0)) +@pytest.mark.parametrize( + "lengths,expected_lengths", + [ + # List of lengths + ([2, 1, 1], [2, 1, 1]), + # List of fractions (with remainder) + ([1 / 3, 1 / 3, 1 / 3], [2, 1, 1]), + ], +) +def test_random_bbox_assignment( + lengths: Sequence[Union[int, float]], expected_lengths: Sequence[int] +) -> None: + ds = CustomGeoDataset( + [ + (BoundingBox(0, 1, 0, 1, 0, 0), "a"), + (BoundingBox(1, 2, 0, 1, 0, 0), "b"), + (BoundingBox(2, 3, 0, 1, 0, 0), "c"), + (BoundingBox(3, 4, 0, 1, 0, 0), "d"), + ] ) - # Test list of lengths - train_ds, val_ds, test_ds = random_bbox_assignment(ds, lengths=[2, 1, 1]) - assert len(train_ds) == 2 - assert len(val_ds) == 1 - assert len(test_ds) == 1 - assert len(train_ds & val_ds) == 0 - assert len(val_ds & test_ds) == 0 - assert len(test_ds & train_ds) == 0 - assert (train_ds | val_ds | test_ds).bounds == ds.bounds + train_ds, val_ds, test_ds = random_bbox_assignment(ds, lengths) - # Test list of fractions (with remainder) - train_ds, val_ds, test_ds = random_bbox_assignment( - ds, lengths=[1 / 3, 1 / 3, 1 / 3] - ) - assert len(train_ds) == floor(len(ds) / 3) + 1 - assert len(val_ds) == floor(len(ds) / 3) - assert len(test_ds) == floor(len(ds) / 3) + # Check datasets lengths + assert len(train_ds) == expected_lengths[0] + assert len(val_ds) == expected_lengths[1] + assert len(test_ds) == expected_lengths[2] + + # No overlap assert len(train_ds & val_ds) == 0 assert len(val_ds & test_ds) == 0 assert len(test_ds & train_ds) == 0 + + # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds - # Test invalid input lengths + # Test __get_item__ + x = train_ds[train_ds.bounds] + assert isinstance(x, dict) + assert isinstance(x["content"], str) + + +def test_random_bbox_assignment_invalid_inputs() -> None: with pytest.raises( ValueError, match="Sum of input lengths must equal 1 or the length of dataset's index.", ): - random_bbox_assignment(ds, lengths=[2, 2, 1]) + random_bbox_assignment(CustomGeoDataset(), lengths=[2, 2, 1]) with pytest.raises( ValueError, match="All items in input lengths must be greater than 0." ): - random_bbox_assignment(ds, lengths=[1 / 2, 3 / 4, -1 / 4]) + random_bbox_assignment(CustomGeoDataset(), lengths=[1 / 2, 3 / 4, -1 / 4]) def _get_total_area(dataset: GeoDataset) -> float: @@ -91,16 +122,17 @@ def _get_total_area(dataset: GeoDataset) -> float: def test_random_bbox_splitting() -> None: - ds = ( - CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 0)) - | CustomGeoDataset(BoundingBox(1, 2, 0, 1, 0, 0)) - | CustomGeoDataset(BoundingBox(2, 3, 0, 1, 0, 0)) - | CustomGeoDataset(BoundingBox(3, 4, 0, 1, 0, 0)) + ds = CustomGeoDataset( + [ + (BoundingBox(0, 1, 0, 1, 0, 0), "a"), + (BoundingBox(1, 2, 0, 1, 0, 0), "b"), + (BoundingBox(2, 3, 0, 1, 0, 0), "c"), + (BoundingBox(3, 4, 0, 1, 0, 0), "d"), + ] ) ds_area = _get_total_area(ds) - # Test list of fractions train_ds, val_ds, test_ds = random_bbox_splitting( ds, fractions=[1 / 2, 1 / 4, 1 / 4] ) @@ -108,14 +140,24 @@ def test_random_bbox_splitting() -> None: val_ds_area = _get_total_area(val_ds) test_ds_area = _get_total_area(test_ds) + # Check datasets areas assert train_ds_area == ds_area / 2 assert val_ds_area == ds_area / 4 assert test_ds_area == ds_area / 4 + + # No overlap assert len(train_ds & val_ds) == 0 assert len(val_ds & test_ds) == 0 assert len(test_ds & train_ds) == 0 - assert isclose(_get_total_area(train_ds | val_ds | test_ds), ds_area) + + # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds + assert isclose(_get_total_area(train_ds | val_ds | test_ds), ds_area) + + # Test __get_item__ + x = train_ds[train_ds.bounds] + assert isinstance(x, dict) + assert isinstance(x["content"], str) # Test invalid input fractions with pytest.raises(ValueError, match="Sum of input fractions must equal 1."): @@ -127,22 +169,35 @@ def test_random_bbox_splitting() -> None: def test_random_grid_cell_assignment() -> None: - ds = CustomGeoDataset(BoundingBox(0, 12, 0, 12, 0, 0)) | CustomGeoDataset( - BoundingBox(12, 24, 0, 12, 0, 0) + ds = CustomGeoDataset( + [ + (BoundingBox(0, 12, 0, 12, 0, 0), "a"), + (BoundingBox(12, 24, 0, 12, 0, 0), "b"), + ] ) train_ds, val_ds, test_ds = random_grid_cell_assignment( ds, fractions=[1 / 2, 1 / 4, 1 / 4], grid_size=5 ) + # Check datasets lengths assert len(train_ds) == 1 / 2 * 2 * 5**2 + 1 assert len(val_ds) == floor(1 / 4 * 2 * 5**2) assert len(test_ds) == floor(1 / 4 * 2 * 5**2) + + # No overlap assert len(train_ds & val_ds) == 0 assert len(val_ds & test_ds) == 0 assert len(test_ds & train_ds) == 0 - assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds)) + + # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds + assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds)) + + # Test __get_item__ + x = train_ds[train_ds.bounds] + assert isinstance(x, dict) + assert isinstance(x["content"], str) # Test invalid input fractions with pytest.raises(ValueError, match="Sum of input fractions must equal 1."): @@ -156,11 +211,13 @@ def test_random_grid_cell_assignment() -> None: def test_roi_split() -> None: - ds = ( - CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 0)) - | CustomGeoDataset(BoundingBox(1, 2, 0, 1, 0, 0)) - | CustomGeoDataset(BoundingBox(2, 3, 0, 1, 0, 0)) - | CustomGeoDataset(BoundingBox(3, 4, 0, 1, 0, 0)) + ds = CustomGeoDataset( + [ + (BoundingBox(0, 1, 0, 1, 0, 0), "a"), + (BoundingBox(1, 2, 0, 1, 0, 0), "b"), + (BoundingBox(2, 3, 0, 1, 0, 0), "c"), + (BoundingBox(3, 4, 0, 1, 0, 0), "d"), + ] ) train_ds, val_ds, test_ds = roi_split( @@ -171,13 +228,25 @@ def test_roi_split() -> None: BoundingBox(3.5, 4, 0, 1, 0, 0), ], ) + + # Check datasets lengths assert len(train_ds) == 2 assert len(val_ds) == 2 assert len(test_ds) == 1 + + # No overlap assert len(train_ds & val_ds) == 0 assert len(val_ds & test_ds) == 0 assert len(test_ds & train_ds) == 0 + + # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds + assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds)) + + # Test __get_item__ + x = train_ds[train_ds.bounds] + assert isinstance(x, dict) + assert isinstance(x["content"], str) # Test invalid input rois with pytest.raises(ValueError, match="ROIs in input rois can't overlap."): @@ -186,80 +255,82 @@ def test_roi_split() -> None: ) -def test_time_series_split() -> None: - ds = ( - CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 10)) - | CustomGeoDataset(BoundingBox(0, 1, 0, 1, 10, 20)) - | CustomGeoDataset(BoundingBox(0, 1, 0, 1, 20, 30)) - | CustomGeoDataset(BoundingBox(0, 1, 0, 1, 30, 40)) +@pytest.mark.parametrize( + "lengths,expected_lengths", + [ + # List of timestamps + ([(0, 20), (20, 35), (35, 40)], [2, 2, 1]), + # List of lengths + ([20, 15, 5], [2, 2, 1]), + # List of fractions (with remainder) + ([1 / 2, 3 / 8, 1 / 8], [2, 2, 1]), + ], +) +def test_time_series_split( + lengths: Sequence[Union[Tuple[int, int], int, float]], + expected_lengths: Sequence[int], +) -> None: + ds = CustomGeoDataset( + [ + (BoundingBox(0, 1, 0, 1, 0, 10), "a"), + (BoundingBox(0, 1, 0, 1, 10, 20), "b"), + (BoundingBox(0, 1, 0, 1, 20, 30), "c"), + (BoundingBox(0, 1, 0, 1, 30, 40), "d"), + ] ) - # Test lengths input using timestamps - train_ds, val_ds, test_ds = time_series_split( - ds, lengths=[(0, 20), (20, 35), (35, 40)] - ) + train_ds, val_ds, test_ds = time_series_split(ds, lengths) - assert len(train_ds) == 2 - assert len(val_ds) == 2 - assert len(test_ds) == 1 - assert len(train_ds & val_ds) == 0 - assert len(val_ds & test_ds) == 0 - assert len(test_ds & train_ds) == 0 - assert (train_ds | val_ds | test_ds).bounds == ds.bounds + # Check datasets lengths + assert len(train_ds) == expected_lengths[0] + assert len(val_ds) == expected_lengths[1] + assert len(test_ds) == expected_lengths[2] - # Test lengths input using lengths - train_ds, val_ds, test_ds = time_series_split(ds, lengths=[20, 15, 5]) - - assert len(train_ds) == 2 - assert len(val_ds) == 2 - assert len(test_ds) == 1 + # No overlap assert len(train_ds & val_ds) == 0 assert len(val_ds & test_ds) == 0 assert len(test_ds & train_ds) == 0 + + # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds - # Test lengths input using fractions - train_ds, val_ds, test_ds = time_series_split(ds, lengths=[1 / 2, 3 / 8, 1 / 8]) + # Test __get_item__ + x = train_ds[train_ds.bounds] + assert isinstance(x, dict) + assert isinstance(x["content"], str) - assert len(train_ds) == 2 - assert len(val_ds) == 2 - assert len(test_ds) == 1 - assert len(train_ds & val_ds) == 0 - assert len(val_ds & test_ds) == 0 - assert len(test_ds & train_ds) == 0 - assert (train_ds | val_ds | test_ds).bounds == ds.bounds - # Test invalid input lengths +def test_time_series_split_invalid_input() -> None: with pytest.raises( ValueError, match="Pairs of timestamps in lengths must have end greater than start.", ): - time_series_split(ds, lengths=[(0, 20), (35, 20), (35, 40)]) + time_series_split(CustomGeoDataset(), lengths=[(0, 20), (35, 20), (35, 40)]) with pytest.raises( ValueError, match="Pairs of timestamps in lengths must cover dataset's time bounds.", ): - time_series_split(ds, lengths=[(0, 20), (20, 35)]) + time_series_split(CustomGeoDataset(), lengths=[(0, 20), (20, 35)]) with pytest.raises( ValueError, match="Pairs of timestamps in lengths can't be out of dataset's time bounds.", ): - time_series_split(ds, lengths=[(0, 20), (20, 45)]) + time_series_split(CustomGeoDataset(), lengths=[(0, 20), (20, 45)]) with pytest.raises( ValueError, match="Pairs of timestamps in lengths can't overlap." ): - time_series_split(ds, lengths=[(0, 10), (10, 20), (15, 40)]) + time_series_split(CustomGeoDataset(), lengths=[(0, 10), (10, 20), (15, 40)]) with pytest.raises( ValueError, match="Sum of input lengths must equal 1 or the dataset's time length.", ): - time_series_split(ds, lengths=[1 / 2, 1 / 2, 1 / 2]) + time_series_split(CustomGeoDataset(), lengths=[1 / 2, 1 / 2, 1 / 2]) with pytest.raises( ValueError, match="All items in input lengths must be greater than 0." ): - time_series_split(ds, lengths=[20, 25, -5]) + time_series_split(CustomGeoDataset(), lengths=[20, 25, -5]) From 6cda093b6aee4442acd8216def567fe52fca5f59 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 4 Jan 2023 15:17:26 -0300 Subject: [PATCH 29/49] small type fix --- tests/datasets/test_splits.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 7d49d33ecac..1dfdacd3219 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -34,7 +34,7 @@ def __init__( self._crs = crs self.res = res - def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]: + def __getitem__(self, query: BoundingBox) -> Dict[str, str]: hits = self.index.intersection(tuple(query), objects=True) hit = next(iter(hits)) return {"content": hit.object} From 5d87c71a336ea6fd992fcae03a001d06b0d453c9 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 4 Jan 2023 15:25:56 -0300 Subject: [PATCH 30/49] fix again --- tests/datasets/test_splits.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 1dfdacd3219..33100e6448b 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from math import floor, isclose -from typing import Dict, List, Sequence, Tuple, Union +from typing import Any, Dict, List, Sequence, Tuple, Union import pytest import torch @@ -34,7 +34,7 @@ def __init__( self._crs = crs self.res = res - def __getitem__(self, query: BoundingBox) -> Dict[str, str]: + def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: hits = self.index.intersection(tuple(query), objects=True) hit = next(iter(hits)) return {"content": hit.object} From 7adca072f6f525fecdf0ef884c6237a5cac40f9a Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 4 Jan 2023 15:53:39 -0300 Subject: [PATCH 31/49] rm .DS_Store --- .DS_Store | Bin 8196 -> 0 bytes tests/.DS_Store | Bin 8196 -> 0 bytes tests/data/.DS_Store | Bin 10244 -> 0 bytes 3 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .DS_Store delete mode 100644 tests/.DS_Store delete mode 100644 tests/data/.DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 85b0a68d3f8560363a6f44985ea0a211dacc1b69..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeI1TTc@~6vxj%ku3#8qmdVrO?*WJG+yFMxtWkC7z!Gr2D@}AEA1|}+ae*DKI?bz zEBNXs@niVzlm5?{rJ3EfJ{kiv&CHqIJ^#7<=CaH}L~5y3nJ1biA_tM}Q~^n!z|T4M zgrVe)3#*_f@@bbMYSSzo_jubImVhN-30MM_fFxHZR8d=-S^}0pF98>~#&KTfb9*Iqey>w)trqQiwa9}HFRMkw z+20b7Bg!YbHgpXtQ?{7TR@` z0TRhoNPD=d!si&}^|EJ@W@!^r+OZ~rbYMG#t*N{MCAXrC@s7FOun}6+98|Yr!e$BM z7z-b)xYu=*)1U+0nm(-4x@B6UEs)uQ-Jo@-cG2&>v`%=Nx~sGd?-g1_FE{83Qodft zoo(c{P)eOD=`xtUnl$L=ggGc@{J=Him(WQ2F~yYMWg*bUm^DBzL<^WmEDU=_j|ryK z1$ehXu7(-PGu$7b!OL&Q2Jz08l|TgnPre(rxLBE z84$BFy>n=PggG0f{e+e2Hz_NH=YB-*(fVEC6yvHtDN~!LO|)haWBm-a&qyn*=v1r5 zQIAg_K$geG|0^f6rULgEecwinJWF5ot4)agG>GM~ZhZ8og+A)pTusvuYNlcg^=#8~ z{5DWi-`Oh!P0fhRktWj2Xv@qR;OQMT50l;dyLVU%n(^G+4>dHL8@V*I-(IZu{~t~`h}{4H diff --git a/tests/.DS_Store b/tests/.DS_Store deleted file mode 100644 index b68d86b5d60aa1300a78c447d6dd205f3e03a50c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHM&u`O66n?LS)*)N!Lc0_Yl0_~^RBQ`~U4&MIz-p-wYPYFMD}H31WRhBSXCm8a zDOIJM;lv-n-uJZEmG}qt!j&81Pk<{Y_RS9$ns?xe8YQ9k9ysQId$c&7UMdAYlkU*V27oaW zXP@W3p0%x2z$$Pc6_ED_kE*m4WBb~wTL(2N1t43XVO6M4d4S@0##W5&YpbZ_Q)dro zsAy6Q)O4EnB%Ex;*uJ)!4yfsX=FDg^6jWy?o~Ph|6>V*66|f2nDj;+BlFHcw;GgC1 z+jF64HbT)Lgxpj&s9yM}3iVP}7q+1ZKGii~(oc@{PmWa}q@XjkKgZ#j0bGl`CgA6A z&-TMLw&Jdd_TNS~4B}*F<~0r-DjXgjaYmey^RTuiwrYu=^x~=?-Iqs8LWGU9?>CkM zwpDj0ZVI0G0cVjS1T9982loQr5?fW#b6YkI;dI>ov>wX-uj^UJh%arWokX>aH1 z#hGdE@}(=gyG3WR{L{6?`g*v*#Z!u%3YuFOjOydPR<5l8X5B$^)2s(yrR#a?y6K9F zM@m189vySXj~ySM7@wT{@x;{R$y29K)!pL_zq#ItWTZDmD57fEWA{5z(5iL)X23fu z345H>f+wu*9?L-&S?kfKp{4CEVJ1EeD0ja{jB({}zTJ*lvXbeq*4?qV8Hq4l5OI(M zg0DV^N$0tQZOrkupfhfg#`9e{6)&?WCOg1cS@omH=N(qvei-@<7S8jkND{HJ5;O(Z zk&h|2bhr>P>9TzNy_T_Dxs!Iz>5dC!-2o_9Vo&v*iFC+VT;hVWV!5)Q2a5DKI!|Bu zyRZ!F@C2U23-}ve!hi4v-oXbPK^M>BdAy9j;&r@@f8ZiEaSa2Ew1c7aBFis!reI75 zdlz}L7V%K<4vk5CDZTixL7SunJqT%2#kno;fsoma=Xl&v5pm@$NWMDE-Fi@N zj4ab3PKsyL%SDD#p)CH0!E-uWo7NstlsS%4Da#=jwpazeX9W)F#Z8sS{ zyEd!9fmZ+vw`#YlBt!a}C9jCgwLPkTQk6yG`r0ZAH7Xq^tI~0@dw&>G-y^8xRE+Iw UOCwai{)+(n{8z%;=0Gd(KaM1H`Tzg` diff --git a/tests/data/.DS_Store b/tests/data/.DS_Store deleted file mode 100644 index 21f1469a02093ca4271ac8818209732394a3674c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10244 zcmeHN&u`;I6n<{g*lD&)wJZoBfg&Lx4%L!$7g&U}Dy4hiz-5czz>lmGr>(WF9mUCp zQWfj$+0wF0#QwF0#QwF3VK1(36OL~Ud1 zYpp=7K&`-10nQH{oU{j(y*5^-4qS8#0NcQ0J#deGfO0yOJ*e!pu}XzcvwH|CYEW%4 z0z=394&kr|mAy9B&?PW*396EVYKIbN>ELw~F2RAZ^|e-@R$x^D?%hL)C6VX2TKo56 zfHxdWALG3Dw;s3?c@?-*8Ax9;Nx>T7ZY~9W61l*Af+x9@csG>`^DSWM>sQZL3%*L= zMlyjFM3?)Kq4^0hi^d0TR5n6Mqi~&l@w1tV* zf$=hNeW=2=j}SgifZmrb&V3incnoe^+J^9$;D^eNrS#B25QFYZ;FA1rN} z&@tAw_WY40RK~{M0=9r>wseZR#HuJpnpbyfWAhs5X^u&MbWsxO23a(8JiWV_o=o3ZE(v_W&X63X#<^8K~*zG6r^mo$D&3yYfD*-gV;rC z#MsF8YSgnQJQg*wS<#ZEy(@4tm$cQlE`)Eq%73 zdI~&s;kaiMdHxwe4yOqYvj9iDr$6;^u#L*+*sTXg(t6xrUXBNu&al4?##vfUyWKy6 zjm^eZGmOHh9es0rR?Ls5@pM)m#*?r3+vi1*4HozD;EQxL@9*CIq{ydnnvW(*NRtsD zUw@hANijbxW_eO7_t?CMT2ZUN``YDYulsPXv)}vRYOizo;eL0o^WgsbS68j*mABsg z=t=)scAghMVlKoH1jfovWful*!)n)PIevaT$+IG7b|LMj`UD`v++XKZFyX5tf^%rnduD>FXh#PPEO3^QAfaXkYgL!WAv@q1OH=65S< z8eo88Vjn+P zna2^|4!!{#*BOP|@&NifxKMF9GQ&Q!_oAHIJLp!0HY*>_HS3i9#ZtBjMgiaJGA5O= zu@O*=P5d3}rl@rAFKI-wpBQai7tYZGT=nD}GhWN7kXzae%yXb8TA~+ss+5fnF&^0p zuGJ@Nm9o(-#igi~8NH~JTOYKJ@HpVnCl&6Q@#T> zRc*Y!;(UK2_P$#Ks^MT;@;&sRb{u~??py|A$6+jH;@Z{<)Cz1G2|8`q$2b4|zZR-i zpjO}oRv>5`9UmQ{fs2ic-&THW@8EojlLv9v#wrCD{TvUgpX2chKgYj=lh`gT58NMA g_S#rH!SNsbF`(Z6>wjt4Uo4$HV7>o;!T10F07e*Rc>n+a From e19c7adcc7e32069dc940b9f19c29b033e6e6b61 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 15 Feb 2023 14:41:20 -0300 Subject: [PATCH 32/49] fix typo Co-authored-by: Adam J. Stewart --- torchgeo/datasets/splits.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index ae9dddb29bc..9afe6d45230 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -296,7 +296,7 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDatas def time_series_split( dataset: GeoDataset, lengths: Sequence[Union[int, float, Tuple[int, int]]] ) -> List[GeoDataset]: - """Split a GeoDataset on it's time dimension to create non-overlapping GeoDatasets. + """Split a GeoDataset on its time dimension to create non-overlapping GeoDatasets. Args: dataset: dataset to be split From cab469f01a61dade58f1d798723434fdb4c6be19 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 15 Feb 2023 14:43:52 -0300 Subject: [PATCH 33/49] bump version added --- torchgeo/datasets/splits.py | 16 ++++++++-------- torchgeo/datasets/utils.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 9afe6d45230..e1802046f79 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -35,7 +35,7 @@ def _fractions_to_lengths(fractions: Sequence[float], total: int) -> Sequence[in Returns: List of lengths. - .. versionadded:: 0.4 + .. versionadded:: 0.5 """ lengths = [floor(frac * total) for frac in fractions] remainder = int(total - sum(lengths)) @@ -56,7 +56,7 @@ def _create_geodataset_like(dataset: GeoDataset, index: Index) -> GeoDataset: Returns: A new GeoDataset. - .. versionadded:: 0.4 + .. versionadded:: 0.5 """ new_dataset = deepcopy(dataset) new_dataset.index = index @@ -78,7 +78,7 @@ def random_nongeo_split( Returns: A list of the subset datasets. - .. versionadded:: 0.4 + .. versionadded:: 0.5 """ if sum(lengths) == 1: lengths = _fractions_to_lengths(lengths, len(dataset)) @@ -100,7 +100,7 @@ def random_bbox_assignment( Returns A list of the subset datasets. - .. versionadded:: 0.4 + .. versionadded:: 0.5 """ if not (sum(lengths) == 1 or sum(lengths) == len(dataset)): raise ValueError( @@ -150,7 +150,7 @@ def random_bbox_splitting( Returns A list of the subset datasets. - .. versionadded:: 0.4 + .. versionadded:: 0.5 """ if sum(fractions) != 1: raise ValueError("Sum of input fractions must equal 1.") @@ -207,7 +207,7 @@ def random_grid_cell_assignment( Returns A list of the subset datasets. - .. versionadded:: 0.4 + .. versionadded:: 0.5 """ if sum(fractions) != 1: raise ValueError("Sum of input fractions must equal 1.") @@ -272,7 +272,7 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDatas Returns A list of the subset datasets. - .. versionadded:: 0.4 + .. versionadded:: 0.5 """ new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in rois @@ -306,7 +306,7 @@ def time_series_split( Returns A list of the subset datasets. - .. versionadded:: 0.4 + .. versionadded:: 0.5 """ minx, maxx, miny, maxy, mint, maxt = dataset.bounds diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 9610910e78b..a9dc6627264 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -408,7 +408,7 @@ def split( Returns: A tuple with the resulting BoundingBoxes - .. versionadded:: 0.4 + .. versionadded:: 0.5 """ if not (0.0 < proportion < 1.0): raise ValueError("Input proportion must be between 0 and 1.") From 771f88599350de357f78a4d9d0b563f8b6567648 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 15 Feb 2023 14:48:26 -0300 Subject: [PATCH 34/49] add to __init__ --- torchgeo/datasets/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 01d0d70176e..724fafad675 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -92,6 +92,13 @@ SpaceNet6, SpaceNet7, ) +from .splits import ( + random_bbox_assignment, + random_bbox_splitting, + random_grid_cell_assignment, + roi_split, + time_series_split, +) from .ucmerced import UCMerced from .usavars import USAVars from .utils import ( @@ -207,4 +214,10 @@ "merge_samples", "stack_samples", "unbind_samples", + # Splits + "random_bbox_assignment", + "random_bbox_splitting", + "random_grid_cell_assignment", + "roi_split", + "time_series_split", ) From 41f730861f5a50cc517a98f60fd7901bd161d5b6 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 15 Feb 2023 14:52:49 -0300 Subject: [PATCH 35/49] add to datasets.rst --- docs/api/datasets.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index af761f33d90..87549c9719f 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -381,3 +381,12 @@ Collation Functions .. autofunction:: concat_samples .. autofunction:: merge_samples .. autofunction:: unbind_samples + +Splitting Functions +^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: random_bbox_assignment +.. autofunction:: random_bbox_splitting +.. autofunction:: random_grid_cell_assignment +.. autofunction:: roi_split +.. autofunction:: time_series_split From 71f75031f4fb9d8019437209a6aa1d25f040c883 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 15 Feb 2023 15:11:35 -0300 Subject: [PATCH 36/49] use accumulate from itertools --- torchgeo/datasets/splits.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index e1802046f79..d9d693490ba 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -4,12 +4,12 @@ """Dataset splitting utilities.""" from copy import deepcopy +from itertools import accumulate from math import floor from typing import Any, List, Optional, Sequence, Tuple, Union from rtree.index import Index, Property from torch import Generator, default_generator, randint, randperm -from torch._utils import _accumulate from torch.utils.data import Subset, TensorDataset, random_split from ..datasets import GeoDataset, NonGeoDataset @@ -326,10 +326,8 @@ def time_series_split( lengths = [totalt * f for f in lengths] # type: ignore[operator] lengths = [ - (mint + offset - length, mint + offset) - for offset, length in zip( - _accumulate(lengths), lengths # type: ignore[no-untyped-call] - ) + (mint + offset - length, mint + offset) # type: ignore[operator, misc] + for offset, length in zip(accumulate(lengths), lengths) ] new_indexes = [ From 91445b7f1ab177b59bec42bb98b2702c1f914221 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 15 Feb 2023 15:15:07 -0300 Subject: [PATCH 37/49] clarify grid_size --- torchgeo/datasets/splits.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index d9d693490ba..cafe70d851d 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -201,7 +201,7 @@ def random_grid_cell_assignment( Args: dataset: dataset to be split fractions: fractions of splits to be produced - grid_size: (optional) size of the grid + grid_size: (optional) number of rows and columns for the grid generator: (optional) generator used for the random permutation Returns From 4fcd141bca0019812303caeb0856447905ad1e88 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 15 Feb 2023 15:58:54 -0300 Subject: [PATCH 38/49] remove random_nongeo_split --- tests/datasets/test_splits.py | 21 --------------------- torchgeo/datasets/splits.py | 23 ----------------------- 2 files changed, 44 deletions(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 33100e6448b..e8d848d2456 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -5,16 +5,13 @@ from typing import Any, Dict, List, Sequence, Tuple, Union import pytest -import torch from rasterio.crs import CRS -from torch.utils.data import TensorDataset from torchgeo.datasets import GeoDataset from torchgeo.datasets.splits import ( random_bbox_assignment, random_bbox_splitting, random_grid_cell_assignment, - random_nongeo_split, roi_split, time_series_split, ) @@ -40,24 +37,6 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: return {"content": hit.object} -def test_random_nongeo_split() -> None: - num_samples = 26 - x = torch.ones(num_samples, 5) - y = torch.randint(low=0, high=2, size=(num_samples,)) - ds = TensorDataset(x, y) - - # Test only train/val set split - train_ds, val_ds = random_nongeo_split(ds, lengths=[1 / 2, 1 / 2]) - assert len(train_ds) == round(num_samples / 2) - assert len(val_ds) == round(num_samples / 2) - - # Test train/val/test set split with remainder - train_ds, val_ds, test_ds = random_nongeo_split(ds, lengths=[1 / 3, 1 / 3, 1 / 3]) - assert len(train_ds) == floor(num_samples / 3) + 1 - assert len(val_ds) == floor(num_samples / 3) + 1 - assert len(test_ds) == floor(num_samples / 3) - - @pytest.mark.parametrize( "lengths,expected_lengths", [ diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index cafe70d851d..7a785e3a9de 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -16,7 +16,6 @@ from .utils import BoundingBox __all__ = ( - "random_nongeo_split", "random_bbox_assignment", "random_bbox_splitting", "random_grid_cell_assignment", @@ -63,28 +62,6 @@ def _create_geodataset_like(dataset: GeoDataset, index: Index) -> GeoDataset: return new_dataset -def random_nongeo_split( - dataset: Union[TensorDataset, NonGeoDataset], - lengths: Sequence[Union[int, float]], - generator: Optional[Generator] = default_generator, -) -> List[Subset[Any]]: - """Randomly split a NonGeoDataset into non-overlapping new NonGeoDatasets. - - Args: - dataset: dataset to be split - lengths: lengths or fractions of splits to be produced - generator: (optional) generator used for the random permutation - - Returns: - A list of the subset datasets. - - .. versionadded:: 0.5 - """ - if sum(lengths) == 1: - lengths = _fractions_to_lengths(lengths, len(dataset)) - return random_split(dataset, lengths, generator) - - def random_bbox_assignment( dataset: GeoDataset, lengths: Sequence[Union[int, float]], From fa513b5ca5c7e08d54ef070df1024b967b3b1ba5 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 15 Feb 2023 16:09:20 -0300 Subject: [PATCH 39/49] remove _create_geodataset_like --- torchgeo/datasets/splits.py | 62 ++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 7a785e3a9de..4b0c54124c2 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -6,13 +6,12 @@ from copy import deepcopy from itertools import accumulate from math import floor -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union from rtree.index import Index, Property from torch import Generator, default_generator, randint, randperm -from torch.utils.data import Subset, TensorDataset, random_split -from ..datasets import GeoDataset, NonGeoDataset +from ..datasets import GeoDataset from .utils import BoundingBox __all__ = ( @@ -45,23 +44,6 @@ def _fractions_to_lengths(fractions: Sequence[float], total: int) -> Sequence[in return lengths -def _create_geodataset_like(dataset: GeoDataset, index: Index) -> GeoDataset: - """Utility to create a new GeoDataset from an existing one with a different index. - - Args: - dataset: dataset to copy - index: new index - - Returns: - A new GeoDataset. - - .. versionadded:: 0.5 - """ - new_dataset = deepcopy(dataset) - new_dataset.index = index - return new_dataset - - def random_bbox_assignment( dataset: GeoDataset, lengths: Sequence[Union[int, float]], @@ -106,7 +88,13 @@ def random_bbox_assignment( hit = hits.pop() new_indexes[i].insert(j, hit.bounds, hit.object) - return [_create_geodataset_like(dataset, index) for index in new_indexes] + new_datasets = [] + for index in new_indexes: + ds = deepcopy(dataset) + ds.index = index + new_datasets.append(ds) + + return new_datasets def random_bbox_splitting( @@ -161,7 +149,13 @@ def random_bbox_splitting( fraction_left -= frac horizontal = not horizontal - return [_create_geodataset_like(dataset, index) for index in new_indexes] + new_datasets = [] + for index in new_indexes: + ds = deepcopy(dataset) + ds.index = index + new_datasets.append(ds) + + return new_datasets def random_grid_cell_assignment( @@ -236,7 +230,13 @@ def random_grid_cell_assignment( cell = cells.pop() new_indexes[i].insert(j, cell[0], cell[1]) - return [_create_geodataset_like(dataset, index) for index in new_indexes] + new_datasets = [] + for index in new_indexes: + ds = deepcopy(dataset) + ds.index = index + new_datasets.append(ds) + + return new_datasets def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDataset]: @@ -267,7 +267,13 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDatas new_indexes[i].insert(j, tuple(new_box), hit.object) j += 1 - return [_create_geodataset_like(dataset, index) for index in new_indexes] + new_datasets = [] + for index in new_indexes: + ds = deepcopy(dataset) + ds.index = index + new_datasets.append(ds) + + return new_datasets def time_series_split( @@ -347,4 +353,10 @@ def time_series_split( "Pairs of timestamps in lengths must cover dataset's time bounds." ) - return [_create_geodataset_like(dataset, index) for index in new_indexes] + new_datasets = [] + for index in new_indexes: + ds = deepcopy(dataset) + ds.index = index + new_datasets.append(ds) + + return new_datasets From 3f46921c2c8f2d7486def0fec3e86764cf0fda4e Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Wed, 15 Feb 2023 16:16:05 -0300 Subject: [PATCH 40/49] black reformatting --- tests/datasets/test_splits.py | 1 - torchgeo/datasets/splits.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index e8d848d2456..605a6b8649a 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -92,7 +92,6 @@ def test_random_bbox_assignment_invalid_inputs() -> None: def _get_total_area(dataset: GeoDataset) -> float: - total_area = 0.0 for hit in dataset.index.intersection(dataset.index.bounds, objects=True): total_area += BoundingBox(*hit.bounds).area diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 4b0c54124c2..df0d44b4c44 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -135,7 +135,6 @@ def random_bbox_splitting( horizontal, flip = randint(0, 2, (2,), generator=generator) for j, frac in enumerate(fractions): - if fraction_left == frac: new_box = box elif flip: @@ -296,7 +295,6 @@ def time_series_split( totalt = maxt - mint if not all(isinstance(x, tuple) for x in lengths): - if not (sum(lengths) == 1 or sum(lengths) == totalt): # type: ignore[arg-type] raise ValueError( "Sum of input lengths must equal 1 or the dataset's time length." @@ -319,7 +317,6 @@ def time_series_split( _totalt = 0 for i, (start, end) in enumerate(lengths): # type: ignore[misc] - if start >= end: raise ValueError( "Pairs of timestamps in lengths must have end greater than start." From 5d2f866f34e953950ee41c7e191d58150d29d9fd Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Mon, 20 Feb 2023 10:16:29 -0300 Subject: [PATCH 41/49] Update tests/datasets/test_splits.py Co-authored-by: Adam J. Stewart --- tests/datasets/test_splits.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 605a6b8649a..6498a15f3e2 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -73,7 +73,7 @@ def test_random_bbox_assignment( # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds - # Test __get_item__ + # Test __getitem__ x = train_ds[train_ds.bounds] assert isinstance(x, dict) assert isinstance(x["content"], str) From 28c24fbb7fcc2419f831ed0678851a3f2303e138 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Mon, 20 Feb 2023 10:16:15 -0300 Subject: [PATCH 42/49] change import --- tests/datasets/test_splits.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 6498a15f3e2..72130329080 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -7,15 +7,15 @@ import pytest from rasterio.crs import CRS -from torchgeo.datasets import GeoDataset -from torchgeo.datasets.splits import ( +from torchgeo.datasets import ( + BoundingBox, + GeoDataset, random_bbox_assignment, random_bbox_splitting, random_grid_cell_assignment, roi_split, time_series_split, ) -from torchgeo.datasets.utils import BoundingBox class CustomGeoDataset(GeoDataset): From 358a679d57f33cf5edf527006cbb88039fa7bf61 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Mon, 20 Feb 2023 12:55:00 -0300 Subject: [PATCH 43/49] docstrings --- torchgeo/datasets/splits.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index df0d44b4c44..1abf29e9eab 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -26,6 +26,8 @@ def _fractions_to_lengths(fractions: Sequence[float], total: int) -> Sequence[int]: """Utility to divide a number into a list of integers according to fractions. + Implementation based on :meth:`torch.utils.data.random_split`. + Args: fractions: list of fractions total: total to be divided @@ -51,6 +53,9 @@ def random_bbox_assignment( ) -> List[GeoDataset]: """Split a GeoDataset randomly assigning its index's BoundingBoxes. + This function will go through each BoundingBox in the GeoDataset's index and + randomly assign it to new GeoDatasets. + Args: dataset: dataset to be split lengths: lengths or fractions of splits to be produced @@ -104,13 +109,14 @@ def random_bbox_splitting( ) -> List[GeoDataset]: """Split a GeoDataset randomly splitting its index's BoundingBoxes. - This function will go through each BoundingBox in the GeoDataset's index and - split it in a random direction. + This function will go through each BoundingBox in the GeoDataset's index, + split it in a random direction and assign the resulting BoundingBoxes to + new GeoDatasets. Args: dataset: dataset to be split fractions: fractions of splits to be produced - generator: (optional) generator used for the random permutation + generator: generator used for the random permutation Returns A list of the subset datasets. @@ -171,8 +177,8 @@ def random_grid_cell_assignment( Args: dataset: dataset to be split fractions: fractions of splits to be produced - grid_size: (optional) number of rows and columns for the grid - generator: (optional) generator used for the random permutation + grid_size: number of rows and columns for the grid + generator: generator used for the random permutation Returns A list of the subset datasets. From 3610c1f708f3c6e2cf2f5e0f20a7dd92344b1754 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Mon, 20 Feb 2023 15:02:11 -0300 Subject: [PATCH 44/49] undo intersection change --- tests/datasets/test_geo.py | 6 ------ tests/datasets/test_splits.py | 24 ++++++++++++------------ torchgeo/datasets/geo.py | 6 ++---- 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 61ed6f4c780..a24dd3653f2 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -421,12 +421,6 @@ def test_no_overlap(self) -> None: ds = IntersectionDataset(ds1, ds2) assert len(ds) == 0 - def test_contiguous(self) -> None: - ds1 = CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 1)) - ds2 = CustomGeoDataset(BoundingBox(1, 2, 0, 1, 0, 1)) - ds = IntersectionDataset(ds1, ds2) - assert len(ds) == 0 - def test_invalid_query(self, dataset: IntersectionDataset) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 72130329080..cba52a54737 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -66,9 +66,9 @@ def test_random_bbox_assignment( assert len(test_ds) == expected_lengths[2] # No overlap - assert len(train_ds & val_ds) == 0 - assert len(val_ds & test_ds) == 0 - assert len(test_ds & train_ds) == 0 + assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) + assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) + assert len(test_ds & train_ds) == 0 or isclose(_get_total_area(test_ds & train_ds), 0) # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds @@ -124,9 +124,9 @@ def test_random_bbox_splitting() -> None: assert test_ds_area == ds_area / 4 # No overlap - assert len(train_ds & val_ds) == 0 - assert len(val_ds & test_ds) == 0 - assert len(test_ds & train_ds) == 0 + assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) + assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) + assert len(test_ds & train_ds) == 0 or isclose(_get_total_area(test_ds & train_ds), 0) # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds @@ -164,9 +164,9 @@ def test_random_grid_cell_assignment() -> None: assert len(test_ds) == floor(1 / 4 * 2 * 5**2) # No overlap - assert len(train_ds & val_ds) == 0 - assert len(val_ds & test_ds) == 0 - assert len(test_ds & train_ds) == 0 + assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) + assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) + assert len(test_ds & train_ds) == 0 or isclose(_get_total_area(test_ds & train_ds), 0) # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds @@ -213,9 +213,9 @@ def test_roi_split() -> None: assert len(test_ds) == 1 # No overlap - assert len(train_ds & val_ds) == 0 - assert len(val_ds & test_ds) == 0 - assert len(test_ds & train_ds) == 0 + assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) + assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) + assert len(test_ds & train_ds) == 0 or isclose(_get_total_area(test_ds & train_ds), 0) # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index c085092e497..96ae41e20a3 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -852,10 +852,8 @@ def _merge_dataset_indices(self) -> None: for hit2 in ds2.index.intersection(hit1.bounds, objects=True): box1 = BoundingBox(*hit1.bounds) box2 = BoundingBox(*hit2.bounds) - new_box = box1 & box2 - if new_box.area > 0: - self.index.insert(i, tuple(new_box)) - i += 1 + self.index.insert(i, tuple(box1 & box2)) + i += 1 def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: """Retrieve image and metadata indexed by query. From d589ed18170b528ab27d91ce04414f2f7da4b954 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Mon, 20 Feb 2023 15:33:48 -0300 Subject: [PATCH 45/49] use microsecond --- torchgeo/datasets/splits.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 1abf29e9eab..0967966937f 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -338,8 +338,8 @@ def time_series_split( ): raise ValueError("Pairs of timestamps in lengths can't overlap.") - # remove one second from each BoundingBox's maxt to avoid overlapping - offset = 0 if i == len(lengths) - 1 else 1 + # remove one microsecond from each BoundingBox's maxt to avoid overlapping + offset = 0 if i == len(lengths) - 1 else 1e-6 roi = BoundingBox(minx, maxx, miny, maxy, start, end - offset) j = 0 for hit in dataset.index.intersection(tuple(roi), objects=True): From d6fc7781df7e1a2c3834da4dc10f8f2eee1c4c82 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Mon, 20 Feb 2023 15:40:34 -0300 Subject: [PATCH 46/49] use isclose --- torchgeo/datasets/splits.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 0967966937f..01d8dc6750f 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -5,7 +5,7 @@ from copy import deepcopy from itertools import accumulate -from math import floor +from math import floor, isclose from typing import List, Optional, Sequence, Tuple, Union from rtree.index import Index, Property @@ -66,7 +66,7 @@ def random_bbox_assignment( .. versionadded:: 0.5 """ - if not (sum(lengths) == 1 or sum(lengths) == len(dataset)): + if not (isclose(sum(lengths), 1) or isclose(sum(lengths), len(dataset))): raise ValueError( "Sum of input lengths must equal 1 or the length of dataset's index." ) @@ -74,7 +74,7 @@ def random_bbox_assignment( if any(n <= 0 for n in lengths): raise ValueError("All items in input lengths must be greater than 0.") - if sum(lengths) == 1: + if isclose(sum(lengths), 1): lengths = _fractions_to_lengths(lengths, len(dataset)) hits = list(dataset.index.intersection(dataset.index.bounds, objects=True)) @@ -123,7 +123,7 @@ def random_bbox_splitting( .. versionadded:: 0.5 """ - if sum(fractions) != 1: + if not isclose(sum(fractions), 1): raise ValueError("Sum of input fractions must equal 1.") if any(n <= 0 for n in fractions): @@ -185,7 +185,7 @@ def random_grid_cell_assignment( .. versionadded:: 0.5 """ - if sum(fractions) != 1: + if not isclose(sum(fractions), 1): raise ValueError("Sum of input fractions must equal 1.") if any(n <= 0 for n in fractions): @@ -301,7 +301,7 @@ def time_series_split( totalt = maxt - mint if not all(isinstance(x, tuple) for x in lengths): - if not (sum(lengths) == 1 or sum(lengths) == totalt): # type: ignore[arg-type] + if not (isclose(sum(lengths), 1) or isclose(sum(lengths), totalt)): # type: ignore[arg-type] raise ValueError( "Sum of input lengths must equal 1 or the dataset's time length." ) @@ -309,7 +309,7 @@ def time_series_split( if any(n <= 0 for n in lengths): # type: ignore[operator] raise ValueError("All items in input lengths must be greater than 0.") - if sum(lengths) == 1: # type: ignore[arg-type] + if isclose(sum(lengths), 1): # type: ignore[arg-type] lengths = [totalt * f for f in lengths] # type: ignore[operator] lengths = [ @@ -351,7 +351,7 @@ def time_series_split( _totalt += end - start - if not _totalt == totalt: + if not isclose(_totalt, totalt): raise ValueError( "Pairs of timestamps in lengths must cover dataset's time bounds." ) From 9c9fd26580ee84019921a4db50f6717658d763f6 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Mon, 20 Feb 2023 16:04:32 -0300 Subject: [PATCH 47/49] black --- tests/datasets/test_splits.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index cba52a54737..c03cfbd103e 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -68,7 +68,9 @@ def test_random_bbox_assignment( # No overlap assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) - assert len(test_ds & train_ds) == 0 or isclose(_get_total_area(test_ds & train_ds), 0) + assert len(test_ds & train_ds) == 0 or isclose( + _get_total_area(test_ds & train_ds), 0 + ) # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds @@ -126,7 +128,9 @@ def test_random_bbox_splitting() -> None: # No overlap assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) - assert len(test_ds & train_ds) == 0 or isclose(_get_total_area(test_ds & train_ds), 0) + assert len(test_ds & train_ds) == 0 or isclose( + _get_total_area(test_ds & train_ds), 0 + ) # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds @@ -166,7 +170,9 @@ def test_random_grid_cell_assignment() -> None: # No overlap assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) - assert len(test_ds & train_ds) == 0 or isclose(_get_total_area(test_ds & train_ds), 0) + assert len(test_ds & train_ds) == 0 or isclose( + _get_total_area(test_ds & train_ds), 0 + ) # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds @@ -215,7 +221,9 @@ def test_roi_split() -> None: # No overlap assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) - assert len(test_ds & train_ds) == 0 or isclose(_get_total_area(test_ds & train_ds), 0) + assert len(test_ds & train_ds) == 0 or isclose( + _get_total_area(test_ds & train_ds), 0 + ) # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds From 15ffb2436bacd112dc1696bcd422b6a96f82e496 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Tue, 21 Feb 2023 11:02:24 -0300 Subject: [PATCH 48/49] fix typing --- torchgeo/datasets/splits.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 01d8dc6750f..7e512ce4db9 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -6,7 +6,7 @@ from copy import deepcopy from itertools import accumulate from math import floor, isclose -from typing import List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union, cast from rtree.index import Index, Property from torch import Generator, default_generator, randint, randperm @@ -48,7 +48,7 @@ def _fractions_to_lengths(fractions: Sequence[float], total: int) -> Sequence[in def random_bbox_assignment( dataset: GeoDataset, - lengths: Sequence[Union[int, float]], + lengths: Sequence[float], generator: Optional[Generator] = default_generator, ) -> List[GeoDataset]: """Split a GeoDataset randomly assigning its index's BoundingBoxes. @@ -76,20 +76,18 @@ def random_bbox_assignment( if isclose(sum(lengths), 1): lengths = _fractions_to_lengths(lengths, len(dataset)) + lengths = cast(Sequence[int], lengths) hits = list(dataset.index.intersection(dataset.index.bounds, objects=True)) - hits = [ - hits[i] - for i in randperm(sum(lengths), generator=generator) # type: ignore[arg-type] - ] + hits = [hits[i] for i in randperm(sum(lengths), generator=generator)] new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in lengths ] for i, length in enumerate(lengths): - for j in range(length): # type: ignore[arg-type] + for j in range(length): hit = hits.pop() new_indexes[i].insert(j, hit.bounds, hit.object) @@ -282,7 +280,7 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDatas def time_series_split( - dataset: GeoDataset, lengths: Sequence[Union[int, float, Tuple[int, int]]] + dataset: GeoDataset, lengths: Sequence[Union[float, Tuple[float, float]]] ) -> List[GeoDataset]: """Split a GeoDataset on its time dimension to create non-overlapping GeoDatasets. @@ -301,28 +299,32 @@ def time_series_split( totalt = maxt - mint if not all(isinstance(x, tuple) for x in lengths): - if not (isclose(sum(lengths), 1) or isclose(sum(lengths), totalt)): # type: ignore[arg-type] + lengths = cast(Sequence[float], lengths) + + if not (isclose(sum(lengths), 1) or isclose(sum(lengths), totalt)): raise ValueError( "Sum of input lengths must equal 1 or the dataset's time length." ) - if any(n <= 0 for n in lengths): # type: ignore[operator] + if any(n <= 0 for n in lengths): raise ValueError("All items in input lengths must be greater than 0.") - if isclose(sum(lengths), 1): # type: ignore[arg-type] - lengths = [totalt * f for f in lengths] # type: ignore[operator] + if isclose(sum(lengths), 1): + lengths = [totalt * f for f in lengths] lengths = [ - (mint + offset - length, mint + offset) # type: ignore[operator, misc] + (mint + offset - length, mint + offset) # type: ignore[operator] for offset, length in zip(accumulate(lengths), lengths) ] + lengths = cast(Sequence[Tuple[float, float]], lengths) + new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in lengths ] - _totalt = 0 - for i, (start, end) in enumerate(lengths): # type: ignore[misc] + _totalt = 0.0 + for i, (start, end) in enumerate(lengths): if start >= end: raise ValueError( "Pairs of timestamps in lengths must have end greater than start." @@ -333,9 +335,7 @@ def time_series_split( "Pairs of timestamps in lengths can't be out of dataset's time bounds." ) - if any( # type: ignore[misc] - start < x < end or start < y < end for x, y in lengths[i + 1 :] - ): + if any(start < x < end or start < y < end for x, y in lengths[i + 1 :]): raise ValueError("Pairs of timestamps in lengths can't overlap.") # remove one microsecond from each BoundingBox's maxt to avoid overlapping From 01821800f6a3e318fee5845e7d046b29a6a50ed1 Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Tue, 21 Feb 2023 11:33:01 -0300 Subject: [PATCH 49/49] add comments --- torchgeo/datasets/splits.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 7e512ce4db9..d85d56b63dd 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -39,7 +39,7 @@ def _fractions_to_lengths(fractions: Sequence[float], total: int) -> Sequence[in """ lengths = [floor(frac * total) for frac in fractions] remainder = int(total - sum(lengths)) - # add 1 to all the lengths in round-robin fashion until the remainder is 0 + # Add 1 to all the lengths in round-robin fashion until the remainder is 0 for i in range(remainder): idx_to_add_at = i % len(lengths) lengths[idx_to_add_at] += 1 @@ -137,19 +137,25 @@ def random_bbox_splitting( box = BoundingBox(*hit.bounds) fraction_left = 1.0 + # Randomly choose the split direction horizontal, flip = randint(0, 2, (2,), generator=generator) - for j, frac in enumerate(fractions): - if fraction_left == frac: + for j, fraction in enumerate(fractions): + if fraction_left == fraction: + # For the last fraction, no need to split again new_box = box elif flip: + # new_box corresponds to fraction, box is the remainder that we might + # split again in the next iteration. Each split is done according to + # fraction wrt what's left box, new_box = box.split( - (fraction_left - frac) / fraction_left, horizontal + (fraction_left - fraction) / fraction_left, horizontal ) else: - new_box, box = box.split(frac / fraction_left, horizontal) + # Same as above, but without flipping + new_box, box = box.split(fraction / fraction_left, horizontal) new_indexes[j].insert(i, tuple(new_box), hit.object) - fraction_left -= frac + fraction_left -= fraction horizontal = not horizontal new_datasets = [] @@ -200,6 +206,7 @@ def random_grid_cell_assignment( cells = [] + # Generate the grid's cells for each bbox in index for i, hit in enumerate( dataset.index.intersection(dataset.index.bounds, objects=True) ): @@ -226,6 +233,7 @@ def random_grid_cell_assignment( ] ) + # Randomly assign cells to each new index cells = [cells[i] for i in randperm(len(cells), generator=generator)] for i, length in enumerate(lengths): @@ -338,7 +346,7 @@ def time_series_split( if any(start < x < end or start < y < end for x, y in lengths[i + 1 :]): raise ValueError("Pairs of timestamps in lengths can't overlap.") - # remove one microsecond from each BoundingBox's maxt to avoid overlapping + # Remove one microsecond from each BoundingBox's maxt to avoid overlapping offset = 0 if i == len(lengths) - 1 else 1e-6 roi = BoundingBox(minx, maxx, miny, maxy, start, end - offset) j = 0