Skip to content

Commit

Permalink
Merge remote-tracking branch 'torchgeo/main' into geosampler_prechipping
Browse files Browse the repository at this point in the history
  • Loading branch information
sfalkena committed Sep 24, 2024
2 parents c51a63b + e932902 commit 25ce0e1
Show file tree
Hide file tree
Showing 13 changed files with 99 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/style.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
- name: Clone repo
uses: actions/checkout@v4.1.7
- name: Set up nodejs
uses: actions/setup-node@v4.0.3
uses: actions/setup-node@v4.0.4
with:
node-version: '20'
cache: 'npm'
Expand Down
4 changes: 4 additions & 0 deletions docs/tutorials/custom_raster_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@
"\n",
"If your data only contains model inputs (such as images), use `is_image = True`. If your data only contains ground truth model outputs (such as segmentation masks), use `is_image = False` instead.\n",
"\n",
"Consequently, the sample returned by the dataset/data loader will use the \"image\" key if *is_image* is True, otherwise it will use the \"mask\" key.\n",
"\n",
"For datasets with both model inputs and outputs, the recommended approach is to use 2 `RasterDataset` instances and combine them using an `IntersectionDataset`. See L7 Irish, L8 Biome, and I/O Bench for examples of this in `torchgeo/datasets`.\n",
"\n",
"### `dtype`\n",
"\n",
"Defaults to float32 for `is_image == True` and long for `is_image == False`. This is what you want for 99% of datasets, but can be overridden for tasks like pixel-wise regression (where the target mask should be float32).\n",
Expand Down
11 changes: 6 additions & 5 deletions docs/user/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,14 @@ Datasets
A major component of TorchGeo is the large collection of :mod:`torchgeo.datasets` that have been implemented. Adding new datasets to this list is a great way to contribute to the library. A brief checklist to follow when implementing a new dataset:

* Implement the dataset extending either :class:`~torchgeo.datasets.GeoDataset` or :class:`~torchgeo.datasets.NonGeoDataset`
* Add the dataset definition to ``torchgeo/datasets/__init__.py``
* Add a ``data.py`` script to ``tests/data/<new dataset>/`` that generates test data with the same directory structure/file naming conventions as the new dataset
* Add appropriate tests with 100% test coverage to ``tests/datasets/``
* Add the dataset definition to ``torchgeo/datasets/foo.py``, where *foo* is the name of the dataset
* Add an import alias to this dataset in ``torchgeo/datasets/__init__.py``
* Add a ``tests/data/foo/data.py`` script that generates fake test data with the same directory structure/file naming conventions as the real dataset
* Add appropriate tests with 100% test coverage to ``tests/datasets/test_foo.py``
* Add the dataset to ``docs/api/datasets.rst``
* Add the dataset metadata to either ``docs/api/geo_datasets.csv`` or ``docs/api/non_geo_datasets.csv``
* Add the dataset metadata to either ``docs/api/datasets/geo_datasets.csv`` or ``docs/api/datasets/non_geo_datasets.csv``

A good way to get started is by looking at some of the existing implementations that are most closely related to the dataset that you are implementing (e.g. if you are implementing a semantic segmentation dataset, looking at the LandCover.ai dataset implementation would be a good starting point).
A good way to get started is by looking at some of the existing implementations that are most closely related to the dataset that you are implementing (e.g., if you are implementing a semantic segmentation dataset, looking at the LandCover.ai dataset implementation would be a good starting point).

