diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 5742134d14a..5abb7766aa4 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -223,7 +223,7 @@ def test_getitem(self, dataset: CustomVectorDataset) -> None: def test_empty_shapes(self, dataset: CustomVectorDataset) -> None: query = BoundingBox(1.1, 1.9, 1.1, 1.9, 0, 0) x = dataset[query] - assert torch.equal(x["mask"], torch.zeros(7, 7, dtype=torch.uint8)) + assert torch.equal(x["mask"], torch.zeros(8, 8, dtype=torch.uint8)) def test_invalid_query(self, dataset: CustomVectorDataset) -> None: query = BoundingBox(3, 3, 3, 3, 0, 0) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index c74833d812e..0790036865d 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -611,12 +611,12 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: ) if shapes: masks = rasterio.features.rasterize( - shapes, out_shape=(int(height), int(width)), transform=transform + shapes, out_shape=(round(height), round(width)), transform=transform ) else: # If no features are found in this query, return an empty mask # with the default fill value and dtype used by rasterize - masks = np.zeros((int(height), int(width)), dtype=np.uint8) + masks = np.zeros((round(height), round(width)), dtype=np.uint8) sample = {"mask": torch.tensor(masks), "crs": self.crs, "bbox": query}