From 74829ccd4b709c012310744b4b3198637becad3a Mon Sep 17 00:00:00 2001 From: Bart Schilperoort Date: Tue, 10 Sep 2024 14:02:14 +0200 Subject: [PATCH] Implement new "most common" regridder. --- README.md | 2 +- docs/getting_started.rst | 2 +- docs/notebooks/demos/demo_most_common.ipynb | 785 +++++++++++++++++++- src/xarray_regrid/methods/_shared.py | 93 +++ src/xarray_regrid/methods/flox_reduce.py | 145 ++++ src/xarray_regrid/methods/most_common.py | 255 ------- src/xarray_regrid/regrid.py | 36 +- tests/test_most_common.py | 42 +- 8 files changed, 1058 insertions(+), 302 deletions(-) create mode 100644 src/xarray_regrid/methods/_shared.py create mode 100644 src/xarray_regrid/methods/flox_reduce.py delete mode 100644 src/xarray_regrid/methods/most_common.py diff --git a/README.md b/README.md index 4b5b41e..1efe67b 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ With xarray-regrid it is possible to regrid between two rectilinear grids. The f - Cubic - "Most common value" (zonal statistics) -All regridding methods, except for the "most common value" can operate lazily on [Dask arrays](https://docs.xarray.dev/en/latest/user-guide/dask.html). +All regridding methods can operate lazily on [Dask arrays](https://docs.xarray.dev/en/latest/user-guide/dask.html). Note that "Most common value" is designed to regrid categorical data to a coarse resolution. For regridding categorical data to a finer resolution, please use "nearest-neighbor" regridder. diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 33b0f25..c5299f9 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -34,5 +34,5 @@ Multiple regridding methods are available: * `conservative regridding `_ (``.regrid.conservative``) Additionally, a zonal statistics `method to compute the most common value `_ -is available (``.regrid.most_common``). +is available for DataArrays (``.regrid.most_common``). This can be used to upscale very fine categorical data to a more course resolution. diff --git a/docs/notebooks/demos/demo_most_common.ipynb b/docs/notebooks/demos/demo_most_common.ipynb index a322d1a..6533f14 100644 --- a/docs/notebooks/demos/demo_most_common.ipynb +++ b/docs/notebooks/demos/demo_most_common.ipynb @@ -39,84 +39,817 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Next twe need a high resolution dataset to regrid. We used the LCCS land cover data which is available from the [Climate Data Store](https://cds.climate.copernicus.eu/cdsapp#!/dataset/satellite-land-cover).\n", + "Next we need a high resolution dataset to regrid. We used the LCCS land cover data which is available from the [Climate Data Store](https://cds.climate.copernicus.eu/cdsapp#!/dataset/satellite-land-cover).\n", "\n", - "We will also define our target grid:" + "Note the data is loaded in as a dask arrays, allowing for lazy computation." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'lccs_class' (time: 1, latitude: 64800, longitude: 129600)> Size: 8GB\n",
+       "dask.array<getitem, shape=(1, 64800, 129600), dtype=uint8, chunksize=(1, 9257, 10125), chunktype=numpy.ndarray>\n",
+       "Coordinates:\n",
+       "  * latitude   (latitude) float64 518kB -90.0 -90.0 -89.99 ... 89.99 90.0 90.0\n",
+       "  * longitude  (longitude) float64 1MB -180.0 -180.0 -180.0 ... 180.0 180.0\n",
+       "  * time       (time) datetime64[ns] 8B 2020-01-01\n",
+       "Attributes:\n",
+       "    standard_name:        land_cover_lccs\n",
+       "    flag_colors:          #ffff64 #ffff64 #ffff00 #aaf0f0 #dcf064 #c8c864 #00...\n",
+       "    long_name:            Land cover class defined in LCCS\n",
+       "    valid_min:            1\n",
+       "    valid_max:            220\n",
+       "    ancillary_variables:  processed_flag current_pixel_state observation_coun...\n",
+       "    flag_meanings:        no_data cropland_rainfed cropland_rainfed_herbaceou...\n",
+       "    flag_values:          [  0  10  11  12  20  30  40  50  60  61  62  70  7...
" + ], + "text/plain": [ + " Size: 8GB\n", + "dask.array\n", + "Coordinates:\n", + " * latitude (latitude) float64 518kB -90.0 -90.0 -89.99 ... 89.99 90.0 90.0\n", + " * longitude (longitude) float64 1MB -180.0 -180.0 -180.0 ... 180.0 180.0\n", + " * time (time) datetime64[ns] 8B 2020-01-01\n", + "Attributes:\n", + " standard_name: land_cover_lccs\n", + " flag_colors: #ffff64 #ffff64 #ffff00 #aaf0f0 #dcf064 #c8c864 #00...\n", + " long_name: Land cover class defined in LCCS\n", + " valid_min: 1\n", + " valid_max: 220\n", + " ancillary_variables: processed_flag current_pixel_state observation_coun...\n", + " flag_meanings: no_data cropland_rainfed cropland_rainfed_herbaceou...\n", + " flag_values: [ 0 10 11 12 20 30 40 50 60 61 62 70 7..." + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "ds = xr.open_dataset(\n", - " \"../ESACCI-LC-L4-LCCS-Map-300m-P1Y-2013-v2.0.7cds.nc\",\n", - " chunks={\"lat\": 2000, \"lon\": 2000},\n", + " \"/data/C3S-LC-L4-LCCS-Map-300m-P1Y-2020-v2.1.1.nc\",\n", + " chunks=\"auto\",\n", ")\n", "\n", - "ds = ds[[\"lccs_class\"]] # Only take the class variable.\n", - "ds = ds.sortby([\"lat\", \"lon\"])\n", - "ds = ds.rename({\"lat\": \"latitude\", \"lon\": \"longitude\"})\n", - "\n", - "from xarray_regrid import Grid, create_regridding_dataset\n", + "da = ds[\"lccs_class\"] # Only take the class variable.\n", + "da = da.sortby([\"lat\", \"lon\"])\n", + "da = da.rename({\"lat\": \"latitude\", \"lon\": \"longitude\"})\n", + "da" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will also define our target grid:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from xarray_regrid import Grid\n", "\n", - "new_grid = Grid(\n", + "target_dataset = Grid(\n", " north=90,\n", " east=90,\n", " south=0,\n", " west=0,\n", " resolution_lat=1,\n", " resolution_lon=1,\n", - ")\n", - "target_dataset = create_regridding_dataset(new_grid)" + ").create_regridding_dataset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Using `regrid.most_common` you can regrid the data.\n", - "\n", - "Currently the computation can not be done fully lazily, however a workaround that splits the problem into chunks and combines the solution is available. This is enabled using the \"max_mem\" keyword argument.\n", + "The default chunks are a bit large for this regridding operation, so we need to rechunk before continuing to avoid memory issues: " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Array Chunk
Bytes 7.82 GiB 15.64 MiB
Shape (1, 64800, 129600) (1, 4050, 4050)
Dask graph 512 chunks in 5 graph layers
Data type uint8 numpy.ndarray
\n", + "
\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + "\n", + " \n", + " 129600\n", + " 64800\n", + " 1\n", + "\n", + "
" + ], + "text/plain": [ + "dask.array" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "da = da.chunk({\"time\": -1, \"latitude\": 4050, \"longitude\": 4050})\n", + "da.data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using `regrid.most_common` you can now regrid the data. This is currently only implemented for `DataArray`s, not `xr.Dataset`.\n", "\n", - "Note that the maximum memory limits the size of the regridding routine (in bytes), not of the input/output data, so total memory use can be higher." + "Note that we have to provide the expected groups (i.e. unique labels) in the data. This dataset already conventiently stores these in the attributes." ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "ds_regrid = ds.regrid.most_common(target_dataset, time_dim=\"time\", max_mem=1e9)" + "da_regrid = da.regrid.most_common(\n", + " target_dataset, expected_groups=da.attrs[\"flag_values\"], time_dim=\"time\"\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "After computation, we can plot the solution:" + "When we call `.plot` on the DataArray, computation will begin." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 4, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -126,7 +859,7 @@ } ], "source": [ - "ds_regrid[\"lccs_class\"].plot(x=\"longitude\")" + "da_regrid.plot(x=\"longitude\")" ] }, { @@ -153,7 +886,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.0" + "version": "3.12.0" }, "orig_nbformat": 4 }, diff --git a/src/xarray_regrid/methods/_shared.py b/src/xarray_regrid/methods/_shared.py new file mode 100644 index 0000000..cc629a0 --- /dev/null +++ b/src/xarray_regrid/methods/_shared.py @@ -0,0 +1,93 @@ +"""Utility functions shared between methods.""" + +from typing import overload + +import numpy as np +import pandas as pd +import xarray as xr + + +def construct_intervals(coord: np.ndarray) -> pd.IntervalIndex: + """Create pandas.intervals with given coordinates.""" + step_size = np.median(np.diff(coord, n=1)) + breaks = np.append(coord, coord[-1] + step_size) - step_size / 2 + + # Note: closed="both" triggers an `NotImplementedError` + return pd.IntervalIndex.from_breaks(breaks, closed="left") + + +@overload +def restore_properties( + result: xr.DataArray, + original_data: xr.DataArray | xr.Dataset, + target_ds: xr.Dataset, + coords: list[str], +) -> xr.DataArray: ... + + +@overload +def restore_properties( + result: xr.Dataset, + original_data: xr.DataArray | xr.Dataset, + target_ds: xr.Dataset, + coords: list[str], +) -> xr.Dataset: ... + + +def restore_properties( + result: xr.DataArray | xr.Dataset, + original_data: xr.DataArray | xr.Dataset, + target_ds: xr.Dataset, + coords: list[str], +) -> xr.DataArray | xr.Dataset: + """Restore coord names, copy values and attributes of target, & add NaN padding.""" + result.attrs = original_data.attrs + + result = result.rename({f"{coord}_bins": coord for coord in coords}) + for coord in coords: + result[coord] = target_ds[coord] + result[coord].attrs = target_ds[coord].attrs + + # Replace zeros outside of original data grid with NaNs + uncovered_target_grid = (target_ds[coord] <= original_data[coord].max()) & ( + target_ds[coord] >= original_data[coord].min() + ) + result = result.where(uncovered_target_grid) + + return result.transpose(*original_data.dims) + + +@overload +def reduce_data_to_new_domain( + data: xr.DataArray, + target_ds: xr.Dataset, + coords: list[str], +) -> xr.DataArray: ... + + +@overload +def reduce_data_to_new_domain( + data: xr.Dataset, + target_ds: xr.Dataset, + coords: list[str], +) -> xr.Dataset: ... + + +def reduce_data_to_new_domain( + data: xr.DataArray | xr.Dataset, + target_ds: xr.Dataset, + coords: list[str], +) -> xr.DataArray | xr.Dataset: + """Slice the input data to bounds of the target dataset, to reduce computations.""" + data = data.sortby(list(coords)) + for coord in coords: + coord_res = np.median(np.diff(target_ds[coord].to_numpy(), 1)) + data = data.sel( + { + coord: slice( + target_ds[coord].min().to_numpy() - coord_res, + target_ds[coord].max().to_numpy() + coord_res, + ) + } + ) + return data diff --git a/src/xarray_regrid/methods/flox_reduce.py b/src/xarray_regrid/methods/flox_reduce.py new file mode 100644 index 0000000..2f484bc --- /dev/null +++ b/src/xarray_regrid/methods/flox_reduce.py @@ -0,0 +1,145 @@ +"""Implementation of flox reduction based regridding methods.""" + +import flox.xarray +import numpy as np +import pandas as pd +import xarray as xr + +from xarray_regrid import utils +from xarray_regrid.methods._shared import ( + construct_intervals, + reduce_data_to_new_domain, + restore_properties, +) + + +def statistic_reduce( + data: xr.Dataset, + target_ds: xr.Dataset, + time_dim: str, + method: str, + skipna: bool = False, +) -> xr.Dataset: + """Upsampling of data using statistical methods (e.g. the mean or variance). + + We use flox Aggregations to perform a "groupby" over multiple dimensions, which we + reduce using the specified method. + https://flox.readthedocs.io/en/latest/aggregations.html + + Args: + data: Input dataset. + target_ds: Dataset which coordinates the input dataset should be regrid to. + time_dim: Name of the time dimension. Defaults to "time". Use `None` to force + regridding over the time dimension. + method: One of the following reduction methods: "sum", "mean", "var", "std". + skipna: If NaN values should be ignored. + + Returns: + xarray.dataset with regridded land cover categorical data. + """ + valid_methods = ["sum", "mean", "var", "std"] + if method not in valid_methods: + msg = f"Invalid method. Please choose from '{valid_methods}'." + raise ValueError(msg) + + if skipna: + method = "nan" + method + + coords = utils.common_coords(data, target_ds, remove_coord=time_dim) + + bounds = tuple(construct_intervals(target_ds[coord].to_numpy()) for coord in coords) + + data = reduce_data_to_new_domain(data, target_ds, coords) + + result: xr.Dataset = flox.xarray.xarray_reduce( + data.compute(), + *coords, + func=method, + expected_groups=bounds, + ) + + return restore_properties(result, data, target_ds, coords) + + +def find_matching_int_dtype( + a: np.ndarray, +) -> type[np.signedinteger] | type[np.unsignedinteger]: + """Find the smallest integer datatype that can cover the given array.""" + # Integer types in increasing memory use + int_types: list[type[np.signedinteger] | type[np.unsignedinteger]] = [ + np.int8, + np.uint8, + np.int16, + np.uint16, + np.int32, + np.uint32, + ] + for dtype in int_types: + if (a.max() <= np.iinfo(dtype).max) and (a.min() >= np.iinfo(dtype).min): + return dtype + return np.int64 + + +def get_most_common_value( + data: xr.DataArray, + target_ds: xr.Dataset, + expected_groups: np.ndarray, + time_dim: str | None = "time", + inverse: bool = False, +) -> xr.DataArray: + """Upsample the input data using a "most common label" (mode) approach. + + Args: + data: Input DataArray, with an integer data type. If your data does not consist + of integer type values, you will have to encode them to integer types. + target_ds: Dataset which coordinates the input dataset should be regrid to. + expected_groups: Numpy array containing all labels expected to be in the input + data. For example, `np.array([0, 2, 4])`, if the data only contains the + values 0, 2 and 4. + time_dim: Name of the time dimension. Defaults to "time". Use `None` to force + regridding over the time dimension. + inverse: Find the least-common-value (anti-mode). + + Raises: + ValueError: if the input data is not of an integer dtype. + + Returns: + xarray.DataArray with regridded categorical data. + """ + array_name = data.name if data.name is not None else "DATA_NAME" + + # Must be categorical data (integers) + if not np.issubdtype(data.dtype, np.integer): + msg = ( + "Your input data has to be of an integer datatype for this method.\n" + f" instead, your data is of type '{data.dtype}'." + "You can convert the data with:\n `dataset.astype(int)`." + ) + raise ValueError(msg) + + coords = utils.common_coords(data, target_ds, remove_coord=time_dim) + target_ds_sorted = target_ds.sortby(list(coords)) + + bounds = tuple( + construct_intervals(target_ds_sorted[coord].to_numpy()) for coord in coords + ) + + data = reduce_data_to_new_domain(data, target_ds_sorted, coords) + + # Reduce memory usage by picking the most minimal integer type + dtype = find_matching_int_dtype(expected_groups) + + result: xr.DataArray = flox.xarray.xarray_reduce( + xr.ones_like(data, dtype=bool), + data.astype(dtype), # important, needs to be int + *coords, + dim=coords, + func="count", + expected_groups=(pd.Index(expected_groups.astype(dtype)), *bounds), + fill_value=-1, + ) + result = result.idxmax(array_name) if not inverse else result.idxmin(array_name) + + result = restore_properties(result, data, target_ds_sorted, coords) + result = result.reindex_like(target_ds, copy=False) + return result diff --git a/src/xarray_regrid/methods/most_common.py b/src/xarray_regrid/methods/most_common.py deleted file mode 100644 index e0407f7..0000000 --- a/src/xarray_regrid/methods/most_common.py +++ /dev/null @@ -1,255 +0,0 @@ -"""Implementation of the "most common value" regridding method.""" - -from itertools import product -from typing import Any, overload - -import flox.xarray -import numpy as np -import numpy_groupies as npg # type: ignore -import pandas as pd -import xarray as xr -from flox import Aggregation - -from xarray_regrid import utils - - -@overload -def most_common_wrapper( - data: xr.DataArray, - target_ds: xr.Dataset, - time_dim: str = "", - max_mem: int | None = None, -) -> xr.DataArray: ... - - -@overload -def most_common_wrapper( - data: xr.Dataset, - target_ds: xr.Dataset, - time_dim: str = "", - max_mem: int | None = None, -) -> xr.Dataset: ... - - -def most_common_wrapper( - data: xr.DataArray | xr.Dataset, - target_ds: xr.Dataset, - time_dim: str = "", - max_mem: int | None = None, -) -> xr.DataArray | xr.Dataset: - """Wrapper for the most common regridder, allowing for analyzing larger datasets. - - Args: - data: Input dataset. - target_ds: Dataset which coordinates the input dataset should be regrid to. - time_dim: Name of the time dimension, as the regridders do not regrid over time. - Defaults to "time". - max_mem: (Approximate) maximum memory in bytes that the regridding routines can - use. Note that this is not the total memory consumption and does not include - the size of the final dataset. - If this kwargs is used, the regridding will be split up into more manageable - chunks, and combined for the final dataset. - - Returns: - xarray.dataset with regridded categorical data. - """ - da_name = None - if isinstance(data, xr.DataArray): - da_name = "da" if data.name is None else data.name - data = data.to_dataset(name=da_name) - - coords = utils.common_coords(data, target_ds) - target_ds_sorted = target_ds.sortby(list(coords)) - coord_size = [data[coord].size for coord in coords] - mem_usage = np.prod(coord_size) * np.zeros((1,), dtype=np.int64).itemsize - - if max_mem is not None and mem_usage > max_mem: - result = split_combine_most_common( - data=data, target_ds=target_ds_sorted, time_dim=time_dim, max_mem=max_mem - ) - else: - result = most_common(data=data, target_ds=target_ds_sorted, time_dim=time_dim) - - result = result.reindex_like(target_ds, copy=False) - - if da_name is not None: - return result[da_name] - else: - return result - - -def split_combine_most_common( - data: xr.Dataset, target_ds: xr.Dataset, time_dim: str, max_mem: int = int(1e9) -) -> xr.Dataset: - """Use a split-combine strategy to reduce the memory use of the most_common regrid. - - Args: - data: Input dataset. - target_ds: Dataset which coordinates the input dataset should be regrid to. - time_dim: Name of the time dimension, as the regridders do not regrid over time. - Defaults to "time". - max_mem: (Approximate) maximum memory in bytes that the regridding routines can - use. Note that this is not the total memory consumption and does not include - the size of the final dataset. Defaults to 1e9 (1 GB). - - Returns: - xarray.dataset with regridded categorical data. - """ - coords = utils.common_coords(data, target_ds, remove_coord=time_dim) - max_datapoints = max_mem // 8 # ~8 bytes per item. - max_source_coord_size = max_datapoints ** (1 / len(coords)) - size_ratios = { - coord: ( - np.median(np.diff(data[coord].to_numpy(), 1)) - / np.median(np.diff(target_ds[coord].to_numpy(), 1)) - ) - for coord in coords - } - max_coord_size = { - coord: int(size_ratios[coord] * max_source_coord_size) for coord in coords - } - - blocks = { - coord: np.arange(0, target_ds[coord].size, max_coord_size[coord]) - for coord in coords - } - - subsets = [] - for vals in product(*blocks.values()): - isel = {} - for coord, val in zip(blocks.keys(), vals, strict=True): - isel[coord] = slice(val, val + max_coord_size[coord]) - subsets.append(most_common(data, target_ds.isel(isel), time_dim=time_dim)) - - return xr.merge(subsets) - - -def most_common(data: xr.Dataset, target_ds: xr.Dataset, time_dim: str) -> xr.Dataset: - """Upsampling of data with a "most common label" approach. - - The implementation includes two steps: - - "groupby" coordinates - - select most common label - - We use flox to perform "groupby" multiple dimensions. Here is an example: - https://flox.readthedocs.io/en/latest/intro.html#histogramming-binning-by-multiple-variables - - To embed our customized function for most common label selection, we need to - create our `flox.Aggregation`, for instance: - https://flox.readthedocs.io/en/latest/aggregations.html - - `flox.Aggregation` function works with `numpy_groupies.aggregate_numpy.aggregate - API. Therefore this function also depends on `numpy_groupies`. For more information, - check the following example: - https://flox.readthedocs.io/en/latest/user-stories/custom-aggregations.html - - Args: - data: Input dataset. - target_ds: Dataset which coordinates the input dataset should be regrid to. - - Returns: - xarray.dataset with regridded land cover categorical data. - """ - dim_order = data.dims - coords = utils.common_coords(data, target_ds, remove_coord=time_dim) - coord_attrs = {coord: data[coord].attrs for coord in target_ds.coords} - - bounds = tuple( - _construct_intervals(target_ds[coord].to_numpy()) for coord in coords - ) - - # Slice the input data to the bounds of the target dataset - data = data.sortby(list(coords)) - for coord in coords: - coord_res = np.median(np.diff(target_ds[coord].to_numpy(), 1)) - data = data.sel( - { - coord: slice( - target_ds[coord].min().to_numpy() - coord_res, - target_ds[coord].max().to_numpy() + coord_res, - ) - } - ) - - most_common = Aggregation( - name="most_common", - numpy=_custom_grouped_reduction, # type: ignore - chunk=None, - combine=None, - ) - - ds_regrid: xr.Dataset = flox.xarray.xarray_reduce( - data.compute(), - *coords, - func=most_common, - expected_groups=bounds, - ) - - ds_regrid = ds_regrid.rename({f"{coord}_bins": coord for coord in coords}) - for coord in coords: - ds_regrid[coord] = target_ds[coord] - - # Replace zeros outside of original data grid with NaNs - uncovered_target_grid = (target_ds[coord] <= data[coord].max()) & ( - target_ds[coord] >= data[coord].min() - ) - ds_regrid = ds_regrid.where(uncovered_target_grid) - - ds_regrid[coord].attrs = coord_attrs[coord] - - return ds_regrid.transpose(*dim_order) - - -def _construct_intervals(coord: np.ndarray) -> pd.IntervalIndex: - """Create pandas.intervals with given coordinates.""" - step_size = np.median(np.diff(coord, n=1)) - breaks = np.append(coord, coord[-1] + step_size) - step_size / 2 - - # Note: closed="both" triggers an `NotImplementedError` - return pd.IntervalIndex.from_breaks(breaks, closed="left") - - -def _most_common_label(neighbors: np.ndarray) -> np.ndarray: - """Find the most common label in a neighborhood. - - Note that if more than one labels have the same frequency which is the highest, - then the first label in the list will be picked. - """ - unique_labels, counts = np.unique(neighbors, return_counts=True) - return unique_labels[np.argmax(counts)] # type: ignore - - -def _custom_grouped_reduction( - group_idx: np.ndarray, - array: np.ndarray, - *, - axis: int = -1, - size: int | None = None, - fill_value: Any = None, - dtype: Any = None, -) -> np.ndarray: - """Custom grouped reduction for flox.Aggregation to get most common label. - - Args: - group_idx : integer codes for group labels (1D) - array : values to reduce (nD) - axis : axis of array along which to reduce. - Requires array.shape[axis] == len(group_idx) - size : expected number of groups. If none, - output.shape[-1] == number of uniques in group_idx - fill_value : fill_value for when number groups in group_idx is less than size - dtype : dtype of output - - Returns: - np.ndarray with array.shape[-1] == size, containing a single value per group - """ - agg: np.ndarray = npg.aggregate_numpy.aggregate( - group_idx, - array, - func=_most_common_label, - axis=axis, - size=size, - fill_value=fill_value, - dtype=dtype, - ) - return agg diff --git a/src/xarray_regrid/regrid.py b/src/xarray_regrid/regrid.py index e91348a..b6709ae 100644 --- a/src/xarray_regrid/regrid.py +++ b/src/xarray_regrid/regrid.py @@ -1,6 +1,7 @@ +import numpy as np import xarray as xr -from xarray_regrid.methods import conservative, interp, most_common +from xarray_regrid.methods import conservative, flox_reduce, interp @xr.register_dataarray_accessor("regrid") @@ -111,9 +112,10 @@ def conservative( def most_common( self, ds_target_grid: xr.Dataset, + expected_groups: np.ndarray, time_dim: str = "time", - max_mem: int = int(1e9), - ) -> xr.DataArray | xr.Dataset: + inverse: bool = False, + ) -> xr.DataArray: """Regrid by taking the most common value within the new grid cells. To be used for regridding data to a much coarser resolution, not for regridding @@ -125,17 +127,33 @@ def most_common( Args: ds_target_grid: Target grid dataset - time_dim: Name of the time dimension. Defaults to "time". - max_mem: (Approximate) maximum memory in bytes that the regridding routine - can use. Note that this is not the total memory consumption and does not - include the size of the final dataset. Defaults to 1e9 (1 GB). + expected_groups: Numpy array containing all labels expected to be in the + input data. For example, `np.array([0, 2, 4])`, if the data only + contains the values 0, 2 and 4. + time_dim: Name of the time dimension. Defaults to "time". Use `None` to + force regridding over the time dimension. + inverse: Find the least-common-value (anti-mode). Returns: Regridded data. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) - return most_common.most_common_wrapper( - self._obj, ds_target_grid, time_dim, max_mem + + if isinstance(self._obj, xr.Dataset): + msg = ( + "The 'most common value' regridder is not implemented for\n", + "xarray.Dataset, as it requires specifying the expected labels.\n" + "Please select only a single variable (as DataArray),\n" + " and regrid it separately.", + ) + raise ValueError(msg) + + return flox_reduce.get_most_common_value( + self._obj, + ds_target_grid, + expected_groups, + time_dim, + inverse, ) diff --git a/tests/test_most_common.py b/tests/test_most_common.py index ec05221..6d275e9 100644 --- a/tests/test_most_common.py +++ b/tests/test_most_common.py @@ -5,6 +5,8 @@ from xarray_regrid import Grid, create_regridding_dataset +EXP_LABELS = np.array([0, 1, 2, 3]) # labels that are in the dummy data + @pytest.fixture def dummy_lc_data(): @@ -26,7 +28,7 @@ def dummy_lc_data(): lat_coords = np.linspace(0, 40, num=11) lon_coords = np.linspace(0, 40, num=11) - return xr.Dataset( + ds = xr.Dataset( data_vars={ "lc": (["longitude", "latitude"], data), }, @@ -36,6 +38,9 @@ def dummy_lc_data(): }, attrs={"test": "not empty"}, ) + ds["longitude"].attrs = {"units": "degrees_east"} + ds["latitude"].attrs = {"units": "degrees_north"} + return ds @pytest.fixture @@ -89,7 +94,10 @@ def test_most_common(dummy_lc_data, dummy_target_grid): }, ) xr.testing.assert_equal( - dummy_lc_data.regrid.most_common(dummy_target_grid)["lc"], + dummy_lc_data["lc"].regrid.most_common( + dummy_target_grid, + expected_groups=EXP_LABELS, + ), expected["lc"], ) @@ -121,41 +129,55 @@ def test_oversized_most_common(dummy_lc_data, oversized_dummy_target_grid): }, ) xr.testing.assert_equal( - dummy_lc_data.regrid.most_common(oversized_dummy_target_grid)["lc"], + dummy_lc_data["lc"].regrid.most_common( + oversized_dummy_target_grid, + expected_groups=EXP_LABELS, + ), expected["lc"], ) def test_attrs_dataarray(dummy_lc_data, dummy_target_grid): dummy_lc_data["lc"].attrs = {"test": "testing"} - da_regrid = dummy_lc_data["lc"].regrid.most_common(dummy_target_grid) + da_regrid = dummy_lc_data["lc"].regrid.most_common( + dummy_target_grid, + expected_groups=EXP_LABELS, + ) assert da_regrid.attrs != {} assert da_regrid.attrs == dummy_lc_data["lc"].attrs - assert da_regrid["longitude"].attrs == dummy_lc_data["longitude"].attrs + assert da_regrid["longitude"].attrs == dummy_target_grid["longitude"].attrs +@pytest.mark.xfail # most common currently does not work for datasets def test_attrs_dataset(dummy_lc_data, dummy_target_grid): ds_regrid = dummy_lc_data.regrid.most_common( dummy_target_grid, + expected_groups=EXP_LABELS, ) assert ds_regrid.attrs != {} assert ds_regrid.attrs == dummy_lc_data.attrs - assert ds_regrid["longitude"].attrs == dummy_lc_data["longitude"].attrs + assert ds_regrid["longitude"].attrs == dummy_target_grid["longitude"].attrs -@pytest.mark.parametrize("dataarray", [True, False]) +@pytest.mark.parametrize("dataarray", [True]) # most common does not work for datasets def test_coord_order_original(dummy_lc_data, dummy_target_grid, dataarray): input_data = dummy_lc_data["lc"] if dataarray else dummy_lc_data - ds_regrid = input_data.regrid.most_common(dummy_target_grid) + ds_regrid = input_data.regrid.most_common( + dummy_target_grid, + expected_groups=EXP_LABELS, + ) assert_array_equal(ds_regrid["latitude"], dummy_target_grid["latitude"]) assert_array_equal(ds_regrid["longitude"], dummy_target_grid["longitude"]) @pytest.mark.parametrize("coord", ["latitude", "longitude"]) -@pytest.mark.parametrize("dataarray", [True, False]) +@pytest.mark.parametrize("dataarray", [True]) # most common does not work for datasets def test_coord_order_reversed(dummy_lc_data, dummy_target_grid, coord, dataarray): input_data = dummy_lc_data["lc"] if dataarray else dummy_lc_data dummy_target_grid[coord] = list(reversed(dummy_target_grid[coord])) - ds_regrid = input_data.regrid.most_common(dummy_target_grid) + ds_regrid = input_data.regrid.most_common( + dummy_target_grid, + expected_groups=EXP_LABELS, + ) assert_array_equal(ds_regrid["latitude"], dummy_target_grid["latitude"]) assert_array_equal(ds_regrid["longitude"], dummy_target_grid["longitude"])