Skip to content

Commit

Permalink
Add pixel sampling mode (microsoft#294)
Browse files Browse the repository at this point in the history
* Add pixel sampling mode

* Fix maxy indexing error

Co-authored-by: Ashwin Nair <ash1995@gmail.com>

* Add sample_mode docstrings, default to PIXELS

* Replace sample_mode with units

* Update to use enum

* Fix mypy, tuple, and flake8 issues

* Fix isort and pydocstyle problems

* Update sampler docs to discuss unit sampling mode

* Various fixes

* Add units arg to GridGeoSampler

* Update benchmark script

* Add tests

* Document enum values

* mypy fixes

Co-authored-by: Ashwin Nair <ash1995@gmail.com>
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
3 people authored Feb 24, 2022
1 parent ee0b7f7 commit 7153739
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 82 deletions.
17 changes: 7 additions & 10 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,16 @@ def set_up_parser() -> argparse.ArgumentParser:
"--patch-size",
default=224,
type=int,
help="height/width of each patch",
metavar="SIZE",
help="height/width of each patch in pixels",
metavar="PIXELS",
)
parser.add_argument(
"-s",
"--stride",
default=112,
type=int,
help="sampling stride for GridGeoSampler",
help="sampling stride for GridGeoSampler in pixels",
metavar="PIXELS",
)
parser.add_argument(
"-w",
Expand Down Expand Up @@ -139,15 +140,11 @@ def main(args: argparse.Namespace) -> None:
length = args.num_batches * args.batch_size
num_batches = args.num_batches

# Convert from pixel coords to CRS coords
size = args.patch_size * cdl.res
stride = args.stride * cdl.res

samplers = [
RandomGeoSampler(landsat, size=size, length=length),
GridGeoSampler(landsat, size=size, stride=stride),
RandomGeoSampler(landsat, size=args.patch_size, length=length),
GridGeoSampler(landsat, size=args.patch_size, stride=args.stride),
RandomBatchGeoSampler(
landsat, size=size, batch_size=args.batch_size, length=length
landsat, size=args.patch_size, batch_size=args.batch_size, length=length
),
]

Expand Down
29 changes: 27 additions & 2 deletions docs/api/samplers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ Samplers are used to index a dataset, retrieving a single query at a time. For :
from torchgeo.samplers import RandomGeoSampler
dataset = Landsat(...)
sampler = RandomGeoSampler(dataset, size=1000, length=100)
sampler = RandomGeoSampler(dataset, size=256, length=10000)
dataloader = DataLoader(dataset, sampler=sampler)
This data loader will return 256x256 px images, and has an epoch length of 10,000.

Random Geo Sampler
^^^^^^^^^^^^^^^^^^

Expand All @@ -43,10 +45,12 @@ When working with large tile-based datasets, randomly sampling patches from each
from torchgeo.samplers import RandomBatchGeoSampler
dataset = Landsat(...)
sampler = RandomBatchGeoSampler(dataset, size=1000, batch_size=10, length=100)
sampler = RandomBatchGeoSampler(dataset, size=256, batch_size=128, length=10000)
dataloader = DataLoader(dataset, batch_sampler=sampler)
This data loader will return 256x256 px images, and has a batch size of 128 and an epoch length of 10,000.

Random Batch Geo Sampler
^^^^^^^^^^^^^^^^^^^^^^^^

Expand All @@ -66,3 +70,24 @@ Batch Geo Sampler
^^^^^^^^^^^^^^^^^

.. autoclass:: BatchGeoSampler

Units
-----

By default, the ``size`` parameter specifies the size of the image in *pixel* units. If you would instead like to specify the size in *CRS* units, you can change the ``units`` parameter like so:

.. code-block:: python
from torch.utils.data import DataLoader
from torchgeo.datasets import Landsat
from torchgeo.samplers import RandomGeoSampler, Units
dataset = Landsat(...)
sampler = RandomGeoSampler(dataset, size=256 * 30, length=10000, units=Units.CRS)
dataloader = DataLoader(dataset, sampler=sampler)
Assuming that each pixel in the CRS is 30 m, this data loader will return 256x256 px images, and has an epoch length of 10,000.

.. autoclass:: Units
74 changes: 50 additions & 24 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

import math
from itertools import product
from typing import Dict, Iterator, List

import pytest
Expand All @@ -10,7 +11,7 @@
from torch.utils.data import DataLoader

from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler
from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler, Units


class CustomBatchGeoSampler(BatchGeoSampler):
Expand All @@ -26,7 +27,7 @@ def __len__(self) -> int:


class CustomGeoDataset(GeoDataset):
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 1) -> None:
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
super().__init__()
self._crs = crs
self.res = res
Expand All @@ -36,6 +37,10 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]:


