Skip to content

Commit

Permalink
RandomGeoSampler: several bug fixes (#477)
Browse files Browse the repository at this point in the history
* RandomGeoSampler: prevent area bias

* Use builtin PyTorch random

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
  • Loading branch information
adamjstewart and calebrob6 authored Apr 5, 2022
1 parent 0a68c15 commit f262b00
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 15 deletions.
17 changes: 17 additions & 0 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,23 @@ def test_small_area(self) -> None:
for _ in sampler:
continue

def test_point_data(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 0, 0, 0, 0, 0))
ds.index.insert(1, (1, 1, 1, 1, 1, 1))
sampler = RandomBatchGeoSampler(ds, 0, 2, 10)
for _ in sampler:
continue

def test_weighted_sampling(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 0, 0, 0, 0, 0))
ds.index.insert(1, (0, 10, 0, 10, 0, 10))
sampler = RandomBatchGeoSampler(ds, 1, 2, 10)
for batch in sampler:
for bbox in batch:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(
Expand Down
16 changes: 16 additions & 0 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,22 @@ def test_small_area(self) -> None:
for _ in sampler:
continue

def test_point_data(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 0, 0, 0, 0, 0))
ds.index.insert(1, (1, 1, 1, 1, 1, 1))
sampler = RandomGeoSampler(ds, 0, 10)
for _ in sampler:
continue

def test_weighted_sampling(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 0, 0, 0, 0, 0))
ds.index.insert(1, (0, 10, 0, 10, 0, 10))
sampler = RandomGeoSampler(ds, 1, 10)
for bbox in sampler:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(
Expand Down
18 changes: 13 additions & 5 deletions torchgeo/samplers/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
"""TorchGeo batch samplers."""

import abc
import random
from typing import Iterator, List, Optional, Tuple, Union

import torch
from rtree.index import Index, Property
from torch.utils.data import Sampler

Expand Down Expand Up @@ -104,13 +104,20 @@ def __init__(
self.batch_size = batch_size
self.length = length
self.hits = []
areas = []
for hit in self.index.intersection(tuple(self.roi), objects=True):
bounds = BoundingBox(*hit.bounds)
if (
bounds.maxx - bounds.minx > self.size[1]
and bounds.maxy - bounds.miny > self.size[0]
bounds.maxx - bounds.minx >= self.size[1]
and bounds.maxy - bounds.miny >= self.size[0]
):
self.hits.append(hit)
areas.append(bounds.area)

# torch.multinomial requires float probabilities > 0
self.areas = torch.tensor(areas, dtype=torch.float)
if torch.sum(self.areas) == 0:
self.areas += 1

def __iter__(self) -> Iterator[List[BoundingBox]]:
"""Return the indices of a dataset.
Expand All @@ -119,8 +126,9 @@ def __iter__(self) -> Iterator[List[BoundingBox]]:
batch of (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""
for _ in range(len(self)):
# Choose a random tile
hit = random.choice(self.hits)
# Choose a random tile, weighted by area
idx = torch.multinomial(self.areas, 1)
hit = self.hits[idx]
bounds = BoundingBox(*hit.bounds)

# Choose random indices within that tile
Expand Down
17 changes: 12 additions & 5 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""TorchGeo samplers."""

import abc
import random
from typing import Callable, Iterable, Iterator, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -105,13 +104,20 @@ def __init__(

self.length = length
self.hits = []
areas = []
for hit in self.index.intersection(tuple(self.roi), objects=True):
bounds = BoundingBox(*hit.bounds)
if (
bounds.maxx - bounds.minx > self.size[1]
and bounds.maxy - bounds.miny > self.size[0]
bounds.maxx - bounds.minx >= self.size[1]
and bounds.maxy - bounds.miny >= self.size[0]
):
self.hits.append(hit)
areas.append(bounds.area)

# torch.multinomial requires float probabilities > 0
self.areas = torch.tensor(areas, dtype=torch.float)
if torch.sum(self.areas) == 0:
self.areas += 1

def __iter__(self) -> Iterator[BoundingBox]:
"""Return the index of a dataset.
Expand All @@ -120,8 +126,9 @@ def __iter__(self) -> Iterator[BoundingBox]:
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""
for _ in range(len(self)):
# Choose a random tile
hit = random.choice(self.hits)
# Choose a random tile, weighted by area
idx = torch.multinomial(self.areas, 1)
hit = self.hits[idx]
bounds = BoundingBox(*hit.bounds)

# Choose a random index within that tile
Expand Down
18 changes: 13 additions & 5 deletions torchgeo/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

"""Common sampler utilities."""

import random
from typing import Tuple, Union

import torch

from ..datasets import BoundingBox


Expand Down Expand Up @@ -46,11 +47,18 @@ def get_random_bounding_box(
t_size = _to_tuple(size)

width = (bounds.maxx - bounds.minx - t_size[1]) // res
minx = random.randrange(int(width)) * res + bounds.minx
maxx = minx + t_size[1]

height = (bounds.maxy - bounds.miny - t_size[0]) // res
miny = random.randrange(int(height)) * res + bounds.miny

minx = bounds.minx
miny = bounds.miny

# random.randrange crashes for inputs <= 0
if width > 0:
minx += torch.rand(1).item() * width * res
if height > 0:
miny += torch.rand(1).item() * height * res

maxx = minx + t_size[1]
maxy = miny + t_size[0]

mint = bounds.mint
Expand Down

0 comments on commit f262b00

Please sign in to comment.