Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Feb 23, 2022
1 parent cb72d0a commit d0f2502
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 66 deletions.
74 changes: 50 additions & 24 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

import math
from itertools import product
from typing import Dict, Iterator, List

import pytest
Expand All @@ -10,7 +11,7 @@
from torch.utils.data import DataLoader

from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler
from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler, Units


class CustomBatchGeoSampler(BatchGeoSampler):
Expand All @@ -26,7 +27,7 @@ def __len__(self) -> int:


class CustomGeoDataset(GeoDataset):
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 1) -> None:
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
super().__init__()
self._crs = crs
self.res = res
Expand All @@ -36,6 +37,10 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]:


class TestBatchGeoSampler:
@pytest.fixture(scope="class")
def dataset(self) -> CustomGeoDataset:
return CustomGeoDataset()

@pytest.fixture(scope="function")
def sampler(self) -> CustomBatchGeoSampler:
return CustomBatchGeoSampler()
Expand All @@ -49,28 +54,45 @@ def test_len(self, sampler: CustomBatchGeoSampler) -> None:

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: CustomBatchGeoSampler, num_workers: int) -> None:
ds = CustomGeoDataset()
def test_dataloader(
self,
dataset: CustomGeoDataset,
sampler: CustomBatchGeoSampler,
num_workers: int,
) -> None:
dl = DataLoader(
ds, batch_sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
dataset,
batch_sampler=sampler,
num_workers=num_workers,
collate_fn=stack_samples,
)
for _ in dl:
continue

def test_abstract(self) -> None:
ds = CustomGeoDataset()
def test_abstract(self, dataset: CustomGeoDataset) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
BatchGeoSampler(ds) # type: ignore[abstract]
BatchGeoSampler(dataset) # type: ignore[abstract]


class TestRandomBatchGeoSampler:
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
def sampler(self, request: SubRequest) -> RandomBatchGeoSampler:
@pytest.fixture(scope="class")
def dataset(self):
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 20, 30, 40, 50))
ds.index.insert(1, (0, 10, 20, 30, 40, 50))
size = request.param
return RandomBatchGeoSampler(ds, size, batch_size=2, length=10)
ds.index.insert(0, (0, 100, 200, 300, 400, 500))
ds.index.insert(1, (0, 100, 200, 300, 400, 500))
return ds

@pytest.fixture(
scope="function",
params=product([3, 4.5, (2, 2), (3, 4.5), (4.5, 3)], [Units.PIXELS, Units.CRS]),
)
def sampler(
self, dataset: CustomGeoDataset, request: SubRequest
) -> RandomBatchGeoSampler:
size, units = request.param
return RandomBatchGeoSampler(
dataset, size, batch_size=2, length=10, units=units
)

def test_iter(self, sampler: RandomBatchGeoSampler) -> None:
for batch in sampler:
Expand All @@ -88,18 +110,15 @@ def test_iter(self, sampler: RandomBatchGeoSampler) -> None:
def test_len(self, sampler: RandomBatchGeoSampler) -> None:
assert len(sampler) == sampler.length // sampler.batch_size

def test_roi(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (5, 15, 5, 15, 5, 15))
roi = BoundingBox(0, 10, 0, 10, 0, 10)
sampler = RandomBatchGeoSampler(ds, 2, 2, 10, roi=roi)
def test_roi(self, dataset: CustomGeoDataset) -> None:
roi = BoundingBox(0, 50, 200, 250, 400, 450)
sampler = RandomBatchGeoSampler(dataset, 2, 2, 10, roi=roi)
for batch in sampler:
for query in batch:
assert query in roi

def test_small_area(self) -> None:
ds = CustomGeoDataset()
ds = CustomGeoDataset(res=1)
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (20, 21, 20, 21, 20, 21))
sampler = RandomBatchGeoSampler(ds, 2, 2, 10)
Expand All @@ -108,10 +127,17 @@ def test_small_area(self) -> None:

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: RandomBatchGeoSampler, num_workers: int) -> None:
ds = CustomGeoDataset()
def test_dataloader(
self,
dataset: CustomGeoDataset,
sampler: RandomBatchGeoSampler,
num_workers: int,
) -> None:
dl = DataLoader(
ds, batch_sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
dataset,
batch_sampler=sampler,
num_workers=num_workers,
collate_fn=stack_samples,
)
for _ in dl:
continue
101 changes: 59 additions & 42 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

import math
from itertools import product
from typing import Dict, Iterator

import pytest
Expand All @@ -10,7 +11,7 @@
from torch.utils.data import DataLoader

from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler
from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler, Units


class CustomGeoSampler(GeoSampler):
Expand All @@ -26,7 +27,7 @@ def __len__(self) -> int:


class CustomGeoDataset(GeoDataset):
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 1) -> None:
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
super().__init__()
self._crs = crs
self.res = res
Expand All @@ -36,6 +37,10 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]:


class TestGeoSampler:
@pytest.fixture(scope="class")
def dataset(self) -> CustomGeoDataset:
return CustomGeoDataset()

@pytest.fixture(scope="function")
def sampler(self) -> CustomGeoSampler:
return CustomGeoSampler()
Expand All @@ -46,30 +51,39 @@ def test_iter(self, sampler: CustomGeoSampler) -> None:
def test_len(self, sampler: CustomGeoSampler) -> None:
assert len(sampler) == 2

