Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GridGeoSampler: change stride of last patch to sample entire ROI #630

Merged
merged 23 commits into from
Sep 3, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ad093fb
Adjust minx/miny with a smaller stride for the last sample per row/co…
remtav Mar 2, 2022
855fee1
style and mypy fixes
remtav Mar 14, 2022
eb33fe0
black test fix
remtav Mar 14, 2022
cb554c6
Adjust minx/miny with a smaller stride for the last sample per row/co…
remtav Mar 2, 2022
1a0236c
style and mypy fixes
remtav Mar 14, 2022
bfadf76
black test fix
remtav Mar 14, 2022
6942b5f
single.py: adapt gridgeosampler to sample beyond limit of ROI for a p…
remtav Jun 28, 2022
138855b
format for black and flake8
remtav Jun 28, 2022
37d6055
format for black and flake8
remtav Jun 28, 2022
b0ab3fa
once again, format for black and flake8
remtav Jun 28, 2022
45b3490
Merge branch 'microsoft:main' into samplers/gridgeosampler_bounds
remtav Aug 25, 2022
720cf5b
Merge remote-tracking branch 'origin/samplers/gridgeosampler_bounds' …
remtav Aug 25, 2022
af5a3d1
Revert "Adjust minx/miny with a smaller stride for the last sample pe…
remtav Aug 29, 2022
e588385
Merge branch 'microsoft:main' into samplers/gridgeosampler_bounds
remtav Aug 29, 2022
0e61b1d
adapt unit tests, remove warnings
remtav Aug 29, 2022
6c623d8
flake8: remove warnings import
remtav Aug 30, 2022
fd9b69a
Merge branch 'main' into samplers/gridgeosampler_bounds
remtav Aug 30, 2022
13daca1
Merge branch 'main' into samplers/gridgeosampler_bounds
adamjstewart Sep 3, 2022
9d68d1e
Address some comments
adamjstewart Sep 3, 2022
8660f28
Simplify computation of # rows/cols
adamjstewart Sep 3, 2022
9536970
Document this new feature
adamjstewart Sep 3, 2022
32d877e
Fix size of ceiling symbol
adamjstewart Sep 3, 2022
a741129
Simplify tests
adamjstewart Sep 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 63 additions & 9 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,13 @@ def sampler(self, dataset: CustomGeoDataset, request: SubRequest) -> GridGeoSamp

def test_iter(self, sampler: GridGeoSampler) -> None:
for query in sampler:
assert sampler.roi.minx <= query.minx <= query.maxx <= sampler.roi.maxx
assert sampler.roi.miny <= query.miny <= query.miny <= sampler.roi.maxy
assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt
assert sampler.roi.minx <= query.minx
assert sampler.roi.miny <= query.miny
assert sampler.roi.mint <= query.mint
if query.maxx > sampler.roi.maxx:
assert (query.maxx - sampler.roi.maxx) < sampler.size[1]
if query.maxy > sampler.roi.maxy:
assert (query.maxy - sampler.roi.maxy) < sampler.size[0]

assert math.isclose(query.maxx - query.minx, sampler.size[1])
assert math.isclose(query.maxy - query.miny, sampler.size[0])
Expand All @@ -182,24 +186,74 @@ def test_iter(self, sampler: GridGeoSampler) -> None:
)

