Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rolling.construct: Add sliding_window_view_kwargs to pipe arguments down to sliding_window_view #9720

Merged
merged 12 commits into from
Nov 18, 2024
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ New Features
- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])``
(:issue:`2852`, :issue:`757`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Add new ``automatic_rechunk`` kwarg to :py:meth:`DataArrayRolling.construct` and
:py:meth:`DatasetRolling.construct`. This is only useful on ``dask>=2024.11.0``
(:issue:`9550`). By `Deepak Cherian <https://github.com/dcherian>`_.
- Optimize ffill, bfill with dask when limit is specified
(:pull:`9771`).
By `Joseph Nowak <https://github.com/josephnowak>`_, and
Expand Down
16 changes: 16 additions & 0 deletions xarray/core/dask_array_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,19 @@ def reshape_blockwise(
return reshape_blockwise(x, shape=shape, chunks=chunks)
else:
return x.reshape(shape)


def sliding_window_view(
x, window_shape, axis=None, *, automatic_rechunk=True, **kwargs
):
# Backcompat for handling `automatic_rechunk`, delete when dask>=2024.11.0
# Note that subok, writeable are unsupported by dask, so we ignore those in kwargs
from dask.array.lib.stride_tricks import sliding_window_view

if module_available("dask", "2024.11.0"):
return sliding_window_view(
x, window_shape=window_shape, axis=axis, automatic_rechunk=automatic_rechunk
)
else:
# automatic_rechunk is not supported
return sliding_window_view(x, window_shape=window_shape, axis=axis)
33 changes: 24 additions & 9 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@
transpose,
unravel_index,
)
from numpy.lib.stride_tricks import sliding_window_view # noqa: F401
from packaging.version import Version
from pandas.api.types import is_extension_array_dtype

from xarray.core import dask_array_ops, dtypes, nputils
from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils
from xarray.core.options import OPTIONS
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
from xarray.namedarray import pycompat
Expand Down Expand Up @@ -92,19 +91,25 @@ def _dask_or_eager_func(
name,
eager_module=np,
dask_module="dask.array",
dask_only_kwargs=tuple(),
numpy_only_kwargs=tuple(),
):
"""Create a function that dispatches to dask for dask array inputs."""

def f(*args, **kwargs):
if any(is_duck_dask_array(a) for a in args):
if dask_available and any(is_duck_dask_array(a) for a in args):
mod = (
import_module(dask_module)
if isinstance(dask_module, str)
else dask_module
)
wrapped = getattr(mod, name)
for kwarg in numpy_only_kwargs:
kwargs.pop(kwarg, None)
else:
wrapped = getattr(eager_module, name)
for kwarg in dask_only_kwargs:
kwargs.pop(kwarg, None)
return wrapped(*args, **kwargs)

return f
Expand All @@ -122,6 +127,22 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
# Requires special-casing because pandas won't automatically dispatch to dask.isnull via NEP-18
pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd, dask_module="dask.array")

# TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed
# TODO: replacing breaks iris + dask tests
masked_invalid = _dask_or_eager_func(
"masked_invalid", eager_module=np.ma, dask_module="dask.array.ma"
)

# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk),
# so we need to hand-code this.
sliding_window_view = _dask_or_eager_func(
"sliding_window_view",
eager_module=np.lib.stride_tricks,
dask_module=dask_array_compat,
dask_only_kwargs=("automatic_rechunk",),
numpy_only_kwargs=("subok", "writeable"),
)


def round(array):
xp = get_array_namespace(array)
Expand Down Expand Up @@ -170,12 +191,6 @@ def notnull(data):
return ~isnull(data)


# TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed
masked_invalid = _dask_or_eager_func(
"masked_invalid", eager_module=np.ma, dask_module="dask.array.ma"
)


def trapz(y, x, axis):
if axis < 0:
axis = y.ndim + axis
Expand Down
Loading
Loading