diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index d969f962b02..00000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "python.testing.pytestArgs": ["tests"], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true -} diff --git a/tests/data/samplers/filtering_4x4.feather b/tests/data/samplers/filtering_4x4.feather deleted file mode 100644 index 305d37e4fa6..00000000000 Binary files a/tests/data/samplers/filtering_4x4.feather and /dev/null differ diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.cpg b/tests/data/samplers/filtering_4x4/filtering_4x4.cpg deleted file mode 100644 index 57decb48120..00000000000 --- a/tests/data/samplers/filtering_4x4/filtering_4x4.cpg +++ /dev/null @@ -1 +0,0 @@ -ISO-8859-1 diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.dbf b/tests/data/samplers/filtering_4x4/filtering_4x4.dbf deleted file mode 100644 index 499d67bcec4..00000000000 Binary files a/tests/data/samplers/filtering_4x4/filtering_4x4.dbf and /dev/null differ diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.prj b/tests/data/samplers/filtering_4x4/filtering_4x4.prj deleted file mode 100644 index 42fd4b91b78..00000000000 --- a/tests/data/samplers/filtering_4x4/filtering_4x4.prj +++ /dev/null @@ -1 +0,0 @@ -PROJCS["NAD_1983_BC_Environment_Albers",GEOGCS["GCS_North_American_1983",DATUM["D_North_American_1983",SPHEROID["GRS_1980",6378137.0,298.257222101]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Albers"],PARAMETER["False_Easting",1000000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",-126.0],PARAMETER["Standard_Parallel_1",50.0],PARAMETER["Standard_Parallel_2",58.5],PARAMETER["Latitude_Of_Origin",45.0],UNIT["Meter",1.0]] diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.shp b/tests/data/samplers/filtering_4x4/filtering_4x4.shp deleted file mode 100644 index 65606c26dd6..00000000000 Binary files a/tests/data/samplers/filtering_4x4/filtering_4x4.shp and /dev/null differ diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.shx b/tests/data/samplers/filtering_4x4/filtering_4x4.shx deleted file mode 100644 index b2028e759e5..00000000000 Binary files a/tests/data/samplers/filtering_4x4/filtering_4x4.shx and /dev/null differ diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 7cf54f69000..1416368098a 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -2,15 +2,12 @@ # Licensed under the MIT License. import math -import os +from collections.abc import Iterator from itertools import product -import geopandas as gpd import pytest from _pytest.fixtures import SubRequest -from geopandas import GeoDataFrame from rasterio.crs import CRS -from shapely.geometry import box from torch.utils.data import DataLoader from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples @@ -26,23 +23,11 @@ class CustomGeoSampler(GeoSampler): def __init__(self) -> None: - self.chips = self.get_chips() + pass - def get_chips(self) -> GeoDataFrame: - chips = [] + def __iter__(self) -> Iterator[BoundingBox]: for i in range(len(self)): - chips.append( - { - 'geometry': box(i, i, i, i), - 'minx': i, - 'miny': i, - 'maxx': i, - 'maxy': i, - 'mint': i, - 'maxt': i, - } - ) - return GeoDataFrame(chips, crs='3005') + yield BoundingBox(i, i, i, i, i, i) def __len__(self) -> int: return 2 @@ -58,17 +43,6 @@ def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: return {'index': query} -class CustomGeoDatasetSITS(GeoDataset): - def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: - super().__init__() - self._crs = crs - self.res = res - self.return_as_ts = True - - def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: - return {'index': query} - - class TestGeoSampler: @pytest.fixture(scope='class') def dataset(self) -> CustomGeoDataset: @@ -80,14 +54,6 @@ def dataset(self) -> CustomGeoDataset: def sampler(self) -> CustomGeoSampler: return CustomGeoSampler() - @pytest.fixture(scope='class') - def datadir(self) -> str: - return os.path.join('tests', 'data', 'samplers') - - def test_no_get_chips_implemented(self, dataset: CustomGeoDataset) -> None: - with pytest.raises(TypeError): - GeoSampler(dataset) - def test_iter(self, sampler: CustomGeoSampler) -> None: assert next(iter(sampler)) == BoundingBox(0, 0, 0, 0, 0, 0) @@ -98,62 +64,6 @@ def test_abstract(self, dataset: CustomGeoDataset) -> None: with pytest.raises(TypeError, match="Can't instantiate abstract class"): GeoSampler(dataset) # type: ignore[abstract] - @pytest.mark.parametrize( - 'filtering_file', ['filtering_4x4', 'filtering_4x4.feather'] - ) - def test_filtering_from_path(self, datadir: str, filtering_file: str) -> None: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - sampler = GridGeoSampler( - ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10) - ) - iterator = iter(sampler) - - assert len(sampler) == 4 - filtering_path = os.path.join(datadir, filtering_file) - sampler.filter_chips(filtering_path, 'intersects', 'drop') - assert len(sampler) == 3 - assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) - - def test_filtering_from_gdf(self, datadir: str) -> None: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - sampler = GridGeoSampler( - ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10) - ) - iterator = iter(sampler) - - # Dropping first chip - assert len(sampler) == 4 - filtering_gdf = gpd.read_file(os.path.join(datadir, 'filtering_4x4')) - sampler.filter_chips(filtering_gdf, 'intersects', 'drop') - assert len(sampler) == 3 - assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) - - # Keeping only first chip - sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) - iterator = iter(sampler) - sampler.filter_chips(filtering_gdf, 'intersects', 'keep') - assert len(sampler) == 1 - assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) - - def test_set_worker_split(self) -> None: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - sampler = GridGeoSampler( - ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10) - ) - assert len(sampler) == 4 - sampler.set_worker_split(total_workers=4, worker_num=1) - assert len(sampler) == 1 - - def test_save_chips(self, tmpdir_factory) -> None: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) - sampler.save(str(tmpdir_factory.mktemp('out').join('chips'))) - sampler.save(str(tmpdir_factory.mktemp('out').join('chips.feather'))) - @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( @@ -229,15 +139,6 @@ def test_weighted_sampling(self) -> None: for bbox in sampler: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) - def test_return_as_ts(self) -> None: - ds = CustomGeoDatasetSITS() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - ds.index.insert(1, (0, 10, 0, 10, 15, 20)) - sampler = RandomGeoSampler(ds, 1, 5) - for query in sampler: - assert query.mint == ds.bounds.mint == 0 - assert query.maxt == ds.bounds.maxt == 20 - @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( @@ -255,7 +156,7 @@ class TestGridGeoSampler: def dataset(self) -> CustomGeoDataset: ds = CustomGeoDataset() ds.index.insert(0, (0, 100, 200, 300, 400, 500)) - ds.index.insert(1, (0, 100, 200, 300, 500, 600)) + ds.index.insert(1, (0, 100, 200, 300, 400, 500)) return ds @pytest.fixture( @@ -296,13 +197,13 @@ def test_iter(self, sampler: GridGeoSampler) -> None: assert math.isclose(query.maxx - query.minx, sampler.size[1]) assert math.isclose(query.maxy - query.miny, sampler.size[0]) - assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt + assert math.isclose( + query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint + ) def test_len(self, sampler: GridGeoSampler) -> None: rows, cols = tile_to_chips(sampler.roi, sampler.size, sampler.stride) - length = ( - rows * cols * 2 - ) # two spatially but not temporally overlapping items in dataset + length = rows * cols * 2 # two items in dataset assert len(sampler) == length def test_roi(self, dataset: CustomGeoDataset) -> None: @@ -342,29 +243,6 @@ def test_float_multiple(self) -> None: assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) - def test_dataset_has_regex(self) -> None: - ds = CustomGeoDataset() - ds.filename_regex = r'.*(?Ptest)' - ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'filepath_containing_key_test') - sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS) - assert 'my_key' in sampler.chips.columns - - def test_dataset_has_regex_no_match(self) -> None: - ds = CustomGeoDataset() - ds.filename_regex = r'(?Ptest)' - ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'no_matching_key') - sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS) - assert 'my_key' not in sampler.chips.columns - - def test_return_as_ts(self) -> None: - ds = CustomGeoDatasetSITS() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - ds.index.insert(1, (0, 10, 0, 10, 15, 20)) - sampler = GridGeoSampler(ds, 1, 1) - for query in sampler: - assert query.mint == ds.bounds.mint == 0 - assert query.maxt == ds.bounds.maxt == 20 - @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 09f9589e7fa..68a7d853969 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -98,9 +98,6 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): #: a different file format than what it was originally downloaded as. filename_glob = '*' - # Whether to return the dataset as a Timeseries, this will add another dimension to the dataset - return_as_ts = False - # NOTE: according to the Python docs: # # * https://docs.python.org/3/library/exceptions.html#NotImplementedError @@ -986,7 +983,6 @@ def __init__( if not isinstance(ds, GeoDataset): raise ValueError('IntersectionDataset only supports GeoDatasets') - self.return_as_ts = dataset1.return_as_ts or dataset2.return_as_ts self.crs = dataset1.crs self.res = dataset1.res @@ -1147,7 +1143,6 @@ def __init__( if not isinstance(ds, GeoDataset): raise ValueError('UnionDataset only supports GeoDatasets') - self.return_as_ts = dataset1.return_as_ts and dataset2.return_as_ts self.crs = dataset1.crs self.res = dataset1.res diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 2924edad7f4..094142cb647 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -4,86 +4,17 @@ """TorchGeo samplers.""" import abc -import re -import warnings from collections.abc import Callable, Iterable, Iterator -from typing import Any -import geopandas as gpd -import numpy as np -import pandas as pd import torch -from geopandas import GeoDataFrame from rtree.index import Index, Property -from shapely.geometry import box from torch.utils.data import Sampler -from tqdm import tqdm from ..datasets import BoundingBox, GeoDataset from .constants import Units from .utils import _to_tuple, get_random_bounding_box, tile_to_chips -def load_file(path: str | GeoDataFrame) -> GeoDataFrame: - """Load a file from the given path. - - Parameters: - path (str or GeoDataFrame): The path to the file or a GeoDataFrame object. - - Returns: - GeoDataFrame: The loaded file as a GeoDataFrame. - - Raises: - None - - """ - if isinstance(path, GeoDataFrame): - return path - if path.endswith('.feather'): - print(f'Reading feather file: {path}') - return gpd.read_feather(path) - else: - print(f'Reading shapefile: {path}') - return gpd.read_file(path) - - -def _get_regex_groups_as_df(dataset: GeoDataset, hits: list) -> pd.DataFrame: - """Extracts the regex metadata from a list of hits. - - Args: - dataset (GeoDataset): The dataset to sample from. - hits (list): A list of hits. - - Returns: - pandas.DataFrame: A DataFrame containing the extracted file metadata. - """ - has_filename_regex = hasattr(dataset, 'filename_regex') - if has_filename_regex: - filename_regex = re.compile(dataset.filename_regex, re.VERBOSE) - file_metadata = [] - for hit in hits: - if has_filename_regex: - match = re.match(filename_regex, str(hit.object)) - if match: - meta = match.groupdict() - else: - meta = {} - else: - meta = {} - meta.update( - { - 'minx': hit.bounds[0], - 'maxx': hit.bounds[1], - 'miny': hit.bounds[2], - 'maxy': hit.bounds[3], - 'mint': hit.bounds[4], - 'maxt': hit.bounds[5], - } - ) - file_metadata.append(meta) - return pd.DataFrame(file_metadata) - - class GeoSampler(Sampler[BoundingBox], abc.ABC): """Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`. @@ -113,103 +44,18 @@ def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None: self.res = dataset.res self.roi = roi - self.dataset = dataset @abc.abstractmethod - def get_chips(self, *args: Any, **kwargs: Any) -> GeoDataFrame: - """Determines the way to get the extend of the chips (samples) of the dataset. - - Should return a GeoDataFrame with the extend of the chips with the columns - geometry, minx, miny, maxx, maxy, mint, maxt, fid. Each row is a chip. - """ - - def filter_chips( - self, - filter_by: str | GeoDataFrame, - predicate: str = 'intersects', - action: str = 'keep', - ) -> None: - """Filter the default set of chips in the sampler down to a specific subset by specifying files - supported by geopandas such as shapefiles, geodatabases or feather files. - - Args: - filter_by: The file or geodataframe for which the geometries will be used during filtering - predicate: Predicate as used in Geopandas sindex.query_bulk - action: What to do with the chips that satisfy the condition by the predicacte. - Can either be "drop" or "keep". - """ - prefilter_leng = len(self.chips) - filtering_gdf = load_file(filter_by).to_crs(self.dataset.crs) - - if action == 'keep': - self.chips = self.chips.iloc[ - list( - set( - self.chips.sindex.query_bulk( - filtering_gdf.geometry, predicate=predicate - )[1] - ) - ) - ].reset_index(drop=True) - elif action == 'drop': - self.chips = self.chips.drop( - index=list( - set( - self.chips.sindex.query_bulk( - filtering_gdf.geometry, predicate=predicate - )[1] - ) - ) - ).reset_index(drop=True) - - self.chips.fid = self.chips.index - print(f'Filter step reduced chips from {prefilter_leng} to {len(self.chips)}') - assert not self.chips.empty, 'No chips left after filtering!' - - def set_worker_split(self, total_workers: int, worker_num: int) -> None: - """Splits the chips in n equal parts for the number of workers and keeps the set of - chips for the specific worker id, convenient if you want to split the chips across - multiple dataloaders for multi-gpu inference. - - Args: - total_workers: The total number of parts to split the chips - worker_num: The id of the worker (which part to keep), starts from 0 - - """ - self.chips = np.array_split(self.chips, total_workers)[worker_num] - - def save(self, path: str, driver: str) -> None: - """Save the chips as a shapefile or feather file""" - if path.endswith('.feather'): - self.chips.to_feather(path) - else: - self.chips.to_file(path, driver=driver) - def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. Returns: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ - for _, chip in self.chips.iterrows(): - yield BoundingBox( - chip.minx, chip.maxx, chip.miny, chip.maxy, chip.mint, chip.maxt - ) - - def __len__(self) -> int: - """Return the number of samples over the ROI. - - Returns: - number of patches that will be sampled - """ - return len(self.chips) class RandomGeoSampler(GeoSampler): - """Differs from TorchGeo's official RandomGeoSampler in that it can sample SITS data. - - Documentation from TorchGeo: - Samples elements from a region of interest randomly. + """Samples elements from a region of interest randomly. This is particularly useful during training when you want to maximize the size of the dataset and return as many random :term:`chips ` as possible. Note that @@ -259,7 +105,7 @@ def __init__( if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) - num_chips = 0 + self.length = 0 self.hits = [] areas = [] for hit in self.index.intersection(tuple(self.roi), objects=True): @@ -270,53 +116,43 @@ def __init__( ): if bounds.area > 0: rows, cols = tile_to_chips(bounds, self.size) - num_chips += rows * cols + self.length += rows * cols else: - num_chips += 1 + self.length += 1 self.hits.append(hit) areas.append(bounds.area) if length is not None: - num_chips = length - self.length = num_chips + self.length = length # torch.multinomial requires float probabilities > 0 self.areas = torch.tensor(areas, dtype=torch.float) if torch.sum(self.areas) == 0: self.areas += 1 - self.chips = self.get_chips(num_samples=num_chips) + def __iter__(self) -> Iterator[BoundingBox]: + """Return the index of a dataset. - def get_chips(self, num_samples) -> GeoDataFrame: - chips = [] - for _ in tqdm(range(num_samples)): + Returns: + (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + """ + for _ in range(len(self)): # Choose a random tile, weighted by area idx = torch.multinomial(self.areas, 1) hit = self.hits[idx] - hit_bounds = hit.bounds - if self.dataset.return_as_ts: - hit_bounds[-2] = self.dataset.bounds.mint - hit_bounds[-1] = self.dataset.bounds.maxt + bounds = BoundingBox(*hit.bounds) - bounds = BoundingBox(*hit_bounds) # Choose a random index within that tile - bbox = get_random_bounding_box(bounds, self.size, self.res) - minx, maxx, miny, maxy, mint, maxt = tuple(bbox) - chip = { - 'geometry': box(minx, miny, maxx, maxy), - 'minx': minx, - 'miny': miny, - 'maxx': maxx, - 'maxy': maxy, - 'mint': mint, - 'maxt': maxt, - } - chips.append(chip) - - print('creating geodataframe... ') - chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) - chips_gdf['fid'] = chips_gdf.index - - return chips_gdf + bounding_box = get_random_bounding_box(bounds, self.size, self.res) + + yield bounding_box + + def __len__(self) -> int: + """Return the number of samples in a single epoch. + + Returns: + length of the epoch + """ + return self.length class GridGeoSampler(GeoSampler): @@ -370,38 +206,33 @@ def __init__( self.size = (self.size[0] * self.res, self.size[1] * self.res) self.stride = (self.stride[0] * self.res, self.stride[1] * self.res) - hits = self.index.intersection(tuple(self.roi), objects=True) - df_path = _get_regex_groups_as_df(self.dataset, hits) + self.hits = [] + for hit in self.index.intersection(tuple(self.roi), objects=True): + bounds = BoundingBox(*hit.bounds) + if ( + bounds.maxx - bounds.minx >= self.size[1] + and bounds.maxy - bounds.miny >= self.size[0] + ): + self.hits.append(hit) - # Filter out tiles smaller than the chip size - self.df_path = df_path[ - (df_path.maxx - df_path.minx >= self.size[1]) - & (df_path.maxy - df_path.miny >= self.size[0]) - ] + self.length = 0 + for hit in self.hits: + bounds = BoundingBox(*hit.bounds) + rows, cols = tile_to_chips(bounds, self.size, self.stride) + self.length += rows * cols - # Filter out hits in the index that share the same extent - if self.dataset.return_as_ts: - self.df_path.drop_duplicates( - subset=['minx', 'maxx', 'miny', 'maxy'], inplace=True - ) - else: - self.df_path.drop_duplicates( - subset=['minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt'], inplace=True - ) - - self.chips = self.get_chips() - - def get_chips(self) -> GeoDataFrame: - print('generating samples... ') - optional_keys = set(self.df_path.keys()) - set( - ['geometry', 'minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt'] - ) - chips = [] - for _, row in tqdm(self.df_path.iterrows(), total=len(self.df_path)): - bounds = BoundingBox( - row.minx, row.maxx, row.miny, row.maxy, row.mint, row.maxt - ) + def __iter__(self) -> Iterator[BoundingBox]: + """Return the index of a dataset. + + Returns: + (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + """ + # For each tile... + for hit in self.hits: + bounds = BoundingBox(*hit.bounds) rows, cols = tile_to_chips(bounds, self.size, self.stride) + mint = bounds.mint + maxt = bounds.maxt # For each row... for i in range(rows): @@ -413,37 +244,15 @@ def get_chips(self) -> GeoDataFrame: minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] - if self.dataset.return_as_ts: - mint = self.dataset.bounds.mint - maxt = self.dataset.bounds.maxt - else: - mint = bounds.mint - maxt = bounds.maxt - - chip = { - 'geometry': box(minx, miny, maxx, maxy), - 'minx': minx, - 'miny': miny, - 'maxx': maxx, - 'maxy': maxy, - 'mint': mint, - 'maxt': maxt, - } - for key in optional_keys: - if key in row.keys(): - chip[key] = row[key] - - chips.append(chip) - - if chips: - print('creating geodataframe... ') - chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) - chips_gdf['fid'] = chips_gdf.index + yield BoundingBox(minx, maxx, miny, maxy, mint, maxt) - else: - warnings.warn('Sampler has no chips, check your inputs') - chips_gdf = GeoDataFrame() - return chips_gdf + def __len__(self) -> int: + """Return the number of samples over the ROI. + + Returns: + number of patches that will be sampled + """ + return self.length class PreChippedGeoSampler(GeoSampler): @@ -480,29 +289,23 @@ def __init__( for hit in self.index.intersection(tuple(self.roi), objects=True): self.hits.append(hit) - self.chips = self.get_chips() + def __iter__(self) -> Iterator[BoundingBox]: + """Return the index of a dataset. - def get_chips(self) -> GeoDataFrame: + Returns: + (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + """ generator: Callable[[int], Iterable[int]] = range if self.shuffle: generator = torch.randperm - chips = [] - for idx in generator(len(self.hits)): - minx, maxx, miny, maxy, mint, maxt = self.hits[idx].bounds - chip = { - 'geometry': box(minx, miny, maxx, maxy), - 'minx': minx, - 'miny': miny, - 'maxx': maxx, - 'maxy': maxy, - 'mint': mint, - 'maxt': maxt, - } - chips.append(chip) - - print('creating geodataframe... ') - chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) - chips_gdf['fid'] = chips_gdf.index - - return chips_gdf + for idx in generator(len(self)): + yield BoundingBox(*self.hits[idx].bounds) + + def __len__(self) -> int: + """Return the number of samples over the ROI. + + Returns: + number of patches that will be sampled + """ + return len(self.hits)