Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ffill and bfill along chunked dimensions #5187

Merged
merged 7 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ Deprecations

Bug fixes
~~~~~~~~~
- Properly support :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill`, :py:meth:`Dataset.bfill` along chunked dimensions.
(:issue:`2699`).By `Deepak Cherian <https://github.com/dcherian>`_.
- Fix 2d plot failure for certain combinations of dimensions when `x` is 1d and `y` is
2d (:issue:`5097`, :pull:`5099`). By `John Omotani <https://github.com/johnomotani>`_.
- Ensure standard calendar times encoded with large values (i.e. greater than approximately 292 years), can be decoded correctly without silently overflowing (:pull:`5050`). This was a regression in xarray 0.17.0. By `Zeb Nicholls <https://github.com/znicholls>`_.
Expand Down
6 changes: 4 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2515,7 +2515,8 @@ def ffill(self, dim: Hashable, limit: int = None) -> "DataArray":
The maximum number of consecutive NaN values to forward fill. In
other words, if there is a gap with more than this number of
consecutive NaNs, it will only be partially filled. Must be greater
than 0 or None for no limit.
than 0 or None for no limit. Must be None or greater than or equal
to axis length if filling along chunked axes (dimensions).

Returns
-------
Expand All @@ -2539,7 +2540,8 @@ def bfill(self, dim: Hashable, limit: int = None) -> "DataArray":
The maximum number of consecutive NaN values to backward fill. In
other words, if there is a gap with more than this number of
consecutive NaNs, it will only be partially filled. Must be greater
than 0 or None for no limit.
than 0 or None for no limit. Must be None or greater than or equal
to axis length if filling along chunked axes (dimensions).

Returns
-------
Expand Down
6 changes: 4 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4654,7 +4654,8 @@ def ffill(self, dim: Hashable, limit: int = None) -> "Dataset":
The maximum number of consecutive NaN values to forward fill. In
other words, if there is a gap with more than this number of
consecutive NaNs, it will only be partially filled. Must be greater
than 0 or None for no limit.
than 0 or None for no limit. Must be None or greater than or equal
to axis length if filling along chunked axes (dimensions).

Returns
-------
Expand All @@ -4679,7 +4680,8 @@ def bfill(self, dim: Hashable, limit: int = None) -> "Dataset":
The maximum number of consecutive NaN values to backward fill. In
other words, if there is a gap with more than this number of
consecutive NaNs, it will only be partially filled. Must be greater
than 0 or None for no limit.
than 0 or None for no limit. Must be None or greater than or equal
to axis length if filling along chunked axes (dimensions).

Returns
-------
Expand Down
24 changes: 24 additions & 0 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,27 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
return dask_array_ops.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)
else:
return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)


def push(array, n, axis):
"""
Dask-aware bottleneck.push
"""
from bottleneck import push

if is_duck_dask_array(array):
if len(array.chunks[axis]) > 1 and n is not None and n < array.shape[axis]:
raise NotImplementedError(
"Can only fill along a chunked axis when `limit` is None or at least axis length."
"Either rechunk to a single chunk along this axis or call .compute() or .load() first."
)
if all(c == 1 for c in array.chunks[axis]):
array = array.rechunk({axis: 2})
pushed = array.map_blocks(push, axis=axis, n=n)
if len(array.chunks[axis]) > 1:
pushed = pushed.map_overlap(
push, axis=axis, n=n, depth={axis: (1, 0)}, boundary="none"
)
return pushed
else:
return push(array, n, axis)
14 changes: 5 additions & 9 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from . import utils
from .common import _contains_datetime_like_objects, ones_like
from .computation import apply_ufunc
from .duck_array_ops import datetime_to_numeric, timedelta_to_numeric
from .duck_array_ops import datetime_to_numeric, push, timedelta_to_numeric
from .options import _get_keep_attrs
from .pycompat import is_duck_dask_array
from .utils import OrderedSet, is_scalar
Expand Down Expand Up @@ -390,30 +390,26 @@ def func_interpolate_na(interpolator, y, x, **kwargs):

def _bfill(arr, n=None, axis=-1):
"""inverse of ffill"""
import bottleneck as bn

arr = np.flip(arr, axis=axis)

# fill
arr = bn.push(arr, axis=axis, n=n)
arr = push(arr, axis=axis, n=n)

