diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4cd20d5e95f..5659cb029da 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -24,6 +24,8 @@ New Features - New top-level function :py:func:`cross`. (:issue:`3279`, :pull:`5365`). By `Jimmy Westling `_. +- Enable the limit option for dask array in the following methods :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill` and :py:meth:`Dataset.bfill` (:issue:`6112`) + By `Joseph Nowak `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -42,6 +44,9 @@ Deprecations Bug fixes ~~~~~~~~~ +- Properly support :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill` and :py:meth:`Dataset.bfill` along chunked dimensions (:issue:`6112`). + By `Joseph Nowak `_. + - Subclasses of ``byte`` and ``str`` (e.g. ``np.str_`` and ``np.bytes_``) will now serialise to disk rather than raising a ``ValueError: unsupported dtype for netCDF4 variable: object`` as they did previously (:pull:`5264`). By `Zeb Nicholls `_. diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 5eeb22767c8..fa497dbca20 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -57,24 +57,36 @@ def push(array, n, axis): """ Dask-aware bottleneck.push """ - from bottleneck import push + import bottleneck + import dask.array as da + import numpy as np - if len(array.chunks[axis]) > 1 and n is not None and n < array.shape[axis]: - raise NotImplementedError( - "Cannot fill along a chunked axis when limit is not None." - "Either rechunk to a single chunk along this axis or call .compute() or .load() first." - ) - if all(c == 1 for c in array.chunks[axis]): - array = array.rechunk({axis: 2}) - pushed = array.map_blocks(push, axis=axis, n=n, dtype=array.dtype, meta=array._meta) - if len(array.chunks[axis]) > 1: - pushed = pushed.map_overlap( - push, - axis=axis, - n=n, - depth={axis: (1, 0)}, - boundary="none", - dtype=array.dtype, - meta=array._meta, + def _fill_with_last_one(a, b): + # cumreduction apply the push func over all the blocks first so, the only missing part is filling + # the missing values using the last data of the previous chunk + return np.where(~np.isnan(b), b, a) + + if n is not None and 0 < n < array.shape[axis] - 1: + arange = da.broadcast_to( + da.arange( + array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype + ).reshape( + tuple(size if i == axis else 1 for i, size in enumerate(array.shape)) + ), + array.shape, + array.chunks, ) - return pushed + valid_arange = da.where(da.notnull(array), arange, np.nan) + valid_limits = (arange - push(valid_arange, None, axis)) <= n + # omit the forward fill that violate the limit + return da.where(valid_limits, push(array, None, axis), np.nan) + + # The method parameter makes that the tests for python 3.7 fails. + return da.reductions.cumreduction( + func=bottleneck.push, + binop=_fill_with_last_one, + ident=np.nan, + x=array, + axis=axis, + dtype=array.dtype, + ) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index c032a781e47..e12798b70c9 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -884,16 +884,18 @@ def test_push_dask(): import bottleneck import dask.array - array = np.array([np.nan, np.nan, np.nan, 1, 2, 3, np.nan, np.nan, 4, 5, np.nan, 6]) - expected = bottleneck.push(array, axis=0) - for c in range(1, 11): + array = np.array([np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6]) + + for n in [None, 1, 2, 3, 4, 5, 11]: + expected = bottleneck.push(array, axis=0, n=n) + for c in range(1, 11): + with raise_if_dask_computes(): + actual = push(dask.array.from_array(array, chunks=c), axis=0, n=n) + np.testing.assert_equal(actual, expected) + + # some chunks of size-1 with NaN with raise_if_dask_computes(): - actual = push(dask.array.from_array(array, chunks=c), axis=0, n=None) + actual = push( + dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=n + ) np.testing.assert_equal(actual, expected) - - # some chunks of size-1 with NaN - with raise_if_dask_computes(): - actual = push( - dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=None - ) - np.testing.assert_equal(actual, expected) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 69b59a7418c..4121b62a9e8 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -452,8 +452,10 @@ def test_ffill_bfill_dask(method): assert_equal(actual, expected) # limit < axis size - with pytest.raises(NotImplementedError): + with raise_if_dask_computes(): actual = dask_method("x", limit=2) + expected = numpy_method("x", limit=2) + assert_equal(actual, expected) # limit > axis size with raise_if_dask_computes():