Skip to content

Generalize cumulative reduction (scan) to non-dask types #8019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 18, 2023
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
@@ -589,6 +589,10 @@ Internal Changes

- :py:func:`as_variable` now consistently includes the variable name in any exceptions
raised. (:pull:`7995`). By `Peter Hill <https://github.com/ZedThree>`_
- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`,
potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.bfill` to
use non-dask chunked array types.
(:pull:`8019`) By `Tom Nicholas <https://github.com/TomNicholas>`_.
- :py:func:`encode_dataset_coordinates` now sorts coordinates automatically assigned to
`coordinates` attributes during serialization (:issue:`8026`, :pull:`8034`).
`By Ian Carroll <https://github.com/itcarroll>`_.
22 changes: 22 additions & 0 deletions xarray/core/daskmanager.py
Original file line number Diff line number Diff line change
@@ -97,6 +97,28 @@ def reduction(
keepdims=keepdims,
)

def scan(
self,
func: Callable,
binop: Callable,
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: np.dtype | None = None,
**kwargs,
) -> DaskArray:
from dask.array.reductions import cumreduction

return cumreduction(
func,
binop,
ident,
arr,
axis=axis,
dtype=dtype,
**kwargs,
)

def apply_gufunc(
self,
func: Callable,
37 changes: 37 additions & 0 deletions xarray/core/parallelcompat.py
Original file line number Diff line number Diff line change
@@ -403,6 +403,43 @@ def reduction(
"""
raise NotImplementedError()

def scan(
self,
func: Callable,
binop: Callable,
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: np.dtype | None = None,
**kwargs,
) -> T_ChunkedArray:
"""
General version of a 1D scan, also known as a cumulative array reduction.

Used in ``ffill`` and ``bfill`` in xarray.

Parameters
----------
func: callable
Cumulative function like np.cumsum or np.cumprod
binop: callable
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
ident: Number
Associated identity like ``np.cumsum->0`` or ``np.cumprod->1``
arr: dask Array
axis: int, optional
dtype: dtype

Returns
-------
Chunked array

See also
--------
dask.array.cumreduction
"""
raise NotImplementedError()

@abstractmethod
def apply_gufunc(
self,