I/O Benchmarking
----------------
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ filterwarnings = [
"ignore:torch.is_autocast_cpu_enabled\\(\\) is deprecated.:DeprecationWarning:kornia.utils.helpers",
# https://github.com/pytorch/pytorch/pull/129239
"ignore:You are using `torch.load` with `weights_only=False`:FutureWarning",
# https://github.com/pytorch/pytorch/issues/136264
"ignore:__array__ implementation doesn't accept a copy keyword:DeprecationWarning",
"ignore:__array_wrap__ must accept context and return_scalar arguments:DeprecationWarning",

# Expected warnings
# Lightning warns us about using num_workers=0, but it's faster on macOS
Expand Down
2 changes: 1 addition & 1 deletion requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ lightly==1.5.12
lightning[pytorch-extra]==2.4.0
matplotlib==3.9.2
numpy==2.1.1
pandas==2.2.2
pandas==2.2.3
pillow==10.4.0
pyarrow==17.0.0
pyproj==3.6.1
Expand Down
2 changes: 1 addition & 1 deletion requirements/style.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# style
mypy==1.11.2
ruff==0.6.5
ruff==0.6.7
10 changes: 10 additions & 0 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import product

import pytest
import torch
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -144,6 +145,15 @@ def test_weighted_sampling(self) -> None:
for bbox in batch:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler1 = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
sampler2 = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
sample1 = next(iter(sampler1))
sample2 = next(iter(sampler2))
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
24 changes: 24 additions & 0 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import geopandas as gpd
import pytest
import torch
from _pytest.fixtures import SubRequest
from geopandas import GeoDataFrame
from rasterio.crs import CRS
Expand Down Expand Up @@ -222,6 +223,15 @@ def test_weighted_sampling(self) -> None:
for bbox in sampler:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler1 = RandomGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
sampler2 = RandomGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
sample1 = next(iter(sampler1))
sample2 = next(iter(sampler2))
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down Expand Up @@ -371,6 +381,20 @@ def test_point_data(self) -> None:
for _ in sampler:
continue

def test_shuffle_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 11, 0, 11, 0, 11))
sampler1 = PreChippedGeoSampler(
ds, shuffle=True, generator=torch.manual_seed(2)
)
sampler2 = PreChippedGeoSampler(
ds, shuffle=True, generator=torch.manual_seed(2)
)
sample1 = next(iter(sampler1))
sample2 = next(iter(sampler2))
assert sample1 != sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
6 changes: 5 additions & 1 deletion torchgeo/datamodules/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def setup(self, stage: str) -> None:

