Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random generator #6

Merged
merged 8 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .git-blame-ignore-revs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Double quote -> single quote
# Prettier: double quote -> single quote
6a5aaf4b93507072d40dcd78114893362c4eaf6e
# Ruff: double quote -> single quote
b09122f3e4a9cb422f6747bf33eca02993f67549
# Prettier
bd9c75798eede1a4b7d7ecd6203179d3cb5e54dd
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tutorials.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -r requirements/required.txt -r requirements/docs.txt -r requirements/tests.txt planetary_computer pystac
pip install -r requirements/required.txt -r requirements/docs.txt -r requirements/tests.txt planetary_computer pystac .
pip cache purge
- name: List pip dependencies
run: pip list
Expand Down
8 changes: 4 additions & 4 deletions tests/conf/landcoverai100.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
loss: 'ce'
model: 'unet'
backbone: 'resnet18'
in_channels: 3
num_classes: 5
num_filters: 1
Expand All @@ -13,4 +13,4 @@ data:
init_args:
batch_size: 1
dict_kwargs:
root: "tests/data/landcoverai"
root: 'tests/data/landcoverai'
15 changes: 15 additions & 0 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import product

import pytest
import torch
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -144,6 +145,20 @@ def test_weighted_sampling(self) -> None:
for bbox in batch:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
for bbox in sampler:
sample1 = bbox
break

sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
for bbox in sampler:
sample2 = bbox
break
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
35 changes: 35 additions & 0 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import product

import pytest
import torch
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -139,6 +140,21 @@ def test_weighted_sampling(self) -> None:
for bbox in sampler:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
generator = torch.manual_seed(0)
sampler = RandomGeoSampler(ds, 1, 1, generator=generator)
for bbox in sampler:
sample1 = bbox
break

sampler = RandomGeoSampler(ds, 1, 1, generator=generator)
for bbox in sampler:
sample2 = bbox
break
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down Expand Up @@ -288,6 +304,25 @@ def test_point_data(self) -> None:
for _ in sampler:
continue

def test_shuffle_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 11, 0, 11, 0, 11))
generator = torch.manual_seed(2)
sampler1 = PreChippedGeoSampler(ds, shuffle=True, generator=generator)
for bbox in sampler1:
sample1 = bbox
print(sample1)
break

generator = torch.manual_seed(2)
sampler2 = PreChippedGeoSampler(ds, shuffle=True, generator=generator)
for bbox in sampler2:
sample2 = bbox
print(sample2)
break
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
6 changes: 5 additions & 1 deletion torchgeo/datamodules/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def setup(self, stage: str) -> None:

if stage in ['fit']:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
self.train_dataset,
self.patch_size,
self.batch_size,
self.length,
generator=generator,
)
if stage in ['fit', 'validate']:
self.val_sampler = GridGeoSampler(
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
batch_sampler=batch_sampler,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
persistent_workers=self.num_workers > 0,
)

def train_dataloader(self) -> DataLoader[dict[str, Tensor]]:
Expand Down Expand Up @@ -429,6 +430,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
shuffle=split == 'train',
num_workers=self.num_workers,
collate_fn=self.collate_fn,
persistent_workers=self.num_workers > 0,
)

def train_dataloader(self) -> DataLoader[dict[str, Tensor]]:
Expand Down
7 changes: 6 additions & 1 deletion torchgeo/samplers/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
length: int | None = None,
roi: BoundingBox | None = None,
units: Units = Units.PIXELS,
generator: torch.Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.

Expand Down Expand Up @@ -97,9 +98,11 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
units: defines if ``size`` is in pixel or CRS units
generator: random number generator
"""
super().__init__(dataset, roi)
self.size = _to_tuple(size)
self.generator = generator

if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)
Expand Down Expand Up @@ -144,7 +147,9 @@ def __iter__(self) -> Iterator[list[BoundingBox]]:
# Choose random indices within that tile
batch = []
for _ in range(self.batch_size):
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
bounding_box = get_random_bounding_box(
bounds, self.size, self.res, self.generator
)
batch.append(bounding_box)

yield batch
Expand Down
20 changes: 17 additions & 3 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import abc
from collections.abc import Callable, Iterable, Iterator
from functools import partial

import torch
from rtree.index import Index, Property
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
length: int | None = None,
roi: BoundingBox | None = None,
units: Units = Units.PIXELS,
generator: torch.Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.

Expand All @@ -98,13 +100,16 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
units: defines if ``size`` is in pixel or CRS units
generator: The random generator used for sampling.

"""
super().__init__(dataset, roi)
self.size = _to_tuple(size)

if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)

self.generator = generator
self.length = 0
self.hits = []
areas = []
Expand Down Expand Up @@ -142,7 +147,9 @@ def __iter__(self) -> Iterator[BoundingBox]:
bounds = BoundingBox(*hit.bounds)

# Choose a random index within that tile
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
bounding_box = get_random_bounding_box(
bounds, self.size, self.res, self.generator
)

yield bounding_box

Expand Down Expand Up @@ -270,7 +277,11 @@ class PreChippedGeoSampler(GeoSampler):
"""

def __init__(
self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False
self,
dataset: GeoDataset,
roi: BoundingBox | None = None,
shuffle: bool = False,
generator: torch.Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.

Expand All @@ -281,9 +292,12 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
shuffle: if True, reshuffle data at every epoch
generator: The random number generator used in combination with shuffle.

"""
super().__init__(dataset, roi)
self.shuffle = shuffle
self.generator = generator

self.hits = []
for hit in self.index.intersection(tuple(self.roi), objects=True):
Expand All @@ -297,7 +311,7 @@ def __iter__(self) -> Iterator[BoundingBox]:
"""
generator: Callable[[int], Iterable[int]] = range
if self.shuffle:
generator = torch.randperm
generator = partial(torch.randperm, generator=self.generator)

for idx in generator(len(self)):
yield BoundingBox(*self.hits[idx].bounds)
Expand Down
10 changes: 7 additions & 3 deletions torchgeo/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]:


def get_random_bounding_box(
bounds: BoundingBox, size: tuple[float, float] | float, res: float
bounds: BoundingBox,
size: tuple[float, float] | float,
res: float,
generator: torch.Generator | None = None,
) -> BoundingBox:
"""Returns a random bounding box within a given bounding box.

Expand All @@ -50,6 +53,7 @@ def get_random_bounding_box(
bounds: the larger bounding box to sample from
size: the size of the bounding box to sample
res: the resolution of the image
generator: random number generator

Returns:
randomly sampled bounding box from the extent of the input
Expand All @@ -64,8 +68,8 @@ def get_random_bounding_box(
miny = bounds.miny

# Use an integer multiple of res to avoid resampling
minx += int(torch.rand(1).item() * width) * res
miny += int(torch.rand(1).item() * height) * res
minx += int(torch.rand(1, generator=generator).item() * width) * res
miny += int(torch.rand(1, generator=generator).item() * height) * res

maxx = minx + t_size[1]
maxy = miny + t_size[0]
Expand Down