Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IntersectionDataset: better error message when no overlap #1192

Merged
merged 3 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,20 @@ def test_nongeo_dataset(self) -> None:
IntersectionDataset(ds1, ds2) # type: ignore[arg-type]

def test_different_crs(self) -> None:
ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005))
ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616))
ds1 = CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 1), crs=CRS.from_epsg(3005))
ds2 = CustomGeoDataset(
BoundingBox(
-3547229.913123814,
6360089.518213182,
-3547229.913123814,
6360089.518213182,
-3547229.913123814,
6360089.518213182,
),
crs=CRS.from_epsg(32616),
)
ds = IntersectionDataset(ds1, ds2)
assert len(ds) == 0
assert len(ds) == 1

def test_different_res(self) -> None:
ds1 = CustomGeoDataset(res=1)
Expand All @@ -419,8 +429,9 @@ def test_different_res(self) -> None:
def test_no_overlap(self) -> None:
ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5))
ds2 = CustomGeoDataset(BoundingBox(6, 7, 8, 9, 10, 11))
ds = IntersectionDataset(ds1, ds2)
assert len(ds) == 0
msg = "Datasets have no spatiotemporal intersection"
with pytest.raises(RuntimeError, match=msg):
IntersectionDataset(ds1, ds2)

def test_invalid_query(self, dataset: IntersectionDataset) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
Expand Down
77 changes: 39 additions & 38 deletions tests/datasets/test_splits.py
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@
)


def total_area(dataset: GeoDataset) -> float:
total_area = 0.0
for hit in dataset.index.intersection(dataset.index.bounds, objects=True):
total_area += BoundingBox(*hit.bounds).area

return total_area


def no_overlap(ds1: GeoDataset, ds2: GeoDataset) -> bool:
try:
ds = ds1 & ds2
except RuntimeError:
return True
else:
return isclose(total_area(ds), 0)


class CustomGeoDataset(GeoDataset):
def __init__(
self,
Expand Down Expand Up @@ -66,11 +83,9 @@ def test_random_bbox_assignment(
assert len(test_ds) == expected_lengths[2]

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
Expand All @@ -93,14 +108,6 @@ def test_random_bbox_assignment_invalid_inputs() -> None:
random_bbox_assignment(CustomGeoDataset(), lengths=[1 / 2, 3 / 4, -1 / 4])


def _get_total_area(dataset: GeoDataset) -> float:
total_area = 0.0
for hit in dataset.index.intersection(dataset.index.bounds, objects=True):
total_area += BoundingBox(*hit.bounds).area

return total_area


def test_random_bbox_splitting() -> None:
ds = CustomGeoDataset(
[
Expand All @@ -111,30 +118,28 @@ def test_random_bbox_splitting() -> None:
]
)

ds_area = _get_total_area(ds)
ds_area = total_area(ds)

train_ds, val_ds, test_ds = random_bbox_splitting(
ds, fractions=[1 / 2, 1 / 4, 1 / 4]
)
train_ds_area = _get_total_area(train_ds)
val_ds_area = _get_total_area(val_ds)
test_ds_area = _get_total_area(test_ds)
train_ds_area = total_area(train_ds)
val_ds_area = total_area(val_ds)
test_ds_area = total_area(test_ds)

# Check datasets areas
assert train_ds_area == ds_area / 2
assert val_ds_area == ds_area / 4
assert test_ds_area == ds_area / 4

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(_get_total_area(train_ds | val_ds | test_ds), ds_area)
assert isclose(total_area(train_ds | val_ds | test_ds), ds_area)

# Test __get_item__
x = train_ds[train_ds.bounds]
Expand Down Expand Up @@ -168,15 +173,13 @@ def test_random_grid_cell_assignment() -> None:
assert len(test_ds) == floor(1 / 4 * 2 * 5**2)

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds))
assert isclose(total_area(train_ds | val_ds | test_ds), total_area(ds))

# Test __get_item__
x = train_ds[train_ds.bounds]
Expand Down Expand Up @@ -219,15 +222,13 @@ def test_roi_split() -> None:
assert len(test_ds) == 1

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds))
assert isclose(total_area(train_ds | val_ds | test_ds), total_area(ds))

# Test __get_item__
x = train_ds[train_ds.bounds]
Expand Down Expand Up @@ -273,9 +274,9 @@ def test_time_series_split(
assert len(test_ds) == expected_lengths[2]

# No overlap
assert len(train_ds & val_ds) == 0
assert len(val_ds & test_ds) == 0
assert len(test_ds & train_ds) == 0
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
Expand Down
4 changes: 4 additions & 0 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ def __init__(
entry and returns a transformed version

Raises:
RuntimeError: if datasets have no spatiotemporal intersection
ValueError: if either dataset is not a :class:`GeoDataset`

.. versionadded:: 0.4
Expand Down Expand Up @@ -855,6 +856,9 @@ def _merge_dataset_indices(self) -> None:
self.index.insert(i, tuple(box1 & box2))
i += 1

if i == 0:
raise RuntimeError("Datasets have no spatiotemporal intersection")

def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image and metadata indexed by query.

Expand Down