Skip to content

Commit

Permalink
Run pre-commit (tg ruleset) WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
sfalkena committed Aug 28, 2024
1 parent fb85941 commit 889735c
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 134 deletions.
12 changes: 5 additions & 7 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
{
"python.testing.pytestArgs": ["tests"],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
80 changes: 34 additions & 46 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import math
from collections.abc import Iterator
import os
from itertools import product

import geopandas as gpd
import pytest
from _pytest.fixtures import SubRequest
from geopandas import GeoDataFrame
from rasterio.crs import CRS
from shapely.geometry import box
from torch.utils.data import DataLoader

from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
Expand All @@ -21,10 +23,6 @@
tile_to_chips,
)

import geopandas as gpd
from geopandas import GeoDataFrame
from shapely.geometry import box


class CustomGeoSampler(GeoSampler):
def __init__(self) -> None:
Expand All @@ -35,16 +33,16 @@ def get_chips(self) -> GeoDataFrame:
for i in range(len(self)):
chips.append(
{
"geometry": box(i, i, i, i),
"minx": i,
"miny": i,
"maxx": i,
"maxy": i,
"mint": i,
"maxt": i,
'geometry': box(i, i, i, i),
'minx': i,
'miny': i,
'maxx': i,
'maxy': i,
'mint': i,
'maxt': i,
}
)
return GeoDataFrame(chips, crs="3005")
return GeoDataFrame(chips, crs='3005')

def __len__(self) -> int:
return 2
Expand All @@ -65,11 +63,12 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
super().__init__()
self._crs = crs
self.res = res
self.return_as_ts=True
self.return_as_ts = True

def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
return {'index': query}


class TestGeoSampler:
@pytest.fixture(scope='class')
def dataset(self) -> CustomGeoDataset:
Expand All @@ -85,9 +84,9 @@ def sampler(self) -> CustomGeoSampler:
def datadir(self) -> str:
return os.path.join('tests', 'data', 'samplers')

def test_no_get_chips_implemented(self) -> None:
def test_no_get_chips_implemented(self, dataset: CustomGeoDataset) -> None:
with pytest.raises(TypeError):
GeoSampler()
GeoSampler(dataset)

def test_iter(self, sampler: CustomGeoSampler) -> None:
assert next(iter(sampler)) == BoundingBox(0, 0, 0, 0, 0, 0)
Expand All @@ -100,13 +99,9 @@ def test_abstract(self, dataset: CustomGeoDataset) -> None:
GeoSampler(dataset) # type: ignore[abstract]

@pytest.mark.parametrize(
"filtering_file",
[
"filtering_4x4",
"filtering_4x4.feather",
],
'filtering_file', ['filtering_4x4', 'filtering_4x4.feather']
)
def test_filtering_from_path(self, datadir, filtering_file) -> None:
def test_filtering_from_path(self, datadir: str, filtering_file: str) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(
Expand All @@ -116,11 +111,11 @@ def test_filtering_from_path(self, datadir, filtering_file) -> None:

assert len(sampler) == 4
filtering_path = os.path.join(datadir, filtering_file)
sampler.filter_chips(filtering_path, "intersects", "drop")
sampler.filter_chips(filtering_path, 'intersects', 'drop')
assert len(sampler) == 3
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

def test_filtering_from_gdf(self, datadir) -> None:
def test_filtering_from_gdf(self, datadir: str) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(
Expand All @@ -130,17 +125,15 @@ def test_filtering_from_gdf(self, datadir) -> None:

# Dropping first chip
assert len(sampler) == 4
filtering_gdf = gpd.read_file(
os.path.join(datadir, "filtering_4x4")
)
sampler.filter_chips(filtering_gdf, "intersects", "drop")
filtering_gdf = gpd.read_file(os.path.join(datadir, 'filtering_4x4'))
sampler.filter_chips(filtering_gdf, 'intersects', 'drop')
assert len(sampler) == 3
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

# Keeping only first chip
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
iterator = iter(sampler)
sampler.filter_chips(filtering_gdf, "intersects", "keep")
sampler.filter_chips(filtering_gdf, 'intersects', 'keep')
assert len(sampler) == 1
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)

Expand All @@ -158,9 +151,8 @@ def test_save_chips(self, tmpdir_factory) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
sampler.save(str(tmpdir_factory.mktemp("out").join("chips")))
sampler.save(str(tmpdir_factory.mktemp("out").join("chips.feather")))

sampler.save(str(tmpdir_factory.mktemp('out').join('chips')))
sampler.save(str(tmpdir_factory.mktemp('out').join('chips.feather')))

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
Expand Down Expand Up @@ -304,16 +296,13 @@ def test_iter(self, sampler: GridGeoSampler) -> None:

assert math.isclose(query.maxx - query.minx, sampler.size[1])
assert math.isclose(query.maxy - query.miny, sampler.size[0])
assert (
sampler.roi.mint
<= query.mint
<= query.maxt
<= sampler.roi.maxt
)
assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt

def test_len(self, sampler: GridGeoSampler) -> None:
rows, cols = tile_to_chips(sampler.roi, sampler.size, sampler.stride)
length = rows * cols * 2 # two spatially but not temporally overlapping items in dataset
length = (
rows * cols * 2
) # two spatially but not temporally overlapping items in dataset
assert len(sampler) == length

def test_roi(self, dataset: CustomGeoDataset) -> None:
Expand Down Expand Up @@ -352,20 +341,20 @@ def test_float_multiple(self) -> None:
assert len(sampler) == 2
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

def test_dataset_has_regex(self) -> None:
ds = CustomGeoDataset()
ds.filename_regex = r'.*(?P<my_key>test)'
ds.index.insert(0, (0, 10, 0, 10, 0, 10), "filepath_containing_key_test")
ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'filepath_containing_key_test')
sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS)
assert "my_key" in sampler.chips.columns
assert 'my_key' in sampler.chips.columns

def test_dataset_has_regex_no_match(self) -> None:
ds = CustomGeoDataset()
ds.filename_regex = r'(?P<my_key>test)'
ds.index.insert(0, (0, 10, 0, 10, 0, 10), "no_matching_key")
ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'no_matching_key')
sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS)
assert "my_key" not in sampler.chips.columns
assert 'my_key' not in sampler.chips.columns

def test_return_as_ts(self) -> None:
ds = CustomGeoDatasetSITS()
Expand All @@ -376,7 +365,6 @@ def test_return_as_ts(self) -> None:
assert query.mint == ds.bounds.mint == 0
assert query.maxt == ds.bounds.maxt == 20


@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
#: other datasets. It should not include a file extension, as the dataset may be in
#: a different file format than what it was originally downloaded as.
filename_glob = '*'

# Whether to return the dataset as a Timeseries, this will add another dimension to the dataset
return_as_ts = False

Expand Down
Loading

0 comments on commit 889735c

Please sign in to comment.