def test_len(self, sampler: GridGeoSampler) -> None:
rows = int((100 - sampler.size[0]) // sampler.stride[0]) + 1
cols = int((100 - sampler.size[1]) // sampler.stride[1]) + 1
length = rows * cols * 2
rows = math.ceil(
(100 - sampler.size[0] + sampler.stride[0]) / sampler.stride[0]
)
cols = math.ceil(
(100 - sampler.size[1] + sampler.stride[1]) / sampler.stride[1]
)
length = rows * cols * 2 # two items in dataset
assert len(sampler) == length

def test_len_larger(self, sampler: GridGeoSampler) -> None:
entire_rows = (100 - sampler.size[0] + sampler.stride[0]) // sampler.stride[0]
entire_cols = (100 - sampler.size[1] + sampler.stride[1]) // sampler.stride[1]
leftover_row = (100 - sampler.size[0] + sampler.stride[0]) / sampler.stride[
0
] - entire_rows
leftover_col = (100 - sampler.size[1] + sampler.stride[1]) / sampler.stride[
1
] - entire_cols
assert (
len(sampler)
== (entire_rows + math.ceil(leftover_row))
* (entire_cols + math.ceil(leftover_col))
* 2
)

def test_roi(self, dataset: CustomGeoDataset) -> None:
roi = BoundingBox(0, 50, 200, 250, 400, 450)
sampler = GridGeoSampler(dataset, 2, 1, roi=roi)
for query in sampler:
assert query in roi

def test_small_area(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 1, 0, 1, 0, 1))
sampler = GridGeoSampler(ds, 2, 10)
assert len(sampler) == 0

# TODO: skip patches with area=0 when two tiles are
# side-by-side with an overlapping edge face.
def test_tiles_side_by_side(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (20, 21, 20, 21, 20, 21))
ds.index.insert(0, (0, 10, 10, 20, 0, 10))
sampler = GridGeoSampler(ds, 2, 10)
for _ in sampler:
continue
for bbox in sampler:
assert bbox.area > 0

def test_equal_area(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(ds, 10, 10, units=Units.CRS)
assert len(sampler) == 1
for bbox in sampler:
assert bbox == BoundingBox(
minx=0.0, maxx=10.0, miny=0.0, maxy=10.0, mint=0.0, maxt=10.0
)

def test_larger_area(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 6, 0, 5, 0, 10))
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
assert len(sampler) == 2
assert list(sampler)[0] == BoundingBox(
minx=0.0, maxx=5.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0
)
assert list(sampler)[1] == BoundingBox(
minx=1.0, maxx=6.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0
)

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
Expand Down
32 changes: 26 additions & 6 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""TorchGeo samplers."""

import abc
import math
from typing import Callable, Iterable, Iterator, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -200,17 +201,24 @@ def __init__(
for hit in self.index.intersection(tuple(self.roi), objects=True):
bounds = BoundingBox(*hit.bounds)
if (
bounds.maxx - bounds.minx > self.size[1]
and bounds.maxy - bounds.miny > self.size[0]
bounds.maxx - bounds.minx >= self.size[1]
and bounds.maxy - bounds.miny >= self.size[0]
):
self.hits.append(hit)

self.length: int = 0
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)

rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
# last patch samples outside the bounds
rows = math.ceil(
(bounds.maxy - bounds.miny - self.size[0] + self.stride[0])
/ self.stride[0]
)
cols = math.ceil(
(bounds.maxx - bounds.minx - self.size[1] + self.stride[1])
/ self.stride[1]
)
self.length += rows * cols

def __iter__(self) -> Iterator[BoundingBox]:
Expand All @@ -223,8 +231,14 @@ def __iter__(self) -> Iterator[BoundingBox]:
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)

rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
rows = math.ceil(
(bounds.maxy - bounds.miny - self.size[0] + self.stride[0])
/ self.stride[0]
)
cols = math.ceil(
(bounds.maxx - bounds.minx - self.size[1] + self.stride[1])
/ self.stride[1]
)

mint = bounds.mint
maxt = bounds.maxt
Expand All @@ -233,11 +247,17 @@ def __iter__(self) -> Iterator[BoundingBox]:
for i in range(rows):
miny = bounds.miny + i * self.stride[0]
maxy = miny + self.size[0]
if maxy > bounds.maxy:
maxy = bounds.maxy
miny = bounds.maxy - self.size[0]

# For each column...
for j in range(cols):
minx = bounds.minx + j * self.stride[1]
maxx = minx + self.size[1]
if maxx > bounds.maxx:
maxx = bounds.maxx
minx = bounds.maxx - self.size[1]

yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)

Expand Down