Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "TimeSeries Support" #3

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .vscode/settings.json

This file was deleted.

Binary file removed tests/data/samplers/filtering_4x4.feather
Binary file not shown.
1 change: 0 additions & 1 deletion tests/data/samplers/filtering_4x4/filtering_4x4.cpg

This file was deleted.

Binary file removed tests/data/samplers/filtering_4x4/filtering_4x4.dbf
Binary file not shown.
1 change: 0 additions & 1 deletion tests/data/samplers/filtering_4x4/filtering_4x4.prj

This file was deleted.

Binary file removed tests/data/samplers/filtering_4x4/filtering_4x4.shp
Binary file not shown.
Binary file removed tests/data/samplers/filtering_4x4/filtering_4x4.shx
Binary file not shown.
140 changes: 9 additions & 131 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
# Licensed under the MIT License.

import math
import os
from collections.abc import Iterator
from itertools import product

import geopandas as gpd
import pytest
from _pytest.fixtures import SubRequest
from geopandas import GeoDataFrame
from rasterio.crs import CRS
from shapely.geometry import box
from torch.utils.data import DataLoader

from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
Expand All @@ -26,23 +23,11 @@

class CustomGeoSampler(GeoSampler):
def __init__(self) -> None:
self.chips = self.get_chips()
pass

def get_chips(self) -> GeoDataFrame:
chips = []
def __iter__(self) -> Iterator[BoundingBox]:
for i in range(len(self)):
chips.append(
{
'geometry': box(i, i, i, i),
'minx': i,
'miny': i,
'maxx': i,
'maxy': i,
'mint': i,
'maxt': i,
}
)
return GeoDataFrame(chips, crs='3005')
yield BoundingBox(i, i, i, i, i, i)

def __len__(self) -> int:
return 2
Expand All @@ -58,17 +43,6 @@ def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
return {'index': query}


class CustomGeoDatasetSITS(GeoDataset):
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
super().__init__()
self._crs = crs
self.res = res
self.return_as_ts = True

def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
return {'index': query}


class TestGeoSampler:
@pytest.fixture(scope='class')
def dataset(self) -> CustomGeoDataset:
Expand All @@ -80,14 +54,6 @@ def dataset(self) -> CustomGeoDataset:
def sampler(self) -> CustomGeoSampler:
return CustomGeoSampler()

@pytest.fixture(scope='class')
def datadir(self) -> str:
return os.path.join('tests', 'data', 'samplers')

def test_no_get_chips_implemented(self, dataset: CustomGeoDataset) -> None:
with pytest.raises(TypeError):
GeoSampler(dataset)

def test_iter(self, sampler: CustomGeoSampler) -> None:
assert next(iter(sampler)) == BoundingBox(0, 0, 0, 0, 0, 0)

Expand All @@ -98,62 +64,6 @@ def test_abstract(self, dataset: CustomGeoDataset) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
GeoSampler(dataset) # type: ignore[abstract]

@pytest.mark.parametrize(
'filtering_file', ['filtering_4x4', 'filtering_4x4.feather']
)
def test_filtering_from_path(self, datadir: str, filtering_file: str) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(
ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10)
)
iterator = iter(sampler)

assert len(sampler) == 4
filtering_path = os.path.join(datadir, filtering_file)
sampler.filter_chips(filtering_path, 'intersects', 'drop')
assert len(sampler) == 3
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

def test_filtering_from_gdf(self, datadir: str) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(
ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10)
)
iterator = iter(sampler)

# Dropping first chip
assert len(sampler) == 4
filtering_gdf = gpd.read_file(os.path.join(datadir, 'filtering_4x4'))
sampler.filter_chips(filtering_gdf, 'intersects', 'drop')
assert len(sampler) == 3
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

# Keeping only first chip
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
iterator = iter(sampler)
sampler.filter_chips(filtering_gdf, 'intersects', 'keep')
assert len(sampler) == 1
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)

def test_set_worker_split(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(
ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10)
)
assert len(sampler) == 4
sampler.set_worker_split(total_workers=4, worker_num=1)
assert len(sampler) == 1

def test_save_chips(self, tmpdir_factory) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
sampler.save(str(tmpdir_factory.mktemp('out').join('chips')))
sampler.save(str(tmpdir_factory.mktemp('out').join('chips.feather')))

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down Expand Up @@ -229,15 +139,6 @@ def test_weighted_sampling(self) -> None:
for bbox in sampler:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_return_as_ts(self) -> None:
ds = CustomGeoDatasetSITS()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 10, 0, 10, 15, 20))
sampler = RandomGeoSampler(ds, 1, 5)
for query in sampler:
assert query.mint == ds.bounds.mint == 0
assert query.maxt == ds.bounds.maxt == 20

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand All @@ -255,7 +156,7 @@ class TestGridGeoSampler:
def dataset(self) -> CustomGeoDataset:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 100, 200, 300, 400, 500))
ds.index.insert(1, (0, 100, 200, 300, 500, 600))
ds.index.insert(1, (0, 100, 200, 300, 400, 500))
return ds

@pytest.fixture(
Expand Down Expand Up @@ -296,13 +197,13 @@ def test_iter(self, sampler: GridGeoSampler) -> None:

assert math.isclose(query.maxx - query.minx, sampler.size[1])
assert math.isclose(query.maxy - query.miny, sampler.size[0])
assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt
assert math.isclose(
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint
)

def test_len(self, sampler: GridGeoSampler) -> None:
rows, cols = tile_to_chips(sampler.roi, sampler.size, sampler.stride)
length = (
rows * cols * 2
) # two spatially but not temporally overlapping items in dataset
length = rows * cols * 2 # two items in dataset
assert len(sampler) == length

def test_roi(self, dataset: CustomGeoDataset) -> None:
Expand Down Expand Up @@ -342,29 +243,6 @@ def test_float_multiple(self) -> None:
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

def test_dataset_has_regex(self) -> None:
ds = CustomGeoDataset()
ds.filename_regex = r'.*(?P<my_key>test)'
ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'filepath_containing_key_test')
sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS)
assert 'my_key' in sampler.chips.columns

def test_dataset_has_regex_no_match(self) -> None:
ds = CustomGeoDataset()
ds.filename_regex = r'(?P<my_key>test)'
ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'no_matching_key')
sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS)
assert 'my_key' not in sampler.chips.columns

def test_return_as_ts(self) -> None:
ds = CustomGeoDatasetSITS()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 10, 0, 10, 15, 20))
sampler = GridGeoSampler(ds, 1, 1)
for query in sampler:
assert query.mint == ds.bounds.mint == 0
assert query.maxt == ds.bounds.maxt == 20

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
5 changes: 0 additions & 5 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
#: a different file format than what it was originally downloaded as.
filename_glob = '*'

# Whether to return the dataset as a Timeseries, this will add another dimension to the dataset
return_as_ts = False

# NOTE: according to the Python docs:
#
# * https://docs.python.org/3/library/exceptions.html#NotImplementedError
Expand Down Expand Up @@ -986,7 +983,6 @@ def __init__(
if not isinstance(ds, GeoDataset):
raise ValueError('IntersectionDataset only supports GeoDatasets')

self.return_as_ts = dataset1.return_as_ts or dataset2.return_as_ts
self.crs = dataset1.crs
self.res = dataset1.res

Expand Down Expand Up @@ -1147,7 +1143,6 @@ def __init__(
if not isinstance(ds, GeoDataset):
raise ValueError('UnionDataset only supports GeoDatasets')

self.return_as_ts = dataset1.return_as_ts and dataset2.return_as_ts
self.crs = dataset1.crs
self.res = dataset1.res

Expand Down
Loading