Skip to content

Commit

Permalink
Revert "TimeSeries Support"
Browse files Browse the repository at this point in the history
  • Loading branch information
sfalkena authored Sep 17, 2024
1 parent 45f1c7e commit 21552c4
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 410 deletions.
5 changes: 0 additions & 5 deletions .vscode/settings.json

This file was deleted.

Binary file removed tests/data/samplers/filtering_4x4.feather
Binary file not shown.
1 change: 0 additions & 1 deletion tests/data/samplers/filtering_4x4/filtering_4x4.cpg

This file was deleted.

Binary file removed tests/data/samplers/filtering_4x4/filtering_4x4.dbf
Binary file not shown.
1 change: 0 additions & 1 deletion tests/data/samplers/filtering_4x4/filtering_4x4.prj

This file was deleted.

Binary file removed tests/data/samplers/filtering_4x4/filtering_4x4.shp
Binary file not shown.
Binary file removed tests/data/samplers/filtering_4x4/filtering_4x4.shx
Binary file not shown.
140 changes: 9 additions & 131 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
# Licensed under the MIT License.

import math
import os
from collections.abc import Iterator
from itertools import product

import geopandas as gpd
import pytest
from _pytest.fixtures import SubRequest
from geopandas import GeoDataFrame
from rasterio.crs import CRS
from shapely.geometry import box
from torch.utils.data import DataLoader

from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
Expand All @@ -26,23 +23,11 @@

class CustomGeoSampler(GeoSampler):
def __init__(self) -> None:
self.chips = self.get_chips()
pass

def get_chips(self) -> GeoDataFrame:
chips = []
def __iter__(self) -> Iterator[BoundingBox]:
for i in range(len(self)):
chips.append(
{
'geometry': box(i, i, i, i),
'minx': i,
'miny': i,
'maxx': i,
'maxy': i,
'mint': i,
'maxt': i,
}
)
return GeoDataFrame(chips, crs='3005')
yield BoundingBox(i, i, i, i, i, i)

def __len__(self) -> int:
return 2
Expand All @@ -58,17 +43,6 @@ def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
return {'index': query}


class CustomGeoDatasetSITS(GeoDataset):
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
super().__init__()
self._crs = crs
self.res = res
self.return_as_ts = True

def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
return {'index': query}


class TestGeoSampler:
@pytest.fixture(scope='class')
def dataset(self) -> CustomGeoDataset:
Expand All @@ -80,14 +54,6 @@ def dataset(self) -> CustomGeoDataset:
def sampler(self) -> CustomGeoSampler:
return CustomGeoSampler()

@pytest.fixture(scope='class')
def datadir(self) -> str:
return os.path.join('tests', 'data', 'samplers')

def test_no_get_chips_implemented(self, dataset: CustomGeoDataset) -> None:
with pytest.raises(TypeError):
GeoSampler(dataset)

def test_iter(self, sampler: CustomGeoSampler) -> None:
assert next(iter(sampler)) == BoundingBox(0, 0, 0, 0, 0, 0)

Expand All @@ -98,62 +64,6 @@ def test_abstract(self, dataset: CustomGeoDataset) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
GeoSampler(dataset) # type: ignore[abstract]

@pytest.mark.parametrize(
'filtering_file', ['filtering_4x4', 'filtering_4x4.feather']
)
def test_filtering_from_path(self, datadir: str, filtering_file: str) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(
ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10)
)
iterator = iter(sampler)

assert len(sampler) == 4
filtering_path = os.path.join(datadir, filtering_file)
sampler.filter_chips(filtering_path, 'intersects', 'drop')
assert len(sampler) == 3
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

def test_filtering_from_gdf(self, datadir: str) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(
ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10)
)
iterator = iter(sampler)

# Dropping first chip
assert len(sampler) == 4
filtering_gdf = gpd.read_file(os.path.join(datadir, 'filtering_4x4'))
sampler.filter_chips(filtering_gdf, 'intersects', 'drop')
assert len(sampler) == 3
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