# reverse back to original
return np.flip(arr, axis=axis)


def ffill(arr, dim=None, limit=None):
"""forward fill missing values"""
import bottleneck as bn

axis = arr.get_axis_num(dim)

# work around for bottleneck 178
_limit = limit if limit is not None else arr.shape[axis]

return apply_ufunc(
bn.push,
push,
arr,
dask="parallelized",
dask="allowed",
keep_attrs=True,
output_dtypes=[arr.dtype],
kwargs=dict(n=_limit, axis=axis),
Expand All @@ -430,7 +426,7 @@ def bfill(arr, dim=None, limit=None):
return apply_ufunc(
_bfill,
arr,
dask="parallelized",
dask="allowed",
keep_attrs=True,
output_dtypes=[arr.dtype],
kwargs=dict(n=_limit, axis=axis),
Expand Down
25 changes: 25 additions & 0 deletions xarray/tests/test_duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
mean,
np_timedelta64_to_float,
pd_timedelta_to_float,
push,
py_timedelta_to_float,
stack,
timedelta_to_numeric,
Expand All @@ -34,6 +35,7 @@
has_dask,
has_scipy,
raise_if_dask_computes,
requires_bottleneck,
requires_cftime,
requires_dask,
)
Expand Down Expand Up @@ -858,3 +860,26 @@ def test_least_squares(use_dask, skipna):

np.testing.assert_allclose(coeffs, [1.5, 1.25])
np.testing.assert_allclose(residuals, [2.0])


@requires_dask
@requires_bottleneck
def test_push_dask():
import bottleneck
import dask.array

array = np.array([np.nan, np.nan, np.nan, 1, 2, 3, np.nan, np.nan, 4, 5, np.nan, 6])
expected = bottleneck.push(array, axis=0)
for c in range(1, 11):
with raise_if_dask_computes():
actual = push(dask.array.from_array(array, chunks=c), axis=0, n=None)
np.testing.assert_equal(actual, expected)

# some chunks of size-1 with NaN
with raise_if_dask_computes():
actual = push(
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)),
axis=0,
n=None,
)
np.testing.assert_equal(actual, expected)
49 changes: 26 additions & 23 deletions xarray/tests/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
assert_allclose,
assert_array_equal,
assert_equal,
raise_if_dask_computes,
requires_bottleneck,
requires_cftime,
requires_dask,
Expand Down Expand Up @@ -393,37 +394,39 @@ def test_ffill():

@requires_bottleneck
@requires_dask
def test_ffill_dask():
@pytest.mark.parametrize("method", ["ffill", "bfill"])
def test_ffill_bfill_dask(method):
da, _ = make_interpolate_example_data((40, 40), 0.5)
da = da.chunk({"x": 5})
actual = da.ffill("time")
expected = da.load().ffill("time")
assert isinstance(actual.data, dask_array_type)
assert_equal(actual, expected)

# with limit
da = da.chunk({"x": 5})
actual = da.ffill("time", limit=3)
expected = da.load().ffill("time", limit=3)
assert isinstance(actual.data, dask_array_type)
dask_method = getattr(da, method)
numpy_method = getattr(da.compute(), method)
# unchunked axis
with raise_if_dask_computes():
actual = dask_method("time")
expected = numpy_method("time")
assert_equal(actual, expected)


@requires_bottleneck
@requires_dask
def test_bfill_dask():
da, _ = make_interpolate_example_data((40, 40), 0.5)
da = da.chunk({"x": 5})
actual = da.bfill("time")
expected = da.load().bfill("time")
assert isinstance(actual.data, dask_array_type)
# chunked axis
with raise_if_dask_computes():
actual = dask_method("x")
expected = numpy_method("x")
assert_equal(actual, expected)

# with limit
da = da.chunk({"x": 5})
actual = da.bfill("time", limit=3)
expected = da.load().bfill("time", limit=3)
assert isinstance(actual.data, dask_array_type)
with raise_if_dask_computes():
actual = dask_method("time", limit=3)
expected = numpy_method("time", limit=3)
assert_equal(actual, expected)

# limit < axis size
with pytest.raises(NotImplementedError):
actual = dask_method("x", limit=2)

# limit > axis size
with raise_if_dask_computes():
actual = dask_method("x", limit=41)
expected = numpy_method("x", limit=41)
assert_equal(actual, expected)


Expand Down