From 889735c8a4d260793075a6537f4159c19304e2fa Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Wed, 28 Aug 2024 17:00:38 +0200 Subject: [PATCH] Run pre-commit (tg ruleset) WIP --- .vscode/settings.json | 12 ++- tests/samplers/test_single.py | 80 ++++++++--------- torchgeo/datasets/geo.py | 2 +- torchgeo/samplers/single.py | 160 +++++++++++++++++----------------- 4 files changed, 120 insertions(+), 134 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index ad7af29f625..d969f962b02 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,7 +1,5 @@ -{ - "python.testing.pytestArgs": [ - "tests" - ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true -} \ No newline at end of file +{ + "python.testing.pytestArgs": ["tests"], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 10634624227..7cf54f69000 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -1,14 +1,16 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import os import math -from collections.abc import Iterator +import os from itertools import product +import geopandas as gpd import pytest from _pytest.fixtures import SubRequest +from geopandas import GeoDataFrame from rasterio.crs import CRS +from shapely.geometry import box from torch.utils.data import DataLoader from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples @@ -21,10 +23,6 @@ tile_to_chips, ) -import geopandas as gpd -from geopandas import GeoDataFrame -from shapely.geometry import box - class CustomGeoSampler(GeoSampler): def __init__(self) -> None: @@ -35,16 +33,16 @@ def get_chips(self) -> GeoDataFrame: for i in range(len(self)): chips.append( { - "geometry": box(i, i, i, i), - "minx": i, - "miny": i, - "maxx": i, - "maxy": i, - "mint": i, - "maxt": i, + 'geometry': box(i, i, i, i), + 'minx': i, + 'miny': i, + 'maxx': i, + 'maxy': i, + 'mint': i, + 'maxt': i, } ) - return GeoDataFrame(chips, crs="3005") + return GeoDataFrame(chips, crs='3005') def __len__(self) -> int: return 2 @@ -65,11 +63,12 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: super().__init__() self._crs = crs self.res = res - self.return_as_ts=True + self.return_as_ts = True def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: return {'index': query} + class TestGeoSampler: @pytest.fixture(scope='class') def dataset(self) -> CustomGeoDataset: @@ -85,9 +84,9 @@ def sampler(self) -> CustomGeoSampler: def datadir(self) -> str: return os.path.join('tests', 'data', 'samplers') - def test_no_get_chips_implemented(self) -> None: + def test_no_get_chips_implemented(self, dataset: CustomGeoDataset) -> None: with pytest.raises(TypeError): - GeoSampler() + GeoSampler(dataset) def test_iter(self, sampler: CustomGeoSampler) -> None: assert next(iter(sampler)) == BoundingBox(0, 0, 0, 0, 0, 0) @@ -100,13 +99,9 @@ def test_abstract(self, dataset: CustomGeoDataset) -> None: GeoSampler(dataset) # type: ignore[abstract] @pytest.mark.parametrize( - "filtering_file", - [ - "filtering_4x4", - "filtering_4x4.feather", - ], + 'filtering_file', ['filtering_4x4', 'filtering_4x4.feather'] ) - def test_filtering_from_path(self, datadir, filtering_file) -> None: + def test_filtering_from_path(self, datadir: str, filtering_file: str) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) sampler = GridGeoSampler( @@ -116,11 +111,11 @@ def test_filtering_from_path(self, datadir, filtering_file) -> None: assert len(sampler) == 4 filtering_path = os.path.join(datadir, filtering_file) - sampler.filter_chips(filtering_path, "intersects", "drop") + sampler.filter_chips(filtering_path, 'intersects', 'drop') assert len(sampler) == 3 assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) - def test_filtering_from_gdf(self, datadir) -> None: + def test_filtering_from_gdf(self, datadir: str) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) sampler = GridGeoSampler( @@ -130,17 +125,15 @@ def test_filtering_from_gdf(self, datadir) -> None: # Dropping first chip assert len(sampler) == 4 - filtering_gdf = gpd.read_file( - os.path.join(datadir, "filtering_4x4") - ) - sampler.filter_chips(filtering_gdf, "intersects", "drop") + filtering_gdf = gpd.read_file(os.path.join(datadir, 'filtering_4x4')) + sampler.filter_chips(filtering_gdf, 'intersects', 'drop') assert len(sampler) == 3 assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) # Keeping only first chip sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) iterator = iter(sampler) - sampler.filter_chips(filtering_gdf, "intersects", "keep") + sampler.filter_chips(filtering_gdf, 'intersects', 'keep') assert len(sampler) == 1 assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) @@ -158,9 +151,8 @@ def test_save_chips(self, tmpdir_factory) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) - sampler.save(str(tmpdir_factory.mktemp("out").join("chips"))) - sampler.save(str(tmpdir_factory.mktemp("out").join("chips.feather"))) - + sampler.save(str(tmpdir_factory.mktemp('out').join('chips'))) + sampler.save(str(tmpdir_factory.mktemp('out').join('chips.feather'))) @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) @@ -304,16 +296,13 @@ def test_iter(self, sampler: GridGeoSampler) -> None: assert math.isclose(query.maxx - query.minx, sampler.size[1]) assert math.isclose(query.maxy - query.miny, sampler.size[0]) - assert ( - sampler.roi.mint - <= query.mint - <= query.maxt - <= sampler.roi.maxt - ) + assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt def test_len(self, sampler: GridGeoSampler) -> None: rows, cols = tile_to_chips(sampler.roi, sampler.size, sampler.stride) - length = rows * cols * 2 # two spatially but not temporally overlapping items in dataset + length = ( + rows * cols * 2 + ) # two spatially but not temporally overlapping items in dataset assert len(sampler) == length def test_roi(self, dataset: CustomGeoDataset) -> None: @@ -352,20 +341,20 @@ def test_float_multiple(self) -> None: assert len(sampler) == 2 assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) - + def test_dataset_has_regex(self) -> None: ds = CustomGeoDataset() ds.filename_regex = r'.*(?Ptest)' - ds.index.insert(0, (0, 10, 0, 10, 0, 10), "filepath_containing_key_test") + ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'filepath_containing_key_test') sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS) - assert "my_key" in sampler.chips.columns + assert 'my_key' in sampler.chips.columns def test_dataset_has_regex_no_match(self) -> None: ds = CustomGeoDataset() ds.filename_regex = r'(?Ptest)' - ds.index.insert(0, (0, 10, 0, 10, 0, 10), "no_matching_key") + ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'no_matching_key') sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS) - assert "my_key" not in sampler.chips.columns + assert 'my_key' not in sampler.chips.columns def test_return_as_ts(self) -> None: ds = CustomGeoDatasetSITS() @@ -376,7 +365,6 @@ def test_return_as_ts(self) -> None: assert query.mint == ds.bounds.mint == 0 assert query.maxt == ds.bounds.maxt == 20 - @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index f0d217ea640..5d5175c63b4 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -96,7 +96,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): #: other datasets. It should not include a file extension, as the dataset may be in #: a different file format than what it was originally downloaded as. filename_glob = '*' - + # Whether to return the dataset as a Timeseries, this will add another dimension to the dataset return_as_ts = False diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 8bc6f628a6d..2924edad7f4 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -4,27 +4,28 @@ """TorchGeo samplers.""" import abc +import re +import warnings from collections.abc import Callable, Iterable, Iterator +from typing import Any +import geopandas as gpd +import numpy as np +import pandas as pd import torch +from geopandas import GeoDataFrame from rtree.index import Index, Property +from shapely.geometry import box from torch.utils.data import Sampler +from tqdm import tqdm from ..datasets import BoundingBox, GeoDataset from .constants import Units from .utils import _to_tuple, get_random_bounding_box, tile_to_chips -from geopandas import GeoDataFrame -from tqdm import tqdm -from shapely.geometry import box -import re -import pandas as pd -import warnings -import geopandas as gpd -import numpy as np -def load_file(path: str|GeoDataFrame) -> GeoDataFrame: - """ - Load a file from the given path. + +def load_file(path: str | GeoDataFrame) -> GeoDataFrame: + """Load a file from the given path. Parameters: path (str or GeoDataFrame): The path to the file or a GeoDataFrame object. @@ -38,16 +39,16 @@ def load_file(path: str|GeoDataFrame) -> GeoDataFrame: """ if isinstance(path, GeoDataFrame): return path - if path.endswith(".feather"): - print(f"Reading feather file: {path}") + if path.endswith('.feather'): + print(f'Reading feather file: {path}') return gpd.read_feather(path) else: - print(f"Reading shapefile: {path}") + print(f'Reading shapefile: {path}') return gpd.read_file(path) -def _get_regex_groups_as_df(dataset, hits): - """ - Extracts the regex metadata from a list of hits. + +def _get_regex_groups_as_df(dataset: GeoDataset, hits: list) -> pd.DataFrame: + """Extracts the regex metadata from a list of hits. Args: dataset (GeoDataset): The dataset to sample from. @@ -56,7 +57,7 @@ def _get_regex_groups_as_df(dataset, hits): Returns: pandas.DataFrame: A DataFrame containing the extracted file metadata. """ - has_filename_regex = bool(getattr(dataset, "filename_regex", None)) + has_filename_regex = hasattr(dataset, 'filename_regex') if has_filename_regex: filename_regex = re.compile(dataset.filename_regex, re.VERBOSE) file_metadata = [] @@ -71,17 +72,18 @@ def _get_regex_groups_as_df(dataset, hits): meta = {} meta.update( { - "minx": hit.bounds[0], - "maxx": hit.bounds[1], - "miny": hit.bounds[2], - "maxy": hit.bounds[3], - "mint": hit.bounds[4], - "maxt": hit.bounds[5], + 'minx': hit.bounds[0], + 'maxx': hit.bounds[1], + 'miny': hit.bounds[2], + 'maxy': hit.bounds[3], + 'mint': hit.bounds[4], + 'maxt': hit.bounds[5], } ) file_metadata.append(meta) return pd.DataFrame(file_metadata) + class GeoSampler(Sampler[BoundingBox], abc.ABC): """Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`. @@ -112,23 +114,23 @@ def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None: self.res = dataset.res self.roi = roi self.dataset = dataset - - @abc.abstractmethod - def get_chips(self, **kwargs) -> GeoDataFrame: + + @abc.abstractmethod + def get_chips(self, *args: Any, **kwargs: Any) -> GeoDataFrame: """Determines the way to get the extend of the chips (samples) of the dataset. - Should return a GeoDataFrame with the extend of the chips with the columns - geometry, minx, miny, maxx, maxy, mint, maxt, fid. Each row is a chip.""" + Should return a GeoDataFrame with the extend of the chips with the columns + geometry, minx, miny, maxx, maxy, mint, maxt, fid. Each row is a chip. + """ def filter_chips( self, filter_by: str | GeoDataFrame, - predicate: str = "intersects", - action: str = "keep" , + predicate: str = 'intersects', + action: str = 'keep', ) -> None: - """Filter the default set of chips in the sampler down to a specific subset by - specifying files supported by geopandas such as shapefiles, geodatabases or - feather files. + """Filter the default set of chips in the sampler down to a specific subset by specifying files + supported by geopandas such as shapefiles, geodatabases or feather files. Args: filter_by: The file or geodataframe for which the geometries will be used during filtering @@ -139,7 +141,7 @@ def filter_chips( prefilter_leng = len(self.chips) filtering_gdf = load_file(filter_by).to_crs(self.dataset.crs) - if action == "keep": + if action == 'keep': self.chips = self.chips.iloc[ list( set( @@ -149,7 +151,7 @@ def filter_chips( ) ) ].reset_index(drop=True) - elif action == "drop": + elif action == 'drop': self.chips = self.chips.drop( index=list( set( @@ -161,8 +163,8 @@ def filter_chips( ).reset_index(drop=True) self.chips.fid = self.chips.index - print(f"Filter step reduced chips from {prefilter_leng} to {len(self.chips)}") - assert not self.chips.empty, "No chips left after filtering!" + print(f'Filter step reduced chips from {prefilter_leng} to {len(self.chips)}') + assert not self.chips.empty, 'No chips left after filtering!' def set_worker_split(self, total_workers: int, worker_num: int) -> None: """Splits the chips in n equal parts for the number of workers and keeps the set of @@ -176,11 +178,9 @@ def set_worker_split(self, total_workers: int, worker_num: int) -> None: """ self.chips = np.array_split(self.chips, total_workers)[worker_num] - def save(self, - path: str, - driver: str = None) -> None: + def save(self, path: str, driver: str) -> None: """Save the chips as a shapefile or feather file""" - if path.endswith(".feather"): + if path.endswith('.feather'): self.chips.to_feather(path) else: self.chips.to_file(path, driver=driver) @@ -204,6 +204,7 @@ def __len__(self) -> int: """ return len(self.chips) + class RandomGeoSampler(GeoSampler): """Differs from TorchGeo's official RandomGeoSampler in that it can sample SITS data. @@ -301,19 +302,19 @@ def get_chips(self, num_samples) -> GeoDataFrame: bbox = get_random_bounding_box(bounds, self.size, self.res) minx, maxx, miny, maxy, mint, maxt = tuple(bbox) chip = { - "geometry": box(minx, miny, maxx, maxy), - "minx": minx, - "miny": miny, - "maxx": maxx, - "maxy": maxy, - "mint": mint, - "maxt": maxt, + 'geometry': box(minx, miny, maxx, maxy), + 'minx': minx, + 'miny': miny, + 'maxx': maxx, + 'maxy': maxy, + 'mint': mint, + 'maxt': maxt, } chips.append(chip) - - print("creating geodataframe... ") + + print('creating geodataframe... ') chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) - chips_gdf["fid"] = chips_gdf.index + chips_gdf['fid'] = chips_gdf.index return chips_gdf @@ -381,19 +382,20 @@ def __init__( # Filter out hits in the index that share the same extent if self.dataset.return_as_ts: self.df_path.drop_duplicates( - subset=["minx", "maxx", "miny", "maxy"], inplace=True + subset=['minx', 'maxx', 'miny', 'maxy'], inplace=True ) else: self.df_path.drop_duplicates( - subset=["minx", "maxx", "miny", "maxy", "mint", "maxt"], inplace=True + subset=['minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt'], inplace=True ) - - self.chips = self.get_chips() + self.chips = self.get_chips() def get_chips(self) -> GeoDataFrame: - print("generating samples... ") - optional_keys = set(self.df_path.keys()) - set(['geometry', 'minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt']) + print('generating samples... ') + optional_keys = set(self.df_path.keys()) - set( + ['geometry', 'minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt'] + ) chips = [] for _, row in tqdm(self.df_path.iterrows(), total=len(self.df_path)): bounds = BoundingBox( @@ -419,13 +421,13 @@ def get_chips(self) -> GeoDataFrame: maxt = bounds.maxt chip = { - "geometry": box(minx, miny, maxx, maxy), - "minx": minx, - "miny": miny, - "maxx": maxx, - "maxy": maxy, - "mint": mint, - "maxt": maxt, + 'geometry': box(minx, miny, maxx, maxy), + 'minx': minx, + 'miny': miny, + 'maxx': maxx, + 'maxy': maxy, + 'mint': mint, + 'maxt': maxt, } for key in optional_keys: if key in row.keys(): @@ -434,12 +436,12 @@ def get_chips(self) -> GeoDataFrame: chips.append(chip) if chips: - print("creating geodataframe... ") + print('creating geodataframe... ') chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) - chips_gdf["fid"] = chips_gdf.index + chips_gdf['fid'] = chips_gdf.index else: - warnings.warn("Sampler has no chips, check your inputs") + warnings.warn('Sampler has no chips, check your inputs') chips_gdf = GeoDataFrame() return chips_gdf @@ -477,11 +479,10 @@ def __init__( self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): self.hits.append(hit) - + self.chips = self.get_chips() def get_chips(self) -> GeoDataFrame: - generator: Callable[[int], Iterable[int]] = range if self.shuffle: generator = torch.randperm @@ -490,19 +491,18 @@ def get_chips(self) -> GeoDataFrame: for idx in generator(len(self.hits)): minx, maxx, miny, maxy, mint, maxt = self.hits[idx].bounds chip = { - "geometry": box(minx, miny, maxx, maxy), - "minx": minx, - "miny": miny, - "maxx": maxx, - "maxy": maxy, - "mint": mint, - "maxt": maxt, + 'geometry': box(minx, miny, maxx, maxy), + 'minx': minx, + 'miny': miny, + 'maxx': maxx, + 'maxy': maxy, + 'mint': mint, + 'maxt': maxt, } chips.append(chip) - print("creating geodataframe... ") + print('creating geodataframe... ') chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) - chips_gdf["fid"] = chips_gdf.index + chips_gdf['fid'] = chips_gdf.index return chips_gdf -