Skip to content

Commit

Permalink
Fix forced int32 type conversion in RasterDataset (#384)
Browse files Browse the repository at this point in the history
* fix forced int32 type conversion

* add fix for numpy dtypes which are not supported by tensors

* delete whitespace

* Adding custom data to test the dtype transform

* Fixed formatting

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
  • Loading branch information
tritolol and calebrob6 authored Feb 17, 2022
1 parent 2c6e7ee commit daecc90
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 1 deletion.
42 changes: 42 additions & 0 deletions tests/data/raster/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import numpy as np
import rasterio
import rasterio.transform
from torchvision.datasets.utils import calculate_md5


def generate_test_data(fn: str) -> str:
"""Creates test data with uint32 datatype.
Args:
fn (str): Filename to write
Returns:
str: md5 hash of created archive
"""
profile = {
"driver": "GTiff",
"dtype": "uint32",
"count": 1,
"crs": "epsg:4326",
"transform": rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1),
"height": 4,
"width": 4,
"compress": "lzw",
"predictor": 2,
}

with rasterio.open(fn, "w", **profile) as f:
f.write(np.random.randint(0, 2**32 - 1, size=(1, 4, 4)))

md5: str = calculate_md5(fn)
return md5


if __name__ == "__main__":
md5_hash = generate_test_data(os.path.join(os.getcwd(), "test0.tif"))
print(md5_hash)
Binary file added tests/data/raster/test0.tif
Binary file not shown.
11 changes: 11 additions & 0 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ def sentinel(self, request: SubRequest) -> Sentinel2:
cache = request.param
return Sentinel2(root, bands=bands, transforms=transforms, cache=cache)

@pytest.fixture()
def custom_dtype_ds(self) -> RasterDataset:
root = os.path.join("tests", "data", "raster")
return RasterDataset(root)

def test_getitem_single_file(self, naip: NAIP) -> None:
x = naip[naip.bounds]
assert isinstance(x, dict)
Expand All @@ -171,6 +176,12 @@ def test_getitem_separate_files(self, sentinel: Sentinel2) -> None:
assert isinstance(x["crs"], CRS)
assert isinstance(x["image"], torch.Tensor)

def test_getitem_uint_dtype(self, custom_dtype_ds: RasterDataset) -> None:
x = custom_dtype_ds[custom_dtype_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert x["image"].dtype == torch.int64 # type: ignore[attr-defined]

def test_invalid_query(self, sentinel: Sentinel2) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
Expand Down
7 changes: 6 additions & 1 deletion torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,12 @@ def _merge_files(self, filepaths: Sequence[str], query: BoundingBox) -> Tensor:
)
else:
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res)
dest = dest.astype(np.int32)

# fix numpy dtypes which are not supported by pytorch tensors
if dest.dtype == np.uint16:
dest = dest.astype(np.int32)
elif dest.dtype == np.uint32:
dest = dest.astype(np.int64)

tensor: Tensor = torch.tensor(dest) # type: ignore[attr-defined]
return tensor
Expand Down

0 comments on commit daecc90

Please sign in to comment.