Skip to content

Commit

Permalink
100% test coverage for samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
sfalkena committed Aug 28, 2024
1 parent daab71c commit fb85941
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 21 deletions.
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
Binary file added tests/data/samplers/filtering_4x4.feather
Binary file not shown.
1 change: 1 addition & 0 deletions tests/data/samplers/filtering_4x4/filtering_4x4.cpg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ISO-8859-1
Binary file added tests/data/samplers/filtering_4x4/filtering_4x4.dbf
Binary file not shown.
1 change: 1 addition & 0 deletions tests/data/samplers/filtering_4x4/filtering_4x4.prj
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
PROJCS["NAD_1983_BC_Environment_Albers",GEOGCS["GCS_North_American_1983",DATUM["D_North_American_1983",SPHEROID["GRS_1980",6378137.0,298.257222101]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Albers"],PARAMETER["False_Easting",1000000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",-126.0],PARAMETER["Standard_Parallel_1",50.0],PARAMETER["Standard_Parallel_2",58.5],PARAMETER["Latitude_Of_Origin",45.0],UNIT["Meter",1.0]]
Binary file added tests/data/samplers/filtering_4x4/filtering_4x4.shp
Binary file not shown.
Binary file added tests/data/samplers/filtering_4x4/filtering_4x4.shx
Binary file not shown.
116 changes: 116 additions & 0 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import math
from collections.abc import Iterator
from itertools import product
Expand All @@ -20,6 +21,7 @@
tile_to_chips,
)

import geopandas as gpd
from geopandas import GeoDataFrame
from shapely.geometry import box

Expand Down Expand Up @@ -58,6 +60,16 @@ def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
return {'index': query}


class CustomGeoDatasetSITS(GeoDataset):
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
super().__init__()
self._crs = crs
self.res = res
self.return_as_ts=True

def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
return {'index': query}

class TestGeoSampler:
@pytest.fixture(scope='class')
def dataset(self) -> CustomGeoDataset:
Expand All @@ -69,6 +81,14 @@ def dataset(self) -> CustomGeoDataset:
def sampler(self) -> CustomGeoSampler:
return CustomGeoSampler()

@pytest.fixture(scope='class')
def datadir(self) -> str:
return os.path.join('tests', 'data', 'samplers')

def test_no_get_chips_implemented(self) -> None:
with pytest.raises(TypeError):
GeoSampler()

def test_iter(self, sampler: CustomGeoSampler) -> None:
assert next(iter(sampler)) == BoundingBox(0, 0, 0, 0, 0, 0)

Expand All @@ -79,6 +99,69 @@ def test_abstract(self, dataset: CustomGeoDataset) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
GeoSampler(dataset) # type: ignore[abstract]

@pytest.mark.parametrize(
"filtering_file",
[
"filtering_4x4",
"filtering_4x4.feather",
],
)
def test_filtering_from_path(self, datadir, filtering_file) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(
ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10)
)
iterator = iter(sampler)

assert len(sampler) == 4
filtering_path = os.path.join(datadir, filtering_file)
sampler.filter_chips(filtering_path, "intersects", "drop")
assert len(sampler) == 3
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

def test_filtering_from_gdf(self, datadir) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(
ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10)
)
iterator = iter(sampler)

# Dropping first chip
assert len(sampler) == 4
filtering_gdf = gpd.read_file(
os.path.join(datadir, "filtering_4x4")
)
sampler.filter_chips(filtering_gdf, "intersects", "drop")
assert len(sampler) == 3
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

# Keeping only first chip
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
iterator = iter(sampler)
sampler.filter_chips(filtering_gdf, "intersects", "keep")
assert len(sampler) == 1
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)

def test_set_worker_split(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(
ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10)
)
assert len(sampler) == 4
sampler.set_worker_split(total_workers=4, worker_num=1)
assert len(sampler) == 1

def test_save_chips(self, tmpdir_factory) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
sampler.save(str(tmpdir_factory.mktemp("out").join("chips")))
sampler.save(str(tmpdir_factory.mktemp("out").join("chips.feather")))


@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down Expand Up @@ -154,6 +237,15 @@ def test_weighted_sampling(self) -> None:
for bbox in sampler:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_return_as_ts(self) -> None:
ds = CustomGeoDatasetSITS()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 10, 0, 10, 15, 20))
sampler = RandomGeoSampler(ds, 1, 5)
for query in sampler:
assert query.mint == ds.bounds.mint == 0
assert query.maxt == ds.bounds.maxt == 20

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down Expand Up @@ -260,6 +352,30 @@ def test_float_multiple(self) -> None:
assert len(sampler) == 2
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

def test_dataset_has_regex(self) -> None:
ds = CustomGeoDataset()
ds.filename_regex = r'.*(?P<my_key>test)'
ds.index.insert(0, (0, 10, 0, 10, 0, 10), "filepath_containing_key_test")
sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS)
assert "my_key" in sampler.chips.columns

