From 5496d23e81e905ae70779ec7b0a4349e3f009a65 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Wed, 13 Nov 2024 10:02:50 +0100 Subject: [PATCH 1/4] Xarray: add indexes options and better define band names --- CHANGES.md | 70 ++++++++++++++++++ rio_tiler/io/xarray.py | 89 +++++++++++++++++------ tests/test_io_xarray.py | 155 ++++++++++++++++++++++++++++++++-------- 3 files changed, 261 insertions(+), 53 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index ea253ff6..a318bc4b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,73 @@ +# 7.3.0 (TBD) + +* add `indexes` parameter for `XarrayReader` methods. As for Rasterio, the indexes values start at `1`. + + ```python + data = ... # DataArray of shape (2, x, y) + + # before + with XarrayReader(data) as dst: + img = dst.tile(0, 0, 0) + assert img.count == 2 + + # now + with XarrayReader(data) as dst: + # Select the first `band` within the data array + img = dst.tile(0, 0, 0, indexes=1) + assert img.count == 1 + ``` + +* better define `band names` for `XarrayReader` objects + + * band_name for `2D` dataset is extracted form the first `non-geo` coordinates value + + ```python + data = xarray.DataArray( + numpy.arange(0.0, 33 * 35 * 2).reshape(2, 33, 35), + dims=("time", "y", "x"), + coords={ + "x": numpy.arange(-170, 180, 10), + "y": numpy.arange(-80, 85, 5), + "time": [datetime(2022, 1, 1), datetime(2022, 1, 2)], + }, + ) + da = data[0] + + print(da.coords["time"].data) + >> array('2022-01-01T00:00:00.000000000', dtype='datetime64[ns]')) + + # before + with XarrayReader(data) as dst: + img = dst.info() + print(img.band_descriptions)[0] + >> ("b1", "value") + + # now + with XarrayReader(data) as dst: + img = dst.info() + print(img.band_descriptions)[0] + >> ("b1", "2022-01-01T00:00:00.000000000") + ``` + + * default `band_names` is changed to DataArray's name or `array` (when no available coordinates value) + + ```python + data = ... # DataArray of shape (x, y) + + # before + with XarrayReader(data) as dst: + img = dst.info() + print(img.band_descriptions)[0] + >> ("b1", "value") + + # now + with XarrayReader(data) as dst: + img = dst.info() + print(img.band_descriptions)[0] + >> ("b1", "array") + ``` + + # 7.2.0 (2024-11-05) * Ensure compatibility between XarrayReader and other Readers by adding `**kwargs` on class methods (https://github.com/cogeotiff/rio-tiler/pull/762) diff --git a/rio_tiler/io/xarray.py b/rio_tiler/io/xarray.py index d5ae0369..095d45e9 100644 --- a/rio_tiler/io/xarray.py +++ b/rio_tiler/io/xarray.py @@ -3,7 +3,7 @@ from __future__ import annotations import warnings -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import attr import numpy @@ -28,8 +28,13 @@ from rio_tiler.io.base import BaseReader from rio_tiler.models import BandStatistics, ImageData, Info, PointData from rio_tiler.reader import _get_width_height -from rio_tiler.types import BBox, NoData, RIOResampling, WarpResampling -from rio_tiler.utils import CRS_to_uri, _validate_shape_input, get_array_statistics +from rio_tiler.types import BBox, Indexes, NoData, RIOResampling, WarpResampling +from rio_tiler.utils import ( + CRS_to_uri, + _validate_shape_input, + cast_to_sequence, + get_array_statistics, +) try: import xarray @@ -105,6 +110,7 @@ def __attrs_post_init__(self): for d in self.input.dims if d not in [self.input.rio.x_dim, self.input.rio.y_dim] ] + assert len(self._dims) in [0, 1], "Can't handle >=4D DataArray" @property def minzoom(self): @@ -118,29 +124,34 @@ def maxzoom(self): @property def band_names(self) -> List[str]: - """Return list of `band names` in DataArray.""" - return [str(band) for d in self._dims for band in self.input[d].values] or [ - "value" - ] + """Return list of `band descriptions` in DataArray.""" + if not self._dims: + coords_name = list(self.input.coords) + if len(coords_name) > 3 and (coord := coords_name[2]): + return [str(self.input.coords[coord].data)] + + return [self.input.name or "array"] + + return [str(band) for d in self._dims for band in self.input[d].values] def info(self) -> Info: """Return xarray.DataArray info.""" - bands = [str(band) for d in self._dims for band in self.input[d].values] or [ - "value" - ] metadata = [band.attrs for d in self._dims for band in self.input[d]] or [{}] meta = { "bounds": self.bounds, "crs": CRS_to_uri(self.crs) or self.crs.to_wkt(), "band_metadata": [(f"b{ix}", v) for ix, v in enumerate(metadata, 1)], - "band_descriptions": [(f"b{ix}", v) for ix, v in enumerate(bands, 1)], + "band_descriptions": [ + (f"b{ix}", v) for ix, v in enumerate(self.band_names, 1) + ], "dtype": str(self.input.dtype), "nodata_type": "Nodata" if self.input.rio.nodata is not None else "None", "name": self.input.name, "count": self.input.rio.count, "width": self.input.rio.width, "height": self.input.rio.height, + "dimensions": self.input.dims, "attrs": { k: (v.tolist() if isinstance(v, (numpy.ndarray, numpy.generic)) else v) for k, v in self.input.attrs.items() @@ -149,6 +160,28 @@ def info(self) -> Info: return Info(**meta) + def _sel_indexes( + self, indexes: Optional[Indexes] = None + ) -> Tuple[xarray.DataArray, List[str]]: + """Select `band` indexes in DataArray.""" + ds = self.input + band_names = self.band_names + if indexes := cast_to_sequence(indexes): + assert all(v > 0 for v in indexes), "Indexes value must be >= 1" + if ds.ndim == 2: + if indexes != (1,): + raise ValueError( + f"Invalid indexes {indexes} for array of shape {ds.shape}" + ) + + return ds, band_names + + indexes = [idx - 1 for idx in indexes] + ds = ds[indexes] + band_names = [self.band_names[idx] for idx in indexes] + + return ds, band_names + def statistics( self, categorical: bool = False, @@ -156,12 +189,14 @@ def statistics( percentiles: Optional[List[int]] = None, hist_options: Optional[Dict] = None, nodata: Optional[NoData] = None, + indexes: Optional[Indexes] = None, **kwargs: Any, ) -> Dict[str, BandStatistics]: """Return statistics from a dataset.""" hist_options = hist_options or {} - ds = self.input + ds, band_names = self._sel_indexes(indexes) + if nodata is not None: ds = ds.rio.write_nodata(nodata) @@ -176,9 +211,7 @@ def statistics( **hist_options, ) - return { - self.band_names[ix]: BandStatistics(**val) for ix, val in enumerate(stats) - } + return {band_names[ix]: BandStatistics(**val) for ix, val in enumerate(stats)} def tile( self, @@ -189,6 +222,7 @@ def tile( reproject_method: WarpResampling = "nearest", auto_expand: bool = True, nodata: Optional[NoData] = None, + indexes: Optional[Indexes] = None, **kwargs: Any, ) -> ImageData: """Read a Web Map tile from a dataset. @@ -211,7 +245,8 @@ def tile( f"Tile(x={tile_x}, y={tile_y}, z={tile_z}) is outside bounds" ) - ds = self.input + ds, band_names = self._sel_indexes(indexes) + if nodata is not None: ds = ds.rio.write_nodata(nodata) @@ -251,7 +286,7 @@ def tile( bounds=tile_bounds, crs=dst_crs, dataset_statistics=stats, - band_names=self.band_names, + band_names=band_names, ) def part( @@ -262,6 +297,7 @@ def part( reproject_method: WarpResampling = "nearest", auto_expand: bool = True, nodata: Optional[NoData] = None, + indexes: Optional[Indexes] = None, max_size: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None, @@ -294,7 +330,8 @@ def part( dst_crs = dst_crs or bounds_crs - ds = self.input + ds, band_names = self._sel_indexes(indexes) + if nodata is not None: ds = ds.rio.write_nodata(nodata) @@ -339,7 +376,7 @@ def part( bounds=ds.rio.bounds(), crs=ds.rio.crs, dataset_statistics=stats, - band_names=self.band_names, + band_names=band_names, ) output_height = height or img.height @@ -362,6 +399,7 @@ def preview( height: Optional[int] = None, width: Optional[int] = None, nodata: Optional[NoData] = None, + indexes: Optional[Indexes] = None, dst_crs: Optional[CRS] = None, reproject_method: WarpResampling = "nearest", resampling_method: RIOResampling = "nearest", @@ -388,7 +426,8 @@ def preview( UserWarning, ) - ds = self.input + ds, band_names = self._sel_indexes(indexes) + if nodata is not None: ds = ds.rio.write_nodata(nodata) @@ -427,7 +466,7 @@ def preview( bounds=ds.rio.bounds(), crs=ds.rio.crs, dataset_statistics=stats, - band_names=self.band_names, + band_names=band_names, ) output_height = height or img.height @@ -450,6 +489,7 @@ def point( lat: float, coord_crs: CRS = WGS84_CRS, nodata: Optional[NoData] = None, + indexes: Optional[Indexes] = None, **kwargs: Any, ) -> PointData: """Read a pixel value from a dataset. @@ -472,7 +512,8 @@ def point( ): raise PointOutsideBounds("Point is outside dataset bounds") - ds = self.input + ds, band_names = self._sel_indexes(indexes) + if nodata is not None: ds = ds.rio.write_nodata(nodata) @@ -489,7 +530,7 @@ def point( arr, coordinates=(lon, lat), crs=coord_crs, - band_names=self.band_names, + band_names=band_names, ) def feature( @@ -500,6 +541,7 @@ def feature( reproject_method: WarpResampling = "nearest", auto_expand: bool = True, nodata: Optional[NoData] = None, + indexes: Optional[Indexes] = None, max_size: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None, @@ -537,6 +579,7 @@ def feature( dst_crs=dst_crs, bounds_crs=shape_crs, nodata=nodata, + indexes=indexes, max_size=max_size, width=width, height=height, diff --git a/tests/test_io_xarray.py b/tests/test_io_xarray.py index b7652f19..0a0d6f74 100644 --- a/tests/test_io_xarray.py +++ b/tests/test_io_xarray.py @@ -15,14 +15,14 @@ def test_xarray_reader(): """test XarrayReader.""" - arr = numpy.arange(0.0, 33 * 35).reshape(1, 33, 35) + arr = numpy.arange(0.0, 33 * 35 * 2).reshape(2, 33, 35) data = xarray.DataArray( arr, dims=("time", "y", "x"), coords={ "x": numpy.arange(-170, 180, 10), "y": numpy.arange(-80, 85, 5), - "time": [datetime(2022, 1, 1)], + "time": [datetime(2022, 1, 1), datetime(2022, 1, 2)], }, ) data.attrs.update({"valid_min": arr.min(), "valid_max": arr.max()}) @@ -34,23 +34,50 @@ def test_xarray_reader(): assert info.bounds == dst.bounds crs = info.crs assert rioCRS.from_user_input(crs) == dst.crs - assert info.band_metadata == [("b1", {})] - assert info.band_descriptions == [("b1", "2022-01-01T00:00:00.000000000")] + assert info.band_metadata == [("b1", {}), ("b2", {})] + assert info.band_descriptions == [ + ("b1", "2022-01-01T00:00:00.000000000"), + ("b2", "2022-01-02T00:00:00.000000000"), + ] assert info.height == 33 assert info.width == 35 - assert info.count == 1 + assert info.count == 2 assert info.attrs stats = dst.statistics() - assert stats["2022-01-01T00:00:00.000000000"] + assert list(stats) == [ + "2022-01-01T00:00:00.000000000", + "2022-01-02T00:00:00.000000000", + ] assert stats["2022-01-01T00:00:00.000000000"].min == 0.0 + stats = dst.statistics(indexes=1) + assert list(stats) == ["2022-01-01T00:00:00.000000000"] + + stats = dst.statistics(indexes=2) + assert list(stats) == ["2022-01-02T00:00:00.000000000"] + + stats = dst.statistics(indexes=(1, 2)) + assert list(stats) == [ + "2022-01-01T00:00:00.000000000", + "2022-01-02T00:00:00.000000000", + ] + + with pytest.raises(AssertionError): + stats = dst.statistics(indexes=(0,)) + + with pytest.raises(AssertionError): + stats = dst.statistics(indexes=0) + img = dst.tile(0, 0, 0) - assert img.count == 1 + assert img.count == 2 assert img.width == 256 assert img.height == 256 - assert img.band_names == ["2022-01-01T00:00:00.000000000"] - assert img.dataset_statistics == ((arr.min(), arr.max()),) + assert img.band_names == [ + "2022-01-01T00:00:00.000000000", + "2022-01-02T00:00:00.000000000", + ] + assert img.dataset_statistics == ((arr.min(), arr.max()), (arr.min(), arr.max())) # Tests for auto_expand # Test that a high-zoom tile will error with auto_expand=False @@ -64,59 +91,86 @@ def test_xarray_reader(): assert "At least one of the clipped raster x,y coordinates" in str(error.value) # Test that a high-zoom tile will succeed with auto_expand=True (and that is the default) - img = dst.tile(tile.x, tile.y, zoom) + img = dst.tile(tile.x, tile.y, zoom, indexes=1) assert img.count == 1 assert img.width == 256 assert img.height == 256 assert img.bounds == bounds assert img.dataset_statistics == ((arr.min(), arr.max()),) - img = dst.part((-160, -80, 160, 80)) + img = dst.part((-160, -80, 160, 80), indexes=1) assert img.crs == "epsg:4326" assert img.count == 1 assert img.band_names == ["2022-01-01T00:00:00.000000000"] assert img.array.shape == (1, 33, 33) - img = dst.part((-160, -80, 160, 80), dst_crs="epsg:3857") + img = dst.part((-160, -80, 160, 80), dst_crs="epsg:3857", indexes=1) assert img.crs == "epsg:3857" assert img.count == 1 assert img.band_names == ["2022-01-01T00:00:00.000000000"] assert img.array.shape == (1, 32, 34) - img = dst.part((-160, -80, 160, 80), max_size=15) + img = dst.part((-160, -80, 160, 80), max_size=15, indexes=1) assert img.array.shape == (1, 15, 15) - img = dst.part((-160, -80, 160, 80), width=40, height=35) + img = dst.part((-160, -80, 160, 80), width=40, height=35, indexes=1) assert img.array.shape == (1, 35, 40) - img = dst.part((-160, -80, 160, 80), max_size=15, resampling_method="bilinear") + img = dst.part( + (-160, -80, 160, 80), max_size=15, resampling_method="bilinear", indexes=1 + ) assert img.array.shape == (1, 15, 15) img = dst.preview() assert img.crs == "epsg:4326" + assert img.count == 2 + assert img.band_names == [ + "2022-01-01T00:00:00.000000000", + "2022-01-02T00:00:00.000000000", + ] + assert img.array.shape == (2, 33, 35) + + img = dst.preview(indexes=1) + assert img.crs == "epsg:4326" assert img.count == 1 assert img.band_names == ["2022-01-01T00:00:00.000000000"] assert img.array.shape == (1, 33, 35) - img = dst.preview(dst_crs="epsg:3857") + img = dst.preview(dst_crs="epsg:3857", indexes=1) assert img.crs == "epsg:3857" assert img.count == 1 assert img.band_names == ["2022-01-01T00:00:00.000000000"] assert img.array.shape == (1, 32, 36) - img = dst.preview(max_size=None) + img = dst.preview(max_size=None, indexes=1) assert img.array.shape == (1, 33, 35) - img = dst.preview(max_size=15) + img = dst.preview(max_size=15, indexes=1) assert img.array.shape == (1, 15, 15) - img = dst.preview(max_size=15, resampling_method="bilinear") + img = dst.preview(max_size=15, resampling_method="bilinear", indexes=1) assert img.array.shape == (1, 15, 15) - img = dst.preview(height=25, width=25, max_size=None) + img = dst.preview(height=25, width=25, max_size=None, indexes=1) assert img.array.shape == (1, 25, 25) pt = dst.point(0, 0) + assert pt.count == 2 + assert pt.band_names == [ + "2022-01-01T00:00:00.000000000", + "2022-01-02T00:00:00.000000000", + ] + assert pt.coordinates + xys = [[0, 2.499], [0, 2.501], [-4.999, 0], [-5.001, 0], [-170, 80]] + for xy in xys: + x = xy[0] + y = xy[1] + pt = dst.point(x, y) + numpy.testing.assert_array_equal( + pt.data, data.sel(x=x, y=y, method="nearest").to_numpy() + ) + + pt = dst.point(0, 0, indexes=1) assert pt.count == 1 assert pt.band_names == ["2022-01-01T00:00:00.000000000"] assert pt.coordinates @@ -125,7 +179,9 @@ def test_xarray_reader(): x = xy[0] y = xy[1] pt = dst.point(x, y) - assert pt.data[0] == data.sel(x=x, y=y, method="nearest") + assert pt.data[0] == data.sel( + time="2022-01-01T00:00:00.000000000", x=x, y=y, method="nearest" + ) feat = { "type": "Feature", @@ -148,22 +204,49 @@ def test_xarray_reader(): }, } img = dst.feature(feat) + assert img.count == 2 + assert img.band_names == [ + "2022-01-01T00:00:00.000000000", + "2022-01-02T00:00:00.000000000", + ] + assert img.array.shape == (2, 25, 32) + + img = dst.feature(feat, indexes=1) assert img.count == 1 assert img.band_names == ["2022-01-01T00:00:00.000000000"] assert img.array.shape == (1, 25, 32) - img = dst.feature(feat, dst_crs="epsg:3857") + img = dst.feature(feat, dst_crs="epsg:3857", indexes=1) assert img.count == 1 assert img.band_names == ["2022-01-01T00:00:00.000000000"] assert img.crs == "epsg:3857" assert img.array.shape == (1, 20, 35) - img = dst.feature(feat, max_size=15) + img = dst.feature(feat, max_size=15, indexes=1) assert img.array.shape == (1, 12, 15) - img = dst.feature(feat, width=50, height=45) + img = dst.feature(feat, width=50, height=45, indexes=1) assert img.array.shape == (1, 45, 50) + # Select the first value + da = data[0] + assert da.ndim == 2 + with XarrayReader(da) as dst: + assert dst.band_names == ["2022-01-01T00:00:00.000000000"] + info = dst.info() + assert info.band_descriptions == [("b1", "2022-01-01T00:00:00.000000000")] + + stats = dst.statistics() + assert stats["2022-01-01T00:00:00.000000000"] + assert stats["2022-01-01T00:00:00.000000000"].min == 0.0 + + img = dst.tile(0, 0, 0) + assert img.count == 1 + assert img.width == 256 + assert img.height == 256 + assert img.band_names == ["2022-01-01T00:00:00.000000000"] + assert img.dataset_statistics == ((arr.min(), arr.max()),) + arr = numpy.zeros((1, 1000, 2000)) data = xarray.DataArray( arr, @@ -506,33 +589,45 @@ def test_xarray_reader_no_dims(): crs = info.crs assert rioCRS.from_user_input(crs) == dst.crs assert info.band_metadata == [("b1", {})] - assert info.band_descriptions == [("b1", "value")] + assert info.band_descriptions == [("b1", "array")] assert info.height == 33 assert info.width == 35 assert info.count == 1 assert info.attrs stats = dst.statistics() - assert stats["value"] - assert stats["value"].min == 0.0 + assert stats["array"] + assert stats["array"].min == 0.0 + + stats = dst.statistics(indexes=1) + assert stats["array"] + + stats = dst.statistics(indexes=(1,)) + assert stats["array"] + + with pytest.raises(ValueError): + stats = dst.statistics(indexes=2) + + with pytest.raises(ValueError): + stats = dst.statistics(indexes=(1, 2)) img = dst.tile(0, 0, 0) assert img.count == 1 assert img.width == 256 assert img.height == 256 - assert img.band_names == ["value"] + assert img.band_names == ["array"] assert img.dataset_statistics == ((arr.min(), arr.max()),) img = dst.part((-160, -80, 160, 80)) assert img.count == 1 assert img.width == 33 assert img.height == 33 - assert img.band_names == ["value"] + assert img.band_names == ["array"] assert img.dataset_statistics == ((arr.min(), arr.max()),) pt = dst.point(0, 0) assert pt.count == 1 - assert pt.band_names == ["value"] + assert pt.band_names == ["array"] assert pt.coordinates xys = [[0, 2.499], [0, 2.501], [-4.999, 0], [-5.001, 0], [-170, 80]] for xy in xys: From 4217434ddb2c257b8496fc2ef64a07b36da668eb Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Wed, 13 Nov 2024 10:07:32 +0100 Subject: [PATCH 2/4] fix --- tests/test_io_stac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_io_stac.py b/tests/test_io_stac.py index 7932dfcb..f99cea00 100644 --- a/tests/test_io_stac.py +++ b/tests/test_io_stac.py @@ -1060,7 +1060,7 @@ def _get_reader(self, asset_info: AssetInfo) -> Tuple[Type[BaseReader], Dict]: assert info["netcdf"].crs img = stac.preview(assets=["netcdf"]) - assert img.band_names == ["netcdf_value"] + assert img.band_names == ["netcdf_dataset"] @patch("rio_tiler.io.stac.STAC_ALTERNATE_KEY", "s3") From a75574c2efc88285d894ddce8f97aa7f54aaa1ed Mon Sep 17 00:00:00 2001 From: Vincent Sarago Date: Tue, 26 Nov 2024 14:00:09 +0100 Subject: [PATCH 3/4] Update rio_tiler/io/xarray.py Co-authored-by: Max Jones <14077947+maxrjones@users.noreply.github.com> --- rio_tiler/io/xarray.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/rio_tiler/io/xarray.py b/rio_tiler/io/xarray.py index 095d45e9..4f3be31a 100644 --- a/rio_tiler/io/xarray.py +++ b/rio_tiler/io/xarray.py @@ -124,7 +124,11 @@ def maxzoom(self): @property def band_names(self) -> List[str]: - """Return list of `band descriptions` in DataArray.""" + """ + Return list of `band descriptions` in DataArray. + + `Bands` are all dimensions not defined as spatial dims by rioxarray. + """ if not self._dims: coords_name = list(self.input.coords) if len(coords_name) > 3 and (coord := coords_name[2]): From e78e30cf62781cc62eccef8ae3442096033a06f5 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Tue, 26 Nov 2024 14:05:47 +0100 Subject: [PATCH 4/4] ds -> da --- rio_tiler/io/xarray.py | 122 ++++++++++++++++++++--------------------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/rio_tiler/io/xarray.py b/rio_tiler/io/xarray.py index 4f3be31a..1fb999da 100644 --- a/rio_tiler/io/xarray.py +++ b/rio_tiler/io/xarray.py @@ -126,8 +126,8 @@ def maxzoom(self): def band_names(self) -> List[str]: """ Return list of `band descriptions` in DataArray. - - `Bands` are all dimensions not defined as spatial dims by rioxarray. + + `Bands` are all dimensions not defined as spatial dims by rioxarray. """ if not self._dims: coords_name = list(self.input.coords) @@ -168,23 +168,23 @@ def _sel_indexes( self, indexes: Optional[Indexes] = None ) -> Tuple[xarray.DataArray, List[str]]: """Select `band` indexes in DataArray.""" - ds = self.input + da = self.input band_names = self.band_names if indexes := cast_to_sequence(indexes): assert all(v > 0 for v in indexes), "Indexes value must be >= 1" - if ds.ndim == 2: + if da.ndim == 2: if indexes != (1,): raise ValueError( - f"Invalid indexes {indexes} for array of shape {ds.shape}" + f"Invalid indexes {indexes} for array of shape {da.shape}" ) - return ds, band_names + return da, band_names indexes = [idx - 1 for idx in indexes] - ds = ds[indexes] + da = da[indexes] band_names = [self.band_names[idx] for idx in indexes] - return ds, band_names + return da, band_names def statistics( self, @@ -199,13 +199,13 @@ def statistics( """Return statistics from a dataset.""" hist_options = hist_options or {} - ds, band_names = self._sel_indexes(indexes) + da, band_names = self._sel_indexes(indexes) if nodata is not None: - ds = ds.rio.write_nodata(nodata) + da = da.rio.write_nodata(nodata) - data = ds.to_masked_array() - data.mask |= data.data == ds.rio.nodata + data = da.to_masked_array() + data.mask |= data.data == da.rio.nodata stats = get_array_statistics( data, @@ -249,21 +249,21 @@ def tile( f"Tile(x={tile_x}, y={tile_y}, z={tile_z}) is outside bounds" ) - ds, band_names = self._sel_indexes(indexes) + da, band_names = self._sel_indexes(indexes) if nodata is not None: - ds = ds.rio.write_nodata(nodata) + da = da.rio.write_nodata(nodata) tile_bounds = tuple(self.tms.xy_bounds(Tile(x=tile_x, y=tile_y, z=tile_z))) dst_crs = self.tms.rasterio_crs # Create source array by clipping the xarray dataset to extent of the tile. - ds = ds.rio.clip_box( + da = da.rio.clip_box( *tile_bounds, crs=dst_crs, auto_expand=auto_expand, ) - ds = ds.rio.reproject( + da = da.rio.reproject( dst_crs, shape=(tilesize, tilesize), transform=from_bounds(*tile_bounds, height=tilesize, width=tilesize), @@ -272,16 +272,16 @@ def tile( ) # Forward valid_min/valid_max to the ImageData object - minv, maxv = ds.attrs.get("valid_min"), ds.attrs.get("valid_max") + minv, maxv = da.attrs.get("valid_min"), da.attrs.get("valid_max") stats = None if minv is not None and maxv is not None and nodata not in [minv, maxv]: - stats = ((minv, maxv),) * ds.rio.count + stats = ((minv, maxv),) * da.rio.count - arr = ds.to_masked_array() - arr.mask |= arr.data == ds.rio.nodata + arr = da.to_masked_array() + arr.mask |= arr.data == da.rio.nodata - output_bounds = ds.rio._unordered_bounds() - if output_bounds[1] > output_bounds[3] and ds.rio.transform().e > 0: + output_bounds = da.rio._unordered_bounds() + if output_bounds[1] > output_bounds[3] and da.rio.transform().e > 0: yaxis = self.input.dims.index(self.input.rio.y_dim) arr = numpy.flip(arr, axis=yaxis) @@ -334,12 +334,12 @@ def part( dst_crs = dst_crs or bounds_crs - ds, band_names = self._sel_indexes(indexes) + da, band_names = self._sel_indexes(indexes) if nodata is not None: - ds = ds.rio.write_nodata(nodata) + da = da.rio.write_nodata(nodata) - ds = ds.rio.clip_box( + da = da.rio.clip_box( *bbox, crs=bounds_crs, auto_expand=auto_expand, @@ -349,11 +349,11 @@ def part( dst_transform, w, h = calculate_default_transform( self.crs, dst_crs, - ds.rio.width, - ds.rio.height, - *ds.rio.bounds(), + da.rio.width, + da.rio.height, + *da.rio.bounds(), ) - ds = ds.rio.reproject( + da = da.rio.reproject( dst_crs, shape=(h, w), transform=dst_transform, @@ -362,23 +362,23 @@ def part( ) # Forward valid_min/valid_max to the ImageData object - minv, maxv = ds.attrs.get("valid_min"), ds.attrs.get("valid_max") + minv, maxv = da.attrs.get("valid_min"), da.attrs.get("valid_max") stats = None if minv is not None and maxv is not None: - stats = ((minv, maxv),) * ds.rio.count + stats = ((minv, maxv),) * da.rio.count - arr = ds.to_masked_array() - arr.mask |= arr.data == ds.rio.nodata + arr = da.to_masked_array() + arr.mask |= arr.data == da.rio.nodata - output_bounds = ds.rio._unordered_bounds() - if output_bounds[1] > output_bounds[3] and ds.rio.transform().e > 0: + output_bounds = da.rio._unordered_bounds() + if output_bounds[1] > output_bounds[3] and da.rio.transform().e > 0: yaxis = self.input.dims.index(self.input.rio.y_dim) arr = numpy.flip(arr, axis=yaxis) img = ImageData( arr, - bounds=ds.rio.bounds(), - crs=ds.rio.crs, + bounds=da.rio.bounds(), + crs=da.rio.crs, dataset_statistics=stats, band_names=band_names, ) @@ -430,20 +430,20 @@ def preview( UserWarning, ) - ds, band_names = self._sel_indexes(indexes) + da, band_names = self._sel_indexes(indexes) if nodata is not None: - ds = ds.rio.write_nodata(nodata) + da = da.rio.write_nodata(nodata) if dst_crs and dst_crs != self.crs: dst_transform, w, h = calculate_default_transform( self.crs, dst_crs, - ds.rio.width, - ds.rio.height, - *ds.rio.bounds(), + da.rio.width, + da.rio.height, + *da.rio.bounds(), ) - ds = ds.rio.reproject( + da = da.rio.reproject( dst_crs, shape=(h, w), transform=dst_transform, @@ -452,23 +452,23 @@ def preview( ) # Forward valid_min/valid_max to the ImageData object - minv, maxv = ds.attrs.get("valid_min"), ds.attrs.get("valid_max") + minv, maxv = da.attrs.get("valid_min"), da.attrs.get("valid_max") stats = None if minv is not None and maxv is not None: - stats = ((minv, maxv),) * ds.rio.count + stats = ((minv, maxv),) * da.rio.count - arr = ds.to_masked_array() - arr.mask |= arr.data == ds.rio.nodata + arr = da.to_masked_array() + arr.mask |= arr.data == da.rio.nodata - output_bounds = ds.rio._unordered_bounds() - if output_bounds[1] > output_bounds[3] and ds.rio.transform().e > 0: + output_bounds = da.rio._unordered_bounds() + if output_bounds[1] > output_bounds[3] and da.rio.transform().e > 0: yaxis = self.input.dims.index(self.input.rio.y_dim) arr = numpy.flip(arr, axis=yaxis) img = ImageData( arr, - bounds=ds.rio.bounds(), - crs=ds.rio.crs, + bounds=da.rio.bounds(), + crs=da.rio.crs, dataset_statistics=stats, band_names=band_names, ) @@ -508,27 +508,27 @@ def point( PointData """ - ds_lon, ds_lat = transform_coords(coord_crs, self.crs, [lon], [lat]) + da_lon, da_lat = transform_coords(coord_crs, self.crs, [lon], [lat]) if not ( - (self.bounds[0] < ds_lon[0] < self.bounds[2]) - and (self.bounds[1] < ds_lat[0] < self.bounds[3]) + (self.bounds[0] < da_lon[0] < self.bounds[2]) + and (self.bounds[1] < da_lat[0] < self.bounds[3]) ): raise PointOutsideBounds("Point is outside dataset bounds") - ds, band_names = self._sel_indexes(indexes) + da, band_names = self._sel_indexes(indexes) if nodata is not None: - ds = ds.rio.write_nodata(nodata) + da = da.rio.write_nodata(nodata) - y, x = rowcol(ds.rio.transform(), ds_lon, ds_lat) + y, x = rowcol(da.rio.transform(), da_lon, da_lat) - if ds.ndim == 2: - arr = numpy.expand_dims(ds[int(y[0]), int(x[0])].to_masked_array(), axis=0) + if da.ndim == 2: + arr = numpy.expand_dims(da[int(y[0]), int(x[0])].to_masked_array(), axis=0) else: - arr = ds[:, int(y[0]), int(x[0])].to_masked_array() + arr = da[:, int(y[0]), int(x[0])].to_masked_array() - arr.mask |= arr.data == ds.rio.nodata + arr.mask |= arr.data == da.rio.nodata return PointData( arr,