Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache reads in RasterDataset #85

Merged
merged 1 commit into from
Aug 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torchgeo/datasets/cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
cache: bool = True,
download: bool = False,
checksum: bool = False,
) -> None:
Expand All @@ -71,6 +72,7 @@ def __init__(
(defaults to the resolution of the first file found)
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)

Expand All @@ -90,7 +92,7 @@ def __init__(
+ "You can use download=True to download it"
)

super().__init__(root, crs, res, transforms)
super().__init__(root, crs, res, transforms, cache)

def _check_integrity(self) -> bool:
"""Check integrity of dataset.
Expand Down
4 changes: 3 additions & 1 deletion torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
cache: bool = True,
download: bool = False,
checksum: bool = False,
) -> None:
Expand All @@ -84,6 +85,7 @@ def __init__(
(defaults to the resolution of the first file found)
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)

Expand All @@ -103,7 +105,7 @@ def __init__(
+ "You can use download=True to download it"
)

super().__init__(root, crs, res, transforms)
super().__init__(root, crs, res, transforms, cache)

def _check_integrity(self) -> bool:
"""Check integrity of dataset.
Expand Down
50 changes: 40 additions & 10 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base classes for all :mod:`torchgeo` datasets."""

import abc
import functools
import glob
import math
import os
Expand All @@ -17,6 +18,7 @@
import rasterio.merge
import torch
from rasterio.crs import CRS
from rasterio.io import DatasetReader
from rasterio.vrt import WarpedVRT
from rtree.index import Index, Property
from torch import Tensor
Expand Down Expand Up @@ -173,6 +175,7 @@ def __init__(
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
cache: bool = True,
) -> None:
"""Initialize a new Dataset instance.

Expand All @@ -184,13 +187,15 @@ def __init__(
(defaults to the resolution of the first file found)
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling

Raises:
FileNotFoundError: if no files are found in ``root``
"""
super().__init__(transforms)

self.root = root
self.cache = cache

# Populate the dataset index
i = 0
Expand Down Expand Up @@ -304,24 +309,49 @@ def _merge_files(self, filepaths: Sequence[str], query: BoundingBox) -> Tensor:
Returns:
image/mask at that index
"""
# Open files
src_fhs = [rasterio.open(fn) for fn in filepaths]

# Warp to a possibly new CRS
vrt_fhs = [WarpedVRT(src, crs=self.crs) for src in src_fhs]
if self.cache:
vrt_fhs = [self._cached_load_warp_file(fp) for fp in filepaths]
else:
vrt_fhs = [self._load_warp_file(fp) for fp in filepaths]

# Merge files
bounds = (query.minx, query.miny, query.maxx, query.maxy)
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res)
dest = dest.astype(np.int32)

# Close file handles
[fh.close() for fh in src_fhs]
[fh.close() for fh in vrt_fhs]

tensor: Tensor = torch.tensor(dest) # type: ignore[attr-defined]
return tensor

@functools.lru_cache(maxsize=128)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcode cache size of 128 doesn't seem wise?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cache size doesn't actually matter here. We're caching file handles, not images. Each image is ~0.5 GB, but each file handle is tiny. GDAL has its own cache size that controls how large the cache actually is. Users can change this if they need to.

def _cached_load_warp_file(self, filepath: str) -> DatasetReader:
"""Cached version of :meth:`_load_warp_file`.

Args:
filepath: file to load and warp

Returns:
file handle of warped VRT
"""
return self._load_warp_file(filepath)

def _load_warp_file(self, filepath: str) -> DatasetReader:
"""Load and warp a file to the correct CRS and resolution.

Args:
filepath: file to load and warp

Returns:
file handle of warped VRT
"""
src = rasterio.open(filepath)

# Only warp if necessary
if src.crs != self.crs:
vrt = WarpedVRT(src, crs=self.crs)
src.close()
return vrt
else:
return src

def plot(self, data: Tensor) -> None:
"""Plot a data sample.

Expand Down
4 changes: 3 additions & 1 deletion torchgeo/datasets/landsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
res: Optional[float] = None,
bands: Sequence[str] = [],
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
cache: bool = True,
) -> None:
"""Initialize a new Dataset instance.

Expand All @@ -64,13 +65,14 @@ def __init__(
bands: bands to return (defaults to all bands)
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling

Raises:
FileNotFoundError: if no files are found in ``root``
"""
self.bands = bands if bands else self.all_bands

super().__init__(root, crs, res, transforms)
super().__init__(root, crs, res, transforms, cache)


class Landsat1(Landsat):
Expand Down
4 changes: 3 additions & 1 deletion torchgeo/datasets/sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
res: Optional[float] = None,
bands: Sequence[str] = [],
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
cache: bool = True,
) -> None:
"""Initialize a new Dataset instance.

Expand All @@ -86,10 +87,11 @@ def __init__(
bands: bands to return (defaults to all bands)
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling

Raises:
FileNotFoundError: if no files are found in ``root``
"""
self.bands = bands if bands else self.all_bands

super().__init__(root, crs, res, transforms)
super().__init__(root, crs, res, transforms, cache)