def test_dataset_has_regex_no_match(self) -> None:
ds = CustomGeoDataset()
ds.filename_regex = r'(?P<my_key>test)'
ds.index.insert(0, (0, 10, 0, 10, 0, 10), "no_matching_key")
sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS)
assert "my_key" not in sampler.chips.columns

def test_return_as_ts(self) -> None:
ds = CustomGeoDatasetSITS()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 10, 0, 10, 15, 20))
sampler = GridGeoSampler(ds, 1, 1)
for query in sampler:
assert query.mint == ds.bounds.mint == 0
assert query.maxt == ds.bounds.maxt == 20


@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
Expand Down
80 changes: 59 additions & 21 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,31 @@
import re
import pandas as pd
import warnings
import geopandas as gpd
import numpy as np

def load_file(path: str|GeoDataFrame) -> GeoDataFrame:
"""
Load a file from the given path.
Parameters:
path (str or GeoDataFrame): The path to the file or a GeoDataFrame object.
Returns:
GeoDataFrame: The loaded file as a GeoDataFrame.
Raises:
None
"""
if isinstance(path, GeoDataFrame):
return path
if path.endswith(".feather"):
print(f"Reading feather file: {path}")
return gpd.read_feather(path)
else:
print(f"Reading shapefile: {path}")
return gpd.read_file(path)

def _get_regex_groups_as_df(dataset, hits):
"""
Expand All @@ -40,6 +65,8 @@ def _get_regex_groups_as_df(dataset, hits):
match = re.match(filename_regex, str(hit.object))
if match:
meta = match.groupdict()
else:
meta = {}
else:
meta = {}
meta.update(
Expand Down Expand Up @@ -91,14 +118,13 @@ def get_chips(self, **kwargs) -> GeoDataFrame:
"""Determines the way to get the extend of the chips (samples) of the dataset.
Should return a GeoDataFrame with the extend of the chips with the columns
geometry, minx, miny, maxx, maxy, mint, maxt, fid. Each row is a chip."""
raise NotImplementedError


def filter_chips(
self,
filter_by: str | GeoDataFrame,
predicate: str = "intersects",
action: str = "keep",
action: str = "keep" ,
) -> None:
"""Filter the default set of chips in the sampler down to a specific subset by
specifying files supported by geopandas such as shapefiles, geodatabases or
Expand All @@ -112,9 +138,28 @@ def filter_chips(
"""
prefilter_leng = len(self.chips)
filtering_gdf = load_file(filter_by).to_crs(self.dataset.crs)
self.chips = filter_tiles(
self.chips, filtering_gdf, predicate, action
).reset_index(drop=True)

if action == "keep":
self.chips = self.chips.iloc[
list(
set(
self.chips.sindex.query_bulk(
filtering_gdf.geometry, predicate=predicate
)[1]
)
)
].reset_index(drop=True)
elif action == "drop":
self.chips = self.chips.drop(
index=list(
set(
self.chips.sindex.query_bulk(
filtering_gdf.geometry, predicate=predicate
)[1]
)
)
).reset_index(drop=True)

self.chips.fid = self.chips.index
print(f"Filter step reduced chips from {prefilter_leng} to {len(self.chips)}")
assert not self.chips.empty, "No chips left after filtering!"
Expand Down Expand Up @@ -160,7 +205,7 @@ def __len__(self) -> int:
return len(self.chips)

class RandomGeoSampler(GeoSampler):
"""Differs from TrochGeo's RandomGeoSampler in that it can sample SITS data.
"""Differs from TorchGeo's official RandomGeoSampler in that it can sample SITS data.
Documentation from TorchGeo:
Samples elements from a region of interest randomly.
Expand Down Expand Up @@ -266,14 +311,10 @@ def get_chips(self, num_samples) -> GeoDataFrame:
}
chips.append(chip)

if chips:
print("creating geodataframe... ")
chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs)
chips_gdf["fid"] = chips_gdf.index
print("creating geodataframe... ")
chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs)
chips_gdf["fid"] = chips_gdf.index

else:
warnings.warn("Sampler has no chips, check your inputs")
chips_gdf = GeoDataFrame()
return chips_gdf


Expand Down Expand Up @@ -352,7 +393,7 @@ def __init__(

def get_chips(self) -> GeoDataFrame:
print("generating samples... ")
optional_keys = ["tile", "date"]
optional_keys = set(self.df_path.keys()) - set(['geometry', 'minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt'])
chips = []
for _, row in tqdm(self.df_path.iterrows(), total=len(self.df_path)):
bounds = BoundingBox(
Expand Down Expand Up @@ -459,12 +500,9 @@ def get_chips(self) -> GeoDataFrame:
}
chips.append(chip)

if chips:
print("creating geodataframe... ")
chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs)
chips_gdf["fid"] = chips_gdf.index
else:
warnings.warn("Sampler has no chips, check your inputs")
chips_gdf = GeoDataFrame()
print("creating geodataframe... ")
chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs)
chips_gdf["fid"] = chips_gdf.index

return chips_gdf

0 comments on commit fb85941

Please sign in to comment.