def test_abstract(self) -> None:
ds = CustomGeoDataset()
def test_abstract(self, dataset: CustomGeoDataset) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
GeoSampler(ds) # type: ignore[abstract]
GeoSampler(dataset) # type: ignore[abstract]

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: CustomGeoSampler, num_workers: int) -> None:
ds = CustomGeoDataset()
def test_dataloader(
self, dataset: CustomGeoDataset, sampler: CustomGeoSampler, num_workers: int
) -> None:
dl = DataLoader(
ds, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
dataset, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
)
for _ in dl:
continue


class TestRandomGeoSampler:
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
def sampler(self, request: SubRequest) -> RandomGeoSampler:
@pytest.fixture(scope="class")
def dataset(self) -> CustomGeoDataset:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 20, 30, 40, 50))
ds.index.insert(1, (0, 10, 20, 30, 40, 50))
size = request.param
return RandomGeoSampler(ds, size, length=10)
ds.index.insert(0, (0, 100, 200, 300, 400, 500))
ds.index.insert(1, (0, 100, 200, 300, 400, 500))
return ds

@pytest.fixture(
scope="function",
params=product([3, 4.5, (2, 2), (3, 4.5), (4.5, 3)], [Units.PIXELS, Units.CRS]),
)
def sampler(
self, dataset: CustomGeoDataset, request: SubRequest
) -> RandomGeoSampler:
size, units = request.param
return RandomGeoSampler(dataset, size, length=10, units=units)

def test_iter(self, sampler: RandomGeoSampler) -> None:
for query in sampler:
Expand All @@ -86,17 +100,14 @@ def test_iter(self, sampler: RandomGeoSampler) -> None:
def test_len(self, sampler: RandomGeoSampler) -> None:
assert len(sampler) == sampler.length

def test_roi(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (5, 15, 5, 15, 5, 15))
roi = BoundingBox(0, 10, 0, 10, 0, 10)
sampler = RandomGeoSampler(ds, 2, 10, roi=roi)
def test_roi(self, dataset: CustomGeoDataset) -> None:
roi = BoundingBox(0, 50, 200, 250, 400, 450)
sampler = RandomGeoSampler(dataset, 2, 10, roi=roi)
for query in sampler:
assert query in roi

def test_small_area(self) -> None:
ds = CustomGeoDataset()
ds = CustomGeoDataset(res=1)
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (20, 21, 20, 21, 20, 21))
sampler = RandomGeoSampler(ds, 2, 10)
Expand All @@ -105,26 +116,34 @@ def test_small_area(self) -> None:

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: RandomGeoSampler, num_workers: int) -> None:
ds = CustomGeoDataset()
def test_dataloader(
self, dataset: CustomGeoDataset, sampler: RandomGeoSampler, num_workers: int
) -> None:
dl = DataLoader(
ds, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
dataset, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
)
for _ in dl:
continue


class TestGridGeoSampler:
@pytest.fixture(scope="class")
def dataset(self) -> CustomGeoDataset:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 100, 200, 300, 400, 500))
ds.index.insert(1, (0, 100, 200, 300, 400, 500))
return ds

@pytest.fixture(
scope="function",
params=[(8, 1), (6, 2), (4, 3), (2.5, 3), ((8, 6), (1, 2)), ((6, 4), (2, 3))],
params=product(
[(8, 1), (6, 2), (4, 3), (2.5, 3), ((8, 6), (1, 2)), ((6, 4), (2, 3))],
[Units.PIXELS, Units.CRS],
),
)
def sampler(self, request: SubRequest) -> GridGeoSampler:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 20, 0, 10, 40, 50))
ds.index.insert(1, (0, 20, 0, 10, 40, 50))
size, stride = request.param
return GridGeoSampler(ds, size, stride)
def sampler(self, dataset: CustomGeoDataset, request: SubRequest) -> GridGeoSampler:
(size, stride), units = request.param
return GridGeoSampler(dataset, size, stride, units=units)

def test_iter(self, sampler: GridGeoSampler) -> None:
for query in sampler:
Expand All @@ -139,17 +158,14 @@ def test_iter(self, sampler: GridGeoSampler) -> None:
)

def test_len(self, sampler: GridGeoSampler) -> None:
rows = int((10 - sampler.size[0]) // sampler.stride[0]) + 1
cols = int((20 - sampler.size[1]) // sampler.stride[1]) + 1
rows = int((100 - sampler.size[0]) // sampler.stride[0]) + 1
cols = int((100 - sampler.size[1]) // sampler.stride[1]) + 1
length = rows * cols * 2
assert len(sampler) == length

def test_roi(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (5, 15, 5, 15, 5, 15))
roi = BoundingBox(0, 10, 0, 10, 0, 10)
sampler = GridGeoSampler(ds, 2, 1, roi=roi)
def test_roi(self, dataset: CustomGeoDataset) -> None:
roi = BoundingBox(0, 50, 200, 250, 400, 450)
sampler = GridGeoSampler(dataset, 2, 1, roi=roi)
for query in sampler:
assert query in roi

Expand All @@ -163,10 +179,11 @@ def test_small_area(self) -> None:

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: GridGeoSampler, num_workers: int) -> None:
ds = CustomGeoDataset()
def test_dataloader(
self, dataset: CustomGeoDataset, sampler: GridGeoSampler, num_workers: int
) -> None:
dl = DataLoader(
ds, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
dataset, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
)
for _ in dl:
continue

0 comments on commit d0f2502

Please sign in to comment.