if stage in ['fit']:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
self.train_dataset,
self.patch_size,
self.batch_size,
self.length,
generator=generator,
)
if stage in ['fit', 'validate']:
self.val_sampler = GridGeoSampler(
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ class RasterDataset(GeoDataset):
#: The sample returned by the dataset/data loader will use the "image" key if
#: *is_image* is True, otherwise it will use the "mask" key.
#:
#: For datasets with both model inputs and outputs, a custom
#: :func:`~RasterDataset.__getitem__` method must be implemented.
#: For datasets with both model inputs and outputs, the recommended approach is
#: to use 2 `RasterDataset` instances and combine them using an `IntersectionDataset`.
is_image = True

#: True if data is stored in a separate file for each band, else False.
Expand Down
11 changes: 10 additions & 1 deletion torchgeo/samplers/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from rtree.index import Index, Property
from torch import Generator
from torch.utils.data import Sampler

from ..datasets import BoundingBox, GeoDataset
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
length: int | None = None,
roi: BoundingBox | None = None,
units: Units = Units.PIXELS,
generator: Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.
Expand All @@ -86,6 +88,9 @@ def __init__(
.. versionchanged:: 0.4
``length`` parameter is now optional, a reasonable default will be used
.. versionadded:: 0.7
The *generator* parameter.
Args:
dataset: dataset to index from
size: dimensions of each :term:`patch`
Expand All @@ -97,9 +102,11 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
units: defines if ``size`` is in pixel or CRS units
generator: pseudo-random number generator (PRNG).
"""
super().__init__(dataset, roi)
self.size = _to_tuple(size)
self.generator = generator

if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)
Expand Down Expand Up @@ -144,7 +151,9 @@ def __iter__(self) -> Iterator[list[BoundingBox]]:
# Choose random indices within that tile
batch = []
for _ in range(self.batch_size):
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
bounding_box = get_random_bounding_box(
bounds, self.size, self.res, self.generator
)
batch.append(bounding_box)

yield batch
Expand Down
24 changes: 21 additions & 3 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

import abc
from collections.abc import Callable, Iterable, Iterator
from functools import partial

import geopandas as gpd
import numpy as np
import torch
from geopandas import GeoDataFrame
from rtree.index import Index, Property
from shapely.geometry import box
from torch import Generator
from torch.utils.data import Sampler
from tqdm import tqdm

Expand Down Expand Up @@ -210,6 +212,7 @@ def __init__(
length: int | None = None,
roi: BoundingBox | None = None,
units: Units = Units.PIXELS,
generator: Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.
Expand All @@ -226,6 +229,9 @@ def __init__(
.. versionchanged:: 0.4
``length`` parameter is now optional, a reasonable default will be used
.. versionadded:: 0.7
The *generator* parameter.
Args:
dataset: dataset to index from
size: dimensions of each :term:`patch`
Expand All @@ -236,13 +242,15 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
units: defines if ``size`` is in pixel or CRS units
generator: pseudo-random number generator (PRNG).
"""
super().__init__(dataset, roi)
self.size = _to_tuple(size)

if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)

self.generator = generator
self.length = 0
self.hits = []
areas = []
Expand Down Expand Up @@ -304,7 +312,7 @@ def get_chips(self) -> GeoDataFrame:
bounds = BoundingBox(*hit.bounds)

# Choose a random index within that tile
bbox = get_random_bounding_box(bounds, self.size, self.res)
bbox = get_random_bounding_box(bounds, self.size, self.res, self.generator)
minx, maxx, miny, maxy, mint, maxt = tuple(bbox)
chip = {
'geometry': box(minx, miny, maxx, maxy),
Expand Down Expand Up @@ -447,20 +455,30 @@ class PreChippedGeoSampler(GeoSampler):
"""

def __init__(
self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False
self,
dataset: GeoDataset,
roi: BoundingBox | None = None,
shuffle: bool = False,
generator: torch.Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.
.. versionadded:: 0.3
.. versionadded:: 0.7
The *generator* parameter.
Args:
dataset: dataset to index from
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
shuffle: if True, reshuffle data at every epoch
generator: pseudo-random number generator (PRNG) used in combination with shuffle.
"""
super().__init__(dataset, roi)
self.shuffle = shuffle
self.generator = generator

self.hits = []
for hit in self.index.intersection(tuple(self.roi), objects=True):
Expand All @@ -477,7 +495,7 @@ def get_chips(self) -> GeoDataFrame:
"""
generator: Callable[[int], Iterable[int]] = range
if self.shuffle:
generator = torch.randperm
generator = partial(torch.randperm, generator=self.generator)

print('generating samples... ')
chips = []
Expand Down
14 changes: 11 additions & 3 deletions torchgeo/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import overload

import torch
from torch import Generator

from ..datasets import BoundingBox

Expand Down Expand Up @@ -35,7 +36,10 @@ def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]:


def get_random_bounding_box(
bounds: BoundingBox, size: tuple[float, float] | float, res: float
bounds: BoundingBox,
size: tuple[float, float] | float,
res: float,
generator: Generator | None = None,
) -> BoundingBox:
"""Returns a random bounding box within a given bounding box.
Expand All @@ -46,10 +50,14 @@ def get_random_bounding_box(
* a ``tuple`` of two floats - in which case, the first *float* is used for the
height dimension, and the second *float* for the width dimension
.. versionadded:: 0.7
The *generator* parameter.
Args:
bounds: the larger bounding box to sample from
size: the size of the bounding box to sample
res: the resolution of the image
generator: pseudo-random number generator (PRNG).
Returns:
randomly sampled bounding box from the extent of the input
Expand All @@ -64,8 +72,8 @@ def get_random_bounding_box(
miny = bounds.miny

# Use an integer multiple of res to avoid resampling
minx += int(torch.rand(1).item() * width) * res
miny += int(torch.rand(1).item() * height) * res
minx += int(torch.rand(1, generator=generator).item() * width) * res
miny += int(torch.rand(1, generator=generator).item() * height) * res

maxx = minx + t_size[1]
maxy = miny + t_size[0]
Expand Down

0 comments on commit 25ce0e1

Please sign in to comment.