Skip to content

Commit

Permalink
IntersectionDataset: better error message when no overlap (#1192)
Browse files Browse the repository at this point in the history
* IntersectionDataset: better error message when no overlap

* Update split tests

* Document error
  • Loading branch information
adamjstewart authored Mar 29, 2023
1 parent e60c1c4 commit 0a5e1d9
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 43 deletions.
21 changes: 16 additions & 5 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,20 @@ def test_nongeo_dataset(self) -> None:
IntersectionDataset(ds1, ds2) # type: ignore[arg-type]

def test_different_crs(self) -> None:
ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005))
ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616))
ds1 = CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 1), crs=CRS.from_epsg(3005))
ds2 = CustomGeoDataset(
BoundingBox(
-3547229.913123814,
6360089.518213182,
-3547229.913123814,
6360089.518213182,
-3547229.913123814,
6360089.518213182,
),
crs=CRS.from_epsg(32616),
)
ds = IntersectionDataset(ds1, ds2)
assert len(ds) == 0
assert len(ds) == 1

def test_different_res(self) -> None:
ds1 = CustomGeoDataset(res=1)
Expand All @@ -419,8 +429,9 @@ def test_different_res(self) -> None:
def test_no_overlap(self) -> None:
ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5))
ds2 = CustomGeoDataset(BoundingBox(6, 7, 8, 9, 10, 11))
ds = IntersectionDataset(ds1, ds2)
assert len(ds) == 0
msg = "Datasets have no spatiotemporal intersection"
with pytest.raises(RuntimeError, match=msg):
IntersectionDataset(ds1, ds2)

def test_invalid_query(self, dataset: IntersectionDataset) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
Expand Down
77 changes: 39 additions & 38 deletions tests/datasets/test_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@
)


def total_area(dataset: GeoDataset) -> float:
total_area = 0.0
for hit in dataset.index.intersection(dataset.index.bounds, objects=True):
total_area += BoundingBox(*hit.bounds).area

return total_area


def no_overlap(ds1: GeoDataset, ds2: GeoDataset) -> bool:
try:
ds = ds1 & ds2
except RuntimeError:
return True
else:
return isclose(total_area(ds), 0)


class CustomGeoDataset(GeoDataset):
def __init__(
self,
Expand Down Expand Up @@ -66,11 +83,9 @@ def test_random_bbox_assignment(
assert len(test_ds) == expected_lengths[2]

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
Expand All @@ -93,14 +108,6 @@ def test_random_bbox_assignment_invalid_inputs() -> None:
random_bbox_assignment(CustomGeoDataset(), lengths=[1 / 2, 3 / 4, -1 / 4])


def _get_total_area(dataset: GeoDataset) -> float:
total_area = 0.0
for hit in dataset.index.intersection(dataset.index.bounds, objects=True):
total_area += BoundingBox(*hit.bounds).area

return total_area


def test_random_bbox_splitting() -> None:
ds = CustomGeoDataset(
[
Expand All @@ -111,30 +118,28 @@ def test_random_bbox_splitting() -> None:
]
)

ds_area = _get_total_area(ds)
ds_area = total_area(ds)

train_ds, val_ds, test_ds = random_bbox_splitting(
ds, fractions=[1 / 2, 1 / 4, 1 / 4]
)
train_ds_area = _get_total_area(train_ds)
val_ds_area = _get_total_area(val_ds)
test_ds_area = _get_total_area(test_ds)
train_ds_area = total_area(train_ds)
val_ds_area = total_area(val_ds)
test_ds_area = total_area(test_ds)

# Check datasets areas
assert train_ds_area == ds_area / 2
assert val_ds_area == ds_area / 4
assert test_ds_area == ds_area / 4

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(_get_total_area(train_ds | val_ds | test_ds), ds_area)
assert isclose(total_area(train_ds | val_ds | test_ds), ds_area)

# Test __get_item__
x = train_ds[train_ds.bounds]
Expand Down Expand Up @@ -168,15 +173,13 @@ def test_random_grid_cell_assignment() -> None:
assert len(test_ds) == floor(1 / 4 * 2 * 5**2)

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds))
assert isclose(total_area(train_ds | val_ds | test_ds), total_area(ds))

# Test __get_item__
x = train_ds[train_ds.bounds]
Expand Down Expand Up @@ -219,15 +222,13 @@ def test_roi_split() -> None:
assert len(test_ds) == 1

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds))
assert isclose(total_area(train_ds | val_ds | test_ds), total_area(ds))

# Test __get_item__
x = train_ds[train_ds.bounds]
Expand Down Expand Up @@ -273,9 +274,9 @@ def test_time_series_split(
assert len(test_ds) == expected_lengths[2]

# No overlap
assert len(train_ds & val_ds) == 0
assert len(val_ds & test_ds) == 0
assert len(test_ds & train_ds) == 0
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
Expand Down
4 changes: 4 additions & 0 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ def __init__(
entry and returns a transformed version
Raises:
RuntimeError: if datasets have no spatiotemporal intersection
ValueError: if either dataset is not a :class:`GeoDataset`
.. versionadded:: 0.4
Expand Down Expand Up @@ -855,6 +856,9 @@ def _merge_dataset_indices(self) -> None:
self.index.insert(i, tuple(box1 & box2))
i += 1

if i == 0:
raise RuntimeError("Datasets have no spatiotemporal intersection")

def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image and metadata indexed by query.
Expand Down

0 comments on commit 0a5e1d9

Please sign in to comment.