diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 8b5d8b210af..5c114bb86d7 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -98,6 +98,14 @@ def test_roi(self) -> None: for query in batch: assert query in roi + def test_small_area(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + ds.index.insert(1, (20, 21, 20, 21, 20, 21)) + sampler = RandomBatchGeoSampler(ds, 2, 2, 10) + for _ in sampler: + continue + @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) def test_dataloader(self, sampler: RandomBatchGeoSampler, num_workers: int) -> None: diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 936856df110..aa13b8b56aa 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -95,6 +95,14 @@ def test_roi(self) -> None: for query in sampler: assert query in roi + def test_small_area(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + ds.index.insert(1, (20, 21, 20, 21, 20, 21)) + sampler = RandomGeoSampler(ds, 2, 10) + for _ in sampler: + continue + @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) def test_dataloader(self, sampler: RandomGeoSampler, num_workers: int) -> None: @@ -145,6 +153,14 @@ def test_roi(self) -> None: for query in sampler: assert query in roi + def test_small_area(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + ds.index.insert(1, (20, 21, 20, 21, 20, 21)) + sampler = GridGeoSampler(ds, 2, 10) + for _ in sampler: + continue + @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) def test_dataloader(self, sampler: GridGeoSampler, num_workers: int) -> None: diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 3dc5dcc4b8e..488649b9785 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -93,7 +93,14 @@ def __init__( self.size = _to_tuple(size) self.batch_size = batch_size self.length = length - self.hits = list(self.index.intersection(tuple(self.roi), objects=True)) + self.hits = [] + for hit in self.index.intersection(tuple(self.roi), objects=True): + bounds = BoundingBox(*hit.bounds) + if ( + bounds.maxx - bounds.minx > self.size[1] + and bounds.maxy - bounds.miny > self.size[0] + ): + self.hits.append(hit) def __iter__(self) -> Iterator[List[BoundingBox]]: """Return the indices of a dataset. diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index d507f698e3b..1f8eb459f4e 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -93,7 +93,14 @@ def __init__( super().__init__(dataset, roi) self.size = _to_tuple(size) self.length = length - self.hits = list(self.index.intersection(tuple(self.roi), objects=True)) + self.hits = [] + for hit in self.index.intersection(tuple(self.roi), objects=True): + bounds = BoundingBox(*hit.bounds) + if ( + bounds.maxx - bounds.minx > self.size[1] + and bounds.maxy - bounds.miny > self.size[0] + ): + self.hits.append(hit) def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. @@ -161,7 +168,14 @@ def __init__( super().__init__(dataset, roi) self.size = _to_tuple(size) self.stride = _to_tuple(stride) - self.hits = list(self.index.intersection(tuple(self.roi), objects=True)) + self.hits = [] + for hit in self.index.intersection(tuple(self.roi), objects=True): + bounds = BoundingBox(*hit.bounds) + if ( + bounds.maxx - bounds.minx > self.size[1] + and bounds.maxy - bounds.miny > self.size[0] + ): + self.hits.append(hit) self.length: int = 0 for hit in self.hits: