diff --git a/benchmark.py b/benchmark.py index d0d499e5fd5..2b3465bc3a9 100755 --- a/benchmark.py +++ b/benchmark.py @@ -77,15 +77,16 @@ def set_up_parser() -> argparse.ArgumentParser: "--patch-size", default=224, type=int, - help="height/width of each patch", - metavar="SIZE", + help="height/width of each patch in pixels", + metavar="PIXELS", ) parser.add_argument( "-s", "--stride", default=112, type=int, - help="sampling stride for GridGeoSampler", + help="sampling stride for GridGeoSampler in pixels", + metavar="PIXELS", ) parser.add_argument( "-w", @@ -139,15 +140,11 @@ def main(args: argparse.Namespace) -> None: length = args.num_batches * args.batch_size num_batches = args.num_batches - # Convert from pixel coords to CRS coords - size = args.patch_size * cdl.res - stride = args.stride * cdl.res - samplers = [ - RandomGeoSampler(landsat, size=size, length=length), - GridGeoSampler(landsat, size=size, stride=stride), + RandomGeoSampler(landsat, size=args.patch_size, length=length), + GridGeoSampler(landsat, size=args.patch_size, stride=args.stride), RandomBatchGeoSampler( - landsat, size=size, batch_size=args.batch_size, length=length + landsat, size=args.patch_size, batch_size=args.batch_size, length=length ), ] diff --git a/docs/api/samplers.rst b/docs/api/samplers.rst index fa1ca65227e..be7fbefb910 100644 --- a/docs/api/samplers.rst +++ b/docs/api/samplers.rst @@ -16,10 +16,12 @@ Samplers are used to index a dataset, retrieving a single query at a time. For : from torchgeo.samplers import RandomGeoSampler dataset = Landsat(...) - sampler = RandomGeoSampler(dataset, size=1000, length=100) + sampler = RandomGeoSampler(dataset, size=256, length=10000) dataloader = DataLoader(dataset, sampler=sampler) +This data loader will return 256x256 px images, and has an epoch length of 10,000. + Random Geo Sampler ^^^^^^^^^^^^^^^^^^ @@ -43,10 +45,12 @@ When working with large tile-based datasets, randomly sampling patches from each from torchgeo.samplers import RandomBatchGeoSampler dataset = Landsat(...) - sampler = RandomBatchGeoSampler(dataset, size=1000, batch_size=10, length=100) + sampler = RandomBatchGeoSampler(dataset, size=256, batch_size=128, length=10000) dataloader = DataLoader(dataset, batch_sampler=sampler) +This data loader will return 256x256 px images, and has a batch size of 128 and an epoch length of 10,000. + Random Batch Geo Sampler ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -66,3 +70,24 @@ Batch Geo Sampler ^^^^^^^^^^^^^^^^^ .. autoclass:: BatchGeoSampler + +Units +----- + +By default, the ``size`` parameter specifies the size of the image in *pixel* units. If you would instead like to specify the size in *CRS* units, you can change the ``units`` parameter like so: + +.. code-block:: python + + from torch.utils.data import DataLoader + + from torchgeo.datasets import Landsat + from torchgeo.samplers import RandomGeoSampler, Units + + dataset = Landsat(...) + sampler = RandomGeoSampler(dataset, size=256 * 30, length=10000, units=Units.CRS) + dataloader = DataLoader(dataset, sampler=sampler) + + +Assuming that each pixel in the CRS is 30 m, this data loader will return 256x256 px images, and has an epoch length of 10,000. + +.. autoclass:: Units diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 5c114bb86d7..952f36dd15d 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import math +from itertools import product from typing import Dict, Iterator, List import pytest @@ -10,7 +11,7 @@ from torch.utils.data import DataLoader from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples -from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler +from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler, Units class CustomBatchGeoSampler(BatchGeoSampler): @@ -26,7 +27,7 @@ def __len__(self) -> int: class CustomGeoDataset(GeoDataset): - def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 1) -> None: + def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: super().__init__() self._crs = crs self.res = res @@ -36,6 +37,10 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]: class TestBatchGeoSampler: + @pytest.fixture(scope="class") + def dataset(self) -> CustomGeoDataset: + return CustomGeoDataset() + @pytest.fixture(scope="function") def sampler(self) -> CustomBatchGeoSampler: return CustomBatchGeoSampler() @@ -49,28 +54,45 @@ def test_len(self, sampler: CustomBatchGeoSampler) -> None: @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) - def test_dataloader(self, sampler: CustomBatchGeoSampler, num_workers: int) -> None: - ds = CustomGeoDataset() + def test_dataloader( + self, + dataset: CustomGeoDataset, + sampler: CustomBatchGeoSampler, + num_workers: int, + ) -> None: dl = DataLoader( - ds, batch_sampler=sampler, num_workers=num_workers, collate_fn=stack_samples + dataset, + batch_sampler=sampler, + num_workers=num_workers, + collate_fn=stack_samples, ) for _ in dl: continue - def test_abstract(self) -> None: - ds = CustomGeoDataset() + def test_abstract(self, dataset: CustomGeoDataset) -> None: with pytest.raises(TypeError, match="Can't instantiate abstract class"): - BatchGeoSampler(ds) # type: ignore[abstract] + BatchGeoSampler(dataset) # type: ignore[abstract] class TestRandomBatchGeoSampler: - @pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)]) - def sampler(self, request: SubRequest) -> RandomBatchGeoSampler: + @pytest.fixture(scope="class") + def dataset(self) -> CustomGeoDataset: ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 20, 30, 40, 50)) - ds.index.insert(1, (0, 10, 20, 30, 40, 50)) - size = request.param - return RandomBatchGeoSampler(ds, size, batch_size=2, length=10) + ds.index.insert(0, (0, 100, 200, 300, 400, 500)) + ds.index.insert(1, (0, 100, 200, 300, 400, 500)) + return ds + + @pytest.fixture( + scope="function", + params=product([3, 4.5, (2, 2), (3, 4.5), (4.5, 3)], [Units.PIXELS, Units.CRS]), + ) + def sampler( + self, dataset: CustomGeoDataset, request: SubRequest + ) -> RandomBatchGeoSampler: + size, units = request.param + return RandomBatchGeoSampler( + dataset, size, batch_size=2, length=10, units=units + ) def test_iter(self, sampler: RandomBatchGeoSampler) -> None: for batch in sampler: @@ -88,18 +110,15 @@ def test_iter(self, sampler: RandomBatchGeoSampler) -> None: def test_len(self, sampler: RandomBatchGeoSampler) -> None: assert len(sampler) == sampler.length // sampler.batch_size - def test_roi(self) -> None: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - ds.index.insert(1, (5, 15, 5, 15, 5, 15)) - roi = BoundingBox(0, 10, 0, 10, 0, 10) - sampler = RandomBatchGeoSampler(ds, 2, 2, 10, roi=roi) + def test_roi(self, dataset: CustomGeoDataset) -> None: + roi = BoundingBox(0, 50, 200, 250, 400, 450) + sampler = RandomBatchGeoSampler(dataset, 2, 2, 10, roi=roi) for batch in sampler: for query in batch: assert query in roi def test_small_area(self) -> None: - ds = CustomGeoDataset() + ds = CustomGeoDataset(res=1) ds.index.insert(0, (0, 10, 0, 10, 0, 10)) ds.index.insert(1, (20, 21, 20, 21, 20, 21)) sampler = RandomBatchGeoSampler(ds, 2, 2, 10) @@ -108,10 +127,17 @@ def test_small_area(self) -> None: @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) - def test_dataloader(self, sampler: RandomBatchGeoSampler, num_workers: int) -> None: - ds = CustomGeoDataset() + def test_dataloader( + self, + dataset: CustomGeoDataset, + sampler: RandomBatchGeoSampler, + num_workers: int, + ) -> None: dl = DataLoader( - ds, batch_sampler=sampler, num_workers=num_workers, collate_fn=stack_samples + dataset, + batch_sampler=sampler, + num_workers=num_workers, + collate_fn=stack_samples, ) for _ in dl: continue diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index aa13b8b56aa..9dd49f24c79 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import math +from itertools import product from typing import Dict, Iterator import pytest @@ -10,7 +11,7 @@ from torch.utils.data import DataLoader from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples -from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler +from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler, Units class CustomGeoSampler(GeoSampler): @@ -26,7 +27,7 @@ def __len__(self) -> int: class CustomGeoDataset(GeoDataset): - def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 1) -> None: + def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: super().__init__() self._crs = crs self.res = res @@ -36,6 +37,10 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]: class TestGeoSampler: + @pytest.fixture(scope="class") + def dataset(self) -> CustomGeoDataset: + return CustomGeoDataset() + @pytest.fixture(scope="function") def sampler(self) -> CustomGeoSampler: return CustomGeoSampler() @@ -46,30 +51,39 @@ def test_iter(self, sampler: CustomGeoSampler) -> None: def test_len(self, sampler: CustomGeoSampler) -> None: assert len(sampler) == 2 - def test_abstract(self) -> None: - ds = CustomGeoDataset() + def test_abstract(self, dataset: CustomGeoDataset) -> None: with pytest.raises(TypeError, match="Can't instantiate abstract class"): - GeoSampler(ds) # type: ignore[abstract] + GeoSampler(dataset) # type: ignore[abstract] @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) - def test_dataloader(self, sampler: CustomGeoSampler, num_workers: int) -> None: - ds = CustomGeoDataset() + def test_dataloader( + self, dataset: CustomGeoDataset, sampler: CustomGeoSampler, num_workers: int + ) -> None: dl = DataLoader( - ds, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples + dataset, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples ) for _ in dl: continue class TestRandomGeoSampler: - @pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)]) - def sampler(self, request: SubRequest) -> RandomGeoSampler: + @pytest.fixture(scope="class") + def dataset(self) -> CustomGeoDataset: ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 20, 30, 40, 50)) - ds.index.insert(1, (0, 10, 20, 30, 40, 50)) - size = request.param - return RandomGeoSampler(ds, size, length=10) + ds.index.insert(0, (0, 100, 200, 300, 400, 500)) + ds.index.insert(1, (0, 100, 200, 300, 400, 500)) + return ds + + @pytest.fixture( + scope="function", + params=product([3, 4.5, (2, 2), (3, 4.5), (4.5, 3)], [Units.PIXELS, Units.CRS]), + ) + def sampler( + self, dataset: CustomGeoDataset, request: SubRequest + ) -> RandomGeoSampler: + size, units = request.param + return RandomGeoSampler(dataset, size, length=10, units=units) def test_iter(self, sampler: RandomGeoSampler) -> None: for query in sampler: @@ -86,17 +100,14 @@ def test_iter(self, sampler: RandomGeoSampler) -> None: def test_len(self, sampler: RandomGeoSampler) -> None: assert len(sampler) == sampler.length - def test_roi(self) -> None: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - ds.index.insert(1, (5, 15, 5, 15, 5, 15)) - roi = BoundingBox(0, 10, 0, 10, 0, 10) - sampler = RandomGeoSampler(ds, 2, 10, roi=roi) + def test_roi(self, dataset: CustomGeoDataset) -> None: + roi = BoundingBox(0, 50, 200, 250, 400, 450) + sampler = RandomGeoSampler(dataset, 2, 10, roi=roi) for query in sampler: assert query in roi def test_small_area(self) -> None: - ds = CustomGeoDataset() + ds = CustomGeoDataset(res=1) ds.index.insert(0, (0, 10, 0, 10, 0, 10)) ds.index.insert(1, (20, 21, 20, 21, 20, 21)) sampler = RandomGeoSampler(ds, 2, 10) @@ -105,26 +116,34 @@ def test_small_area(self) -> None: @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) - def test_dataloader(self, sampler: RandomGeoSampler, num_workers: int) -> None: - ds = CustomGeoDataset() + def test_dataloader( + self, dataset: CustomGeoDataset, sampler: RandomGeoSampler, num_workers: int + ) -> None: dl = DataLoader( - ds, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples + dataset, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples ) for _ in dl: continue class TestGridGeoSampler: + @pytest.fixture(scope="class") + def dataset(self) -> CustomGeoDataset: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 100, 200, 300, 400, 500)) + ds.index.insert(1, (0, 100, 200, 300, 400, 500)) + return ds + @pytest.fixture( scope="function", - params=[(8, 1), (6, 2), (4, 3), (2.5, 3), ((8, 6), (1, 2)), ((6, 4), (2, 3))], + params=product( + [(8, 1), (6, 2), (4, 3), (2.5, 3), ((8, 6), (1, 2)), ((6, 4), (2, 3))], + [Units.PIXELS, Units.CRS], + ), ) - def sampler(self, request: SubRequest) -> GridGeoSampler: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 20, 0, 10, 40, 50)) - ds.index.insert(1, (0, 20, 0, 10, 40, 50)) - size, stride = request.param - return GridGeoSampler(ds, size, stride) + def sampler(self, dataset: CustomGeoDataset, request: SubRequest) -> GridGeoSampler: + (size, stride), units = request.param + return GridGeoSampler(dataset, size, stride, units=units) def test_iter(self, sampler: GridGeoSampler) -> None: for query in sampler: @@ -139,17 +158,14 @@ def test_iter(self, sampler: GridGeoSampler) -> None: ) def test_len(self, sampler: GridGeoSampler) -> None: - rows = int((10 - sampler.size[0]) // sampler.stride[0]) + 1 - cols = int((20 - sampler.size[1]) // sampler.stride[1]) + 1 + rows = int((100 - sampler.size[0]) // sampler.stride[0]) + 1 + cols = int((100 - sampler.size[1]) // sampler.stride[1]) + 1 length = rows * cols * 2 assert len(sampler) == length - def test_roi(self) -> None: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - ds.index.insert(1, (5, 15, 5, 15, 5, 15)) - roi = BoundingBox(0, 10, 0, 10, 0, 10) - sampler = GridGeoSampler(ds, 2, 1, roi=roi) + def test_roi(self, dataset: CustomGeoDataset) -> None: + roi = BoundingBox(0, 50, 200, 250, 400, 450) + sampler = GridGeoSampler(dataset, 2, 1, roi=roi) for query in sampler: assert query in roi @@ -163,10 +179,11 @@ def test_small_area(self) -> None: @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) - def test_dataloader(self, sampler: GridGeoSampler, num_workers: int) -> None: - ds = CustomGeoDataset() + def test_dataloader( + self, dataset: CustomGeoDataset, sampler: GridGeoSampler, num_workers: int + ) -> None: dl = DataLoader( - ds, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples + dataset, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples ) for _ in dl: continue diff --git a/torchgeo/samplers/__init__.py b/torchgeo/samplers/__init__.py index da02f09d802..a6f63de1917 100644 --- a/torchgeo/samplers/__init__.py +++ b/torchgeo/samplers/__init__.py @@ -4,6 +4,7 @@ """TorchGeo samplers.""" from .batch import BatchGeoSampler, RandomBatchGeoSampler +from .constants import Units from .single import GeoSampler, GridGeoSampler, RandomGeoSampler __all__ = ( @@ -15,6 +16,8 @@ # Base classes "GeoSampler", "BatchGeoSampler", + # Constants + "Units", ) # https://stackoverflow.com/questions/40018681 diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 488649b9785..e269a748db6 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -11,6 +11,7 @@ from torch.utils.data import Sampler from ..datasets import BoundingBox, GeoDataset +from .constants import Units from .utils import _to_tuple, get_random_bounding_box # https://github.com/pytorch/pytorch/issues/60979 @@ -71,6 +72,7 @@ def __init__( batch_size: int, length: int, roi: Optional[BoundingBox] = None, + units: Units = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -83,14 +85,22 @@ def __init__( Args: dataset: dataset to index from - size: dimensions of each :term:`patch` in units of CRS + size: dimensions of each :term:`patch` batch_size: number of samples per batch length: number of samples per epoch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) + units: defines if ``size`` is in pixel or CRS units + + .. versionchanged:: 0.3 + Added ``units`` parameter, changed default to pixel units """ super().__init__(dataset, roi) self.size = _to_tuple(size) + + if units == Units.PIXELS: + self.size = (self.size[0] * self.res, self.size[1] * self.res) + self.batch_size = batch_size self.length = length self.hits = [] diff --git a/torchgeo/samplers/constants.py b/torchgeo/samplers/constants.py new file mode 100644 index 00000000000..18e1f598ddb --- /dev/null +++ b/torchgeo/samplers/constants.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Common sampler constants.""" + +from enum import Enum, auto + + +class Units(Enum): + """Enumeration defining units of ``size`` parameter. + + Used by :class:`~torchgeo.samplers.GeoSampler` and + :class:`~torchgeo.samplers.BatchGeoSampler`. + """ + + #: Units in number of pixels + PIXELS = auto() + + #: Units of :term:`coordinate reference system (CRS)` + CRS = auto() diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 1f8eb459f4e..781cb5c38b1 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -11,6 +11,7 @@ from torch.utils.data import Sampler from ..datasets import BoundingBox, GeoDataset +from .constants import Units from .utils import _to_tuple, get_random_bounding_box # https://github.com/pytorch/pytorch/issues/60979 @@ -73,6 +74,7 @@ def __init__( size: Union[Tuple[float, float], float], length: int, roi: Optional[BoundingBox] = None, + units: Units = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -85,13 +87,21 @@ def __init__( Args: dataset: dataset to index from - size: dimensions of each :term:`patch` in units of CRS + size: dimensions of each :term:`patch` length: number of random samples to draw per epoch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) + units: defines if ``size`` is in pixel or CRS units + + .. versionchanged:: 0.3 + Added ``units`` parameter, changed default to pixel units """ super().__init__(dataset, roi) self.size = _to_tuple(size) + + if units == Units.PIXELS: + self.size = (self.size[0] * self.res, self.size[1] * self.res) + self.length = length self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): @@ -148,6 +158,7 @@ def __init__( size: Union[Tuple[float, float], float], stride: Union[Tuple[float, float], float], roi: Optional[BoundingBox] = None, + units: Units = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -160,14 +171,23 @@ def __init__( Args: dataset: dataset to index from - size: dimensions of each :term:`patch` in units of CRS + size: dimensions of each :term:`patch` stride: distance to skip between each patch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) + units: defines if ``size`` and ``stride`` are in pixel or CRS units + + .. versionchanged:: 0.3 + Added ``units`` parameter, changed default to pixel units """ super().__init__(dataset, roi) self.size = _to_tuple(size) self.stride = _to_tuple(stride) + + if units == Units.PIXELS: + self.size = (self.size[0] * self.res, self.size[1] * self.res) + self.stride = (self.stride[0] * self.res, self.stride[1] * self.res) + self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): bounds = BoundingBox(*hit.bounds) diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index 265859eeb06..f8382626ee8 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -43,7 +43,7 @@ def get_random_bounding_box( Returns: randomly sampled bounding box from the extent of the input """ - t_size: Tuple[float, float] = _to_tuple(size) + t_size = _to_tuple(size) width = (bounds.maxx - bounds.minx - t_size[1]) // res minx = random.randrange(int(width)) * res + bounds.minx