Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VectorDataset: fix issue with empty query #467

Merged
merged 2 commits into from
Mar 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions tests/data/vector/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import json

# Create an L shape:
#
# +--+
# | |
# +--+--+
# | | |
# +--+--+
#
# This allows us to test queries:
#
# * within the L
# * within the dataset bounding box but with no features
# * outside the dataset bounding box

geojson = {
"type": "FeatureCollection",
"crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}},
"features": [
{
"type": "Feature",
"properties": {},
"geometry": {
"type": "Polygon",
"coordinates": [
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]
],
},
},
{
"type": "Feature",
"properties": {},
"geometry": {
"type": "Polygon",
"coordinates": [
[[1.0, 0.0], [1.0, 1.0], [2.0, 1.0], [2.0, 0.0], [1.0, 0.0]]
],
},
},
{
"type": "Feature",
"properties": {},
"geometry": {
"type": "Polygon",
"coordinates": [
[[0.0, 1.0], [0.0, 2.0], [1.0, 2.0], [1.0, 1.0], [0.0, 1.0]]
],
},
},
],
}

with open("vector.geojson", "w") as f:
json.dump(geojson, f)
1 change: 1 addition & 0 deletions tests/data/vector/vector.geojson
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[1.0, 0.0], [1.0, 1.0], [2.0, 1.0], [2.0, 0.0], [1.0, 0.0]]]}}, {"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 1.0], [0.0, 2.0], [1.0, 2.0], [1.0, 1.0], [0.0, 1.0]]]}}]}
24 changes: 16 additions & 8 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from torchgeo.datasets import (
NAIP,
BoundingBox,
CanadianBuildingFootprints,
GeoDataset,
IntersectionDataset,
RasterDataset,
Expand Down Expand Up @@ -44,6 +43,10 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]:
return {"index": query}


class CustomVectorDataset(VectorDataset):
filename_glob = "*.geojson"


class CustomVisionDataset(VisionDataset):
def __getitem__(self, index: int) -> Dict[str, int]:
return {"index": index}
Expand Down Expand Up @@ -201,20 +204,25 @@ def test_plot_with_cmap(self, custom_dtype_ds: RasterDataset) -> None:


class TestVectorDataset:
@pytest.fixture
def dataset(self) -> CanadianBuildingFootprints:
root = os.path.join("tests", "data", "cbf")
@pytest.fixture(scope="class")
def dataset(self) -> CustomVectorDataset:
root = os.path.join("tests", "data", "vector")
transforms = nn.Identity() # type: ignore[no-untyped-call]
return CanadianBuildingFootprints(root, res=0.1, transforms=transforms)
return CustomVectorDataset(root, res=0.1, transforms=transforms)

def test_getitem(self, dataset: CanadianBuildingFootprints) -> None:
def test_getitem(self, dataset: CustomVectorDataset) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)

def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None:
query = BoundingBox(2, 2, 2, 2, 2, 2)
def test_empty_shapes(self, dataset: CustomVectorDataset) -> None:
query = BoundingBox(1.1, 1.9, 1.1, 1.9, 0, 0)
x = dataset[query]
assert torch.equal(x["mask"], torch.zeros(7, 7, dtype=torch.uint8))

def test_invalid_query(self, dataset: CustomVectorDataset) -> None:
query = BoundingBox(3, 3, 3, 3, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
Expand Down
11 changes: 8 additions & 3 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,9 +661,14 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
transform = rasterio.transform.from_bounds(
query.minx, query.miny, query.maxx, query.maxy, width, height
)
masks = rasterio.features.rasterize(
shapes, out_shape=(int(height), int(width)), transform=transform
)
if shapes:
masks = rasterio.features.rasterize(
shapes, out_shape=(int(height), int(width)), transform=transform
)
else:
# If no features are found in this query, return an empty mask
# with the default fill value and dtype used by rasterize
masks = np.zeros((int(height), int(width)), dtype=np.uint8)

sample = {"mask": torch.tensor(masks), "crs": self.crs, "bbox": query}

Expand Down