# Keeping only first chip
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
iterator = iter(sampler)
sampler.filter_chips(filtering_gdf, 'intersects', 'keep')
assert len(sampler) == 1
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)

def test_set_worker_split(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(
ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10)
)
assert len(sampler) == 4
sampler.set_worker_split(total_workers=4, worker_num=1)
assert len(sampler) == 1

def test_save_chips(self, tmpdir_factory) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
sampler.save(str(tmpdir_factory.mktemp('out').join('chips')))
sampler.save(str(tmpdir_factory.mktemp('out').join('chips.feather')))

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down Expand Up @@ -229,15 +139,6 @@ def test_weighted_sampling(self) -> None:
for bbox in sampler:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_return_as_ts(self) -> None:
ds = CustomGeoDatasetSITS()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 10, 0, 10, 15, 20))
sampler = RandomGeoSampler(ds, 1, 5)
for query in sampler:
assert query.mint == ds.bounds.mint == 0
assert query.maxt == ds.bounds.maxt == 20

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand All @@ -255,7 +156,7 @@ class TestGridGeoSampler:
def dataset(self) -> CustomGeoDataset:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 100, 200, 300, 400, 500))
ds.index.insert(1, (0, 100, 200, 300, 500, 600))
ds.index.insert(1, (0, 100, 200, 300, 400, 500))
return ds

@pytest.fixture(
Expand Down Expand Up @@ -296,13 +197,13 @@ def test_iter(self, sampler: GridGeoSampler) -> None:

assert math.isclose(query.maxx - query.minx, sampler.size[1])
assert math.isclose(query.maxy - query.miny, sampler.size[0])
assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt
assert math.isclose(
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint
)

def test_len(self, sampler: GridGeoSampler) -> None:
rows, cols = tile_to_chips(sampler.roi, sampler.size, sampler.stride)
length = (
rows * cols * 2
) # two spatially but not temporally overlapping items in dataset
length = rows * cols * 2 # two items in dataset
assert len(sampler) == length

def test_roi(self, dataset: CustomGeoDataset) -> None:
Expand Down Expand Up @@ -342,29 +243,6 @@ def test_float_multiple(self) -> None:
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

def test_dataset_has_regex(self) -> None:
ds = CustomGeoDataset()
ds.filename_regex = r'.*(?P<my_key>test)'
ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'filepath_containing_key_test')
sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS)
assert 'my_key' in sampler.chips.columns

def test_dataset_has_regex_no_match(self) -> None:
ds = CustomGeoDataset()
ds.filename_regex = r'(?P<my_key>test)'
ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'no_matching_key')
sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS)
assert 'my_key' not in sampler.chips.columns

def test_return_as_ts(self) -> None:
ds = CustomGeoDatasetSITS()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 10, 0, 10, 15, 20))
sampler = GridGeoSampler(ds, 1, 1)
for query in sampler:
assert query.mint == ds.bounds.mint == 0
assert query.maxt == ds.bounds.maxt == 20

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
5 changes: 0 additions & 5 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
#: a different file format than what it was originally downloaded as.
filename_glob = '*'

# Whether to return the dataset as a Timeseries, this will add another dimension to the dataset
return_as_ts = False

# NOTE: according to the Python docs:
#
# * https://docs.python.org/3/library/exceptions.html#NotImplementedError
Expand Down Expand Up @@ -986,7 +983,6 @@ def __init__(
if not isinstance(ds, GeoDataset):
raise ValueError('IntersectionDataset only supports GeoDatasets')

self.return_as_ts = dataset1.return_as_ts or dataset2.return_as_ts
self.crs = dataset1.crs
self.res = dataset1.res

Expand Down Expand Up @@ -1147,7 +1143,6 @@ def __init__(
if not isinstance(ds, GeoDataset):
raise ValueError('UnionDataset only supports GeoDatasets')

self.return_as_ts = dataset1.return_as_ts and dataset2.return_as_ts
self.crs = dataset1.crs
self.res = dataset1.res

Expand Down
Loading

0 comments on commit 21552c4

Please sign in to comment.