class TestBatchGeoSampler:
@pytest.fixture(scope="class")
def dataset(self) -> CustomGeoDataset:
return CustomGeoDataset()

@pytest.fixture(scope="function")
def sampler(self) -> CustomBatchGeoSampler:
return CustomBatchGeoSampler()
Expand All @@ -49,28 +54,45 @@ def test_len(self, sampler: CustomBatchGeoSampler) -> None:

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: CustomBatchGeoSampler, num_workers: int) -> None:
ds = CustomGeoDataset()
def test_dataloader(
self,
dataset: CustomGeoDataset,
sampler: CustomBatchGeoSampler,
num_workers: int,
) -> None:
dl = DataLoader(
ds, batch_sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
dataset,
batch_sampler=sampler,
num_workers=num_workers,
collate_fn=stack_samples,
)
for _ in dl:
continue

def test_abstract(self) -> None:
ds = CustomGeoDataset()
def test_abstract(self, dataset: CustomGeoDataset) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
BatchGeoSampler(ds) # type: ignore[abstract]
BatchGeoSampler(dataset) # type: ignore[abstract]


class TestRandomBatchGeoSampler:
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
def sampler(self, request: SubRequest) -> RandomBatchGeoSampler:
@pytest.fixture(scope="class")
def dataset(self) -> CustomGeoDataset:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 20, 30, 40, 50))
ds.index.insert(1, (0, 10, 20, 30, 40, 50))
size = request.param
return RandomBatchGeoSampler(ds, size, batch_size=2, length=10)
ds.index.insert(0, (0, 100, 200, 300, 400, 500))
ds.index.insert(1, (0, 100, 200, 300, 400, 500))
return ds

@pytest.fixture(
scope="function",
params=product([3, 4.5, (2, 2), (3, 4.5), (4.5, 3)], [Units.PIXELS, Units.CRS]),
)
def sampler(
self, dataset: CustomGeoDataset, request: SubRequest
) -> RandomBatchGeoSampler:
size, units = request.param
return RandomBatchGeoSampler(
dataset, size, batch_size=2, length=10, units=units
)

def test_iter(self, sampler: RandomBatchGeoSampler) -> None:
for batch in sampler:
Expand All @@ -88,18 +110,15 @@ def test_iter(self, sampler: RandomBatchGeoSampler) -> None:
def test_len(self, sampler: RandomBatchGeoSampler) -> None:
assert len(sampler) == sampler.length // sampler.batch_size

def test_roi(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (5, 15, 5, 15, 5, 15))
roi = BoundingBox(0, 10, 0, 10, 0, 10)
sampler = RandomBatchGeoSampler(ds, 2, 2, 10, roi=roi)
def test_roi(self, dataset: CustomGeoDataset) -> None:
roi = BoundingBox(0, 50, 200, 250, 400, 450)
sampler = RandomBatchGeoSampler(dataset, 2, 2, 10, roi=roi)
for batch in sampler:
for query in batch:
assert query in roi

def test_small_area(self) -> None:
ds = CustomGeoDataset()
ds = CustomGeoDataset(res=1)
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (20, 21, 20, 21, 20, 21))
sampler = RandomBatchGeoSampler(ds, 2, 2, 10)
Expand All @@ -108,10 +127,17 @@ def test_small_area(self) -> None:

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: RandomBatchGeoSampler, num_workers: int) -> None:
ds = CustomGeoDataset()
def test_dataloader(
self,
dataset: CustomGeoDataset,
sampler: RandomBatchGeoSampler,
num_workers: int,
) -> None:
dl = DataLoader(
ds, batch_sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
dataset,
batch_sampler=sampler,
num_workers=num_workers,
collate_fn=stack_samples,
)
for _ in dl:
continue
Loading

0 comments on commit 7153739

Please sign in to comment.