From 101bde6584323f920af77d51dfbc34af22138bec Mon Sep 17 00:00:00 2001 From: snowman2 Date: Thu, 31 Oct 2024 13:35:51 -0500 Subject: [PATCH] BUG:merge: Fix merging masked & scaled data --- docs/history.rst | 1 + rioxarray/merge.py | 17 +++++++++++++++-- test/integration/test_integration_merge.py | 21 ++++++++++++++++++++- 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/docs/history.rst b/docs/history.rst index 7c458ab7..5aa1959b 100644 --- a/docs/history.rst +++ b/docs/history.rst @@ -3,6 +3,7 @@ History Latest ------ +- BUG:merge: Fix merging masked and scaled data (issue #814) 0.17.0 ------ diff --git a/rioxarray/merge.py b/rioxarray/merge.py index 397a5185..21c993f6 100644 --- a/rioxarray/merge.py +++ b/rioxarray/merge.py @@ -52,8 +52,21 @@ def read(self, *args, **kwargs) -> numpy.ma.MaskedArray: This method is meant to be used by the rasterio.merge.merge function. """ with MemoryFile() as memfile: - self._xds.rio.to_raster(memfile.name) - with memfile.open() as dataset: + with memfile.open( + driver="GTiff", + height=int(self._xds.rio.height), + width=int(self._xds.rio.width), + count=self.count, + dtype=self.dtypes[0], + crs=self.crs, + transform=self.transform, + nodata=self.nodatavals[0], + ) as dataset: + data = self._xds.values + if data.ndim == 2: + dataset.write(data, 1) + else: + dataset.write(data) return dataset.read(*args, **kwargs) diff --git a/test/integration/test_integration_merge.py b/test/integration/test_integration_merge.py index 2ee2a48e..c9e64fdb 100644 --- a/test/integration/test_integration_merge.py +++ b/test/integration/test_integration_merge.py @@ -188,7 +188,6 @@ def test_merge_datasets(): (-4447802.078667, -10007554.677, -3335851.559, -8895604.157333), ) assert merged.rio.shape == (2400, 2400) - assert_almost_equal(merged[data_var].sum(), 4539666606551516) assert_almost_equal( tuple(merged[data_var].rio.transform()), ( @@ -208,6 +207,7 @@ def test_merge_datasets(): assert merged.rio.crs == rds.rio.crs assert merged.attrs == rds.attrs assert merged.encoding["grid_mapping"] == "spatial_ref" + assert_almost_equal(merged[data_var].sum(), 4539666606551516) @pytest.mark.xfail(os.name == "nt", reason="On windows the merged data is different.") @@ -263,3 +263,22 @@ def test_merge_datasets__res(): assert merged.attrs == rds.attrs assert merged.encoding["grid_mapping"] == "spatial_ref" assert_almost_equal(merged[data_var].sum(), 974566547463955) + + +@pytest.mark.parametrize("mask_and_scale", [True, False]) +def test_merge_datasets__mask_and_scale(mask_and_scale): + test_file = os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc") + with open_rasterio(test_file, mask_and_scale=mask_and_scale) as rds: + rds = rds.to_dataset() + datasets = [ + rds.isel(x=slice(100), y=slice(100)), + rds.isel(x=slice(100, None), y=slice(100, None)), + rds.isel(x=slice(100), y=slice(100, None)), + rds.isel(x=slice(100, None), y=slice(100)), + ] + merged = merge_datasets(datasets) + total = merged.air_temperature.sum() + if mask_and_scale: + assert_almost_equal(total, 133376696) + else: + assert_almost_equal(total, 10981781386)