diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index f81b018668f..0790036865d 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -611,12 +611,12 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: ) if shapes: masks = rasterio.features.rasterize( - shapes, out_shape=(int(round(height)), int(round(width))), transform=transform + shapes, out_shape=(round(height), round(width)), transform=transform ) else: # If no features are found in this query, return an empty mask # with the default fill value and dtype used by rasterize - masks = np.zeros((int(round(height)), int(round(width))), dtype=np.uint8) + masks = np.zeros((round(height), round(width)), dtype=np.uint8) sample = {"mask": torch.tensor(masks), "crs": self.crs, "bbox": query}