Skip to content

Commit

Permalink
Fixes VectorDataset rounding bug causing sample mask size mismatch (m…
Browse files Browse the repository at this point in the history
…icrosoft#675)

* Fixes VectorDataset rounding bug causing sample mask size mismatch

* removes unnecessary casting to int
  • Loading branch information
TCherici authored Jul 18, 2022
1 parent 4f82d43 commit 927986a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_getitem(self, dataset: CustomVectorDataset) -> None:
def test_empty_shapes(self, dataset: CustomVectorDataset) -> None:
query = BoundingBox(1.1, 1.9, 1.1, 1.9, 0, 0)
x = dataset[query]
assert torch.equal(x["mask"], torch.zeros(7, 7, dtype=torch.uint8))
assert torch.equal(x["mask"], torch.zeros(8, 8, dtype=torch.uint8))

def test_invalid_query(self, dataset: CustomVectorDataset) -> None:
query = BoundingBox(3, 3, 3, 3, 0, 0)
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,12 +611,12 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
)
if shapes:
masks = rasterio.features.rasterize(
shapes, out_shape=(int(height), int(width)), transform=transform
shapes, out_shape=(round(height), round(width)), transform=transform
)
else:
# If no features are found in this query, return an empty mask
# with the default fill value and dtype used by rasterize
masks = np.zeros((int(height), int(width)), dtype=np.uint8)
masks = np.zeros((round(height), round(width)), dtype=np.uint8)

sample = {"mask": torch.tensor(masks), "crs": self.crs, "bbox": query}

Expand Down

0 comments on commit 927986a

Please sign in to comment.