From daecc9070926551d0991f04fb238242e9f126be6 Mon Sep 17 00:00:00 2001 From: tritolol <61182488+tritolol@users.noreply.github.com> Date: Thu, 17 Feb 2022 15:25:33 +0100 Subject: [PATCH] Fix forced int32 type conversion in RasterDataset (#384) * fix forced int32 type conversion * add fix for numpy dtypes which are not supported by tensors * delete whitespace * Adding custom data to test the dtype transform * Fixed formatting Co-authored-by: Caleb Robinson --- tests/data/raster/data.py | 42 ++++++++++++++++++++++++++++++++++++ tests/data/raster/test0.tif | Bin 0 -> 453 bytes tests/datasets/test_geo.py | 11 ++++++++++ torchgeo/datasets/geo.py | 7 +++++- 4 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 tests/data/raster/data.py create mode 100644 tests/data/raster/test0.tif diff --git a/tests/data/raster/data.py b/tests/data/raster/data.py new file mode 100644 index 00000000000..c6328b20470 --- /dev/null +++ b/tests/data/raster/data.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +import rasterio +import rasterio.transform +from torchvision.datasets.utils import calculate_md5 + + +def generate_test_data(fn: str) -> str: + """Creates test data with uint32 datatype. + + Args: + fn (str): Filename to write + + Returns: + str: md5 hash of created archive + """ + profile = { + "driver": "GTiff", + "dtype": "uint32", + "count": 1, + "crs": "epsg:4326", + "transform": rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1), + "height": 4, + "width": 4, + "compress": "lzw", + "predictor": 2, + } + + with rasterio.open(fn, "w", **profile) as f: + f.write(np.random.randint(0, 2**32 - 1, size=(1, 4, 4))) + + md5: str = calculate_md5(fn) + return md5 + + +if __name__ == "__main__": + md5_hash = generate_test_data(os.path.join(os.getcwd(), "test0.tif")) + print(md5_hash) diff --git a/tests/data/raster/test0.tif b/tests/data/raster/test0.tif new file mode 100644 index 0000000000000000000000000000000000000000..83d75bbdc20a51402023702c51bb49fecb38ba9e GIT binary patch literal 453 zcmebD)MDUZU|-pD0L7W1Y*rwf4ax@T5oBZm>#YKEM3KbB zplpzt;!ri-K(-8$8e1rv2`C8C9uWx5&Pyo_OK)W`y z^Duz)B>~wR+nK;VDFU();DbGc4Wkh<_*n?`j9{O!g9C;UXdBQQKoOuNz+hu!cqYKe zv9TQ}%fJTa_bGF7Y-k6`bAxFHhK}5lX|hX-9Dqt3mW4z(hPwwVSeVo>G^oodnh4%K z$85Xj)Jx}wTTQIx8MqW3Hf>P*sXB31mkV>ksXGO%*OU}*yo|UL+TzbJ!FtB4>9QtQ XPlPATI3s3Y5&FJ)h9rxu;{gW%KfXKW literal 0 HcmV?d00001 diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 491b7294574..0cc94966471 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -159,6 +159,11 @@ def sentinel(self, request: SubRequest) -> Sentinel2: cache = request.param return Sentinel2(root, bands=bands, transforms=transforms, cache=cache) + @pytest.fixture() + def custom_dtype_ds(self) -> RasterDataset: + root = os.path.join("tests", "data", "raster") + return RasterDataset(root) + def test_getitem_single_file(self, naip: NAIP) -> None: x = naip[naip.bounds] assert isinstance(x, dict) @@ -171,6 +176,12 @@ def test_getitem_separate_files(self, sentinel: Sentinel2) -> None: assert isinstance(x["crs"], CRS) assert isinstance(x["image"], torch.Tensor) + def test_getitem_uint_dtype(self, custom_dtype_ds: RasterDataset) -> None: + x = custom_dtype_ds[custom_dtype_ds.bounds] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert x["image"].dtype == torch.int64 # type: ignore[attr-defined] + def test_invalid_query(self, sentinel: Sentinel2) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index a6c0d073199..109bd50daba 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -461,7 +461,12 @@ def _merge_files(self, filepaths: Sequence[str], query: BoundingBox) -> Tensor: ) else: dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res) - dest = dest.astype(np.int32) + + # fix numpy dtypes which are not supported by pytorch tensors + if dest.dtype == np.uint16: + dest = dest.astype(np.int32) + elif dest.dtype == np.uint32: + dest = dest.astype(np.int64) tensor: Tensor = torch.tensor(dest) # type: ignore[attr-defined] return tensor