Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement interp for interpolating between chunks of data (dask) #4155

Merged
merged 44 commits into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
62c6385
Implement interp for interpolating between chunks of data (dask)
Jun 15, 2020
f6f7dad
do not forget extra points at the end
Jun 15, 2020
b0d8a5f
add tests
Jun 15, 2020
1a31457
add whats-new comment
Jun 15, 2020
9933c73
fix isort / black
Jun 15, 2020
cea826b
typo
Jun 15, 2020
44bbedf
update pull number
Jun 15, 2020
067b7f3
fix github pep8 warnigns
Jun 15, 2020
c47a1d5
fix isort
Jun 15, 2020
7d505a1
clearer arguments in _dask_aware_interpnd
Jul 17, 2020
423b36d
typo
Jul 17, 2020
85ff539
fix for datetimelike index
Jul 20, 2020
6e9b50e
chunked interpolation does not work for high order interpolation (qua…
Jul 20, 2020
c63636f
Merge branch 'upstream' into chunked_interp
Jul 20, 2020
86cb592
fix whats new
Jul 20, 2020
5e26a4e
remove a useless import
Jul 20, 2020
3ca6e6d
use Variable instead of InexVariable
Jul 21, 2020
a131b21
avoid some list to tuple conversion
Jul 21, 2020
67d2b36
black fix
Jul 21, 2020
f485958
more comments to explain _compute_chunks
Jul 21, 2020
42f8a3b
For orthogonal linear- and nearest-neighbor interpolation, the scalar…
Jul 24, 2020
ec3c400
better detection of Advanced interpolation
Jul 24, 2020
e231954
implement support of unsorted interpolation destination
Jul 24, 2020
061f5a8
rework the tests
Jul 24, 2020
623cb0b
fix for datetime index (bug introduced with unsorted destination)
Jul 24, 2020
b66d123
Variable is cheaber that DataArray
Jul 24, 2020
e211127
add warning if unsorted
Jul 27, 2020
e610268
simplify _compute_chunks
Jul 27, 2020
7547d56
add ghosts point in order to make quadratic and cubic method work in…
Jul 27, 2020
fd936dd
black
Jul 27, 2020
24f9460
forgot to remove an exception in test_upsample_interpolate_dask
Jul 27, 2020
dd2f273
fix filtering out-of-order warning
Jul 28, 2020
49bdefa
use extrapolate to check external points
Jul 28, 2020
d280867
Revert "add ghosts point in order to make quadratic and cubic method …
Jul 29, 2020
aeb7be1
Complete rewrite using blockwise
Jul 29, 2020
3c7d8c6
Merge branch 'upstream' into chunked_interp
Jul 29, 2020
0bc35d2
update whats-new.rst
Jul 29, 2020
0d5f618
reduce the diff
Jul 29, 2020
290a075
more decomposition of orthogonal interpolation
Jul 29, 2020
3f8718e
simplify _dask_aware_interpnd a little
Jul 30, 2020
562d5aa
fix dask interp when chunks are not aligned
Jul 30, 2020
62f059c
continue simplifying _dask_aware_interpnd
Jul 31, 2020
3d4d45c
update whats-new.rst
Jul 31, 2020
b60cddf
clean tests
Jul 31, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,10 @@ New Features
Enhancements
~~~~~~~~~~~~
- Performance improvement of :py:meth:`DataArray.interp` and :py:func:`Dataset.interp`
For orthogonal linear- and nearest-neighbor interpolation, we do 1d-interpolation sequentially
rather than interpolating in multidimensional space. (:issue:`2223`)
We performs independant interpolation sequentially rather than interpolating in
one large multidimensional space. (:issue:`2223`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
- :py:meth:`DataArray.interp` now support interpolations over chunked dimensions (:pull:`4155`). By `Alexandre Poux <https://github.com/pums974>`_.
- Major performance improvement for :py:meth:`Dataset.from_dataframe` when the
dataframe has a MultiIndex (:pull:`4184`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
Expand Down
198 changes: 138 additions & 60 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,15 +544,6 @@ def _get_valid_fill_mask(arr, dim, limit):
) <= limit


def _assert_single_chunk(var, axes):
for axis in axes:
if len(var.chunks[axis]) > 1 or var.chunks[axis][0] < var.shape[axis]:
raise NotImplementedError(
"Chunking along the dimension to be interpolated "
"({}) is not yet supported.".format(axis)
)


def _localize(var, indexes_coords):
""" Speed up for linear and nearest neighbor method.
Only consider a subspace that is needed for the interpolation
Expand Down Expand Up @@ -617,49 +608,42 @@ def interp(var, indexes_coords, method, **kwargs):
if not indexes_coords:
return var.copy()

# simple speed up for the local interpolation
if method in ["linear", "nearest"]:
var, indexes_coords = _localize(var, indexes_coords)

# default behavior
kwargs["bounds_error"] = kwargs.get("bounds_error", False)

# check if the interpolation can be done in orthogonal manner
if (
len(indexes_coords) > 1
and method in ["linear", "nearest"]
and all(dest[1].ndim == 1 for dest in indexes_coords.values())
and len(set([d[1].dims[0] for d in indexes_coords.values()]))
== len(indexes_coords)
):
# interpolate sequentially
for dim, dest in indexes_coords.items():
var = interp(var, {dim: dest}, method, **kwargs)
return var

# target dimensions
dims = list(indexes_coords)
x, new_x = zip(*[indexes_coords[d] for d in dims])
destination = broadcast_variables(*new_x)

# transpose to make the interpolated axis to the last position
broadcast_dims = [d for d in var.dims if d not in dims]
original_dims = broadcast_dims + dims
new_dims = broadcast_dims + list(destination[0].dims)
interped = interp_func(
var.transpose(*original_dims).data, x, destination, method, kwargs
)
result = var
# decompose the interpolation into a succession of independant interpolation
for indexes_coords in decompose_interp(indexes_coords):
var = result

# simple speed up for the local interpolation
if method in ["linear", "nearest"]:
var, indexes_coords = _localize(var, indexes_coords)

# target dimensions
dims = list(indexes_coords)
x, new_x = zip(*[indexes_coords[d] for d in dims])
destination = broadcast_variables(*new_x)

# transpose to make the interpolated axis to the last position
broadcast_dims = [d for d in var.dims if d not in dims]
original_dims = broadcast_dims + dims
new_dims = broadcast_dims + list(destination[0].dims)
interped = interp_func(
var.transpose(*original_dims).data, x, destination, method, kwargs
)

result = Variable(new_dims, interped, attrs=var.attrs)
result = Variable(new_dims, interped, attrs=var.attrs)

# dimension of the output array
out_dims = OrderedSet()
for d in var.dims:
if d in dims:
out_dims.update(indexes_coords[d][1].dims)
else:
out_dims.add(d)
return result.transpose(*tuple(out_dims))
# dimension of the output array
out_dims = OrderedSet()
for d in var.dims:
if d in dims:
out_dims.update(indexes_coords[d][1].dims)
else:
out_dims.add(d)
result = result.transpose(*tuple(out_dims))
return result


def interp_func(var, x, new_x, method, kwargs):
Expand Down Expand Up @@ -706,21 +690,51 @@ def interp_func(var, x, new_x, method, kwargs):
if isinstance(var, dask_array_type):
import dask.array as da

_assert_single_chunk(var, range(var.ndim - len(x), var.ndim))
chunks = var.chunks[: -len(x)] + new_x[0].shape
drop_axis = range(var.ndim - len(x), var.ndim)
new_axis = range(var.ndim - len(x), var.ndim - len(x) + new_x[0].ndim)
return da.map_blocks(
_interpnd,
nconst = var.ndim - len(x)

out_ind = list(range(nconst)) + list(range(var.ndim, var.ndim + new_x[0].ndim))

# blockwise args format
x_arginds = [[_x, (nconst + index,)] for index, _x in enumerate(x)]
x_arginds = [item for pair in x_arginds for item in pair]
new_x_arginds = [
[_x, [var.ndim + index for index in range(_x.ndim)]] for _x in new_x
]
new_x_arginds = [item for pair in new_x_arginds for item in pair]

args = (
var,
x,
new_x,
func,
kwargs,
range(var.ndim),
*x_arginds,
*new_x_arginds,
)

_, rechunked = da.unify_chunks(*args)

args = tuple([elem for pair in zip(rechunked, args[1::2]) for elem in pair])

new_x = rechunked[1 + (len(rechunked) - 1) // 2 :]

new_axes = {
var.ndim + i: new_x[0].chunks[i]
if new_x[0].chunks is not None
else new_x[0].shape[i]
for i in range(new_x[0].ndim)
}

# if usefull, re-use localize for each chunk of new_x
localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None)

return da.blockwise(
_dask_aware_interpnd,
out_ind,
*args,
interp_func=func,
interp_kwargs=kwargs,
localize=localize,
concatenate=True,
dtype=var.dtype,
chunks=chunks,
new_axis=new_axis,
drop_axis=drop_axis,
new_axes=new_axes,
)

return _interpnd(var, x, new_x, func, kwargs)
Expand Down Expand Up @@ -751,3 +765,67 @@ def _interpnd(var, x, new_x, func, kwargs):
# move back the interpolation axes to the last position
rslt = rslt.transpose(range(-rslt.ndim + 1, 1))
return rslt.reshape(rslt.shape[:-1] + new_x[0].shape)


def _dask_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True):
"""Wrapper for `_interpnd` through `blockwise`

The first half arrays in `coords` are original coordinates,
the other half are destination coordinates
"""
n_x = len(coords) // 2
nconst = len(var.shape) - n_x

# _interpnd expect coords to be Variables
x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])]
new_x = [
Variable([f"dim_{len(var.shape) + dim}" for dim in range(len(_x.shape))], _x)
for _x in coords[n_x:]
]

if localize:
# _localize expect var to be a Variable
var = Variable([f"dim_{dim}" for dim in range(len(var.shape))], var)

indexes_coords = {_x.dims[0]: (_x, _new_x) for _x, _new_x in zip(x, new_x)}

# simple speed up for the local interpolation
var, indexes_coords = _localize(var, indexes_coords)
x, new_x = zip(*[indexes_coords[d] for d in indexes_coords])

# put var back as a ndarray
var = var.data

return _interpnd(var, x, new_x, interp_func, interp_kwargs)


def decompose_interp(indexes_coords):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvement!

"""Decompose the interpolation into a succession of independant interpolation keeping the order"""

dest_dims = [
dest[1].dims if dest[1].ndim > 0 else [dim]
for dim, dest in indexes_coords.items()
]
partial_dest_dims = []
partial_indexes_coords = {}
for i, index_coords in enumerate(indexes_coords.items()):
partial_indexes_coords.update([index_coords])

if i == len(dest_dims) - 1:
break

partial_dest_dims += [dest_dims[i]]
other_dims = dest_dims[i + 1 :]

s_partial_dest_dims = {dim for dims in partial_dest_dims for dim in dims}
s_other_dims = {dim for dims in other_dims for dim in dims}

if not s_partial_dest_dims.intersection(s_other_dims):
# this interpolation is orthogonal to the rest

yield partial_indexes_coords

partial_dest_dims = []
partial_indexes_coords = {}

yield partial_indexes_coords
12 changes: 4 additions & 8 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3147,7 +3147,8 @@ def test_upsample_interpolate_regression_1605(self):

@requires_dask
@requires_scipy
def test_upsample_interpolate_dask(self):
@pytest.mark.parametrize("chunked_time", [True, False])
def test_upsample_interpolate_dask(self, chunked_time):
from scipy.interpolate import interp1d

xs = np.arange(6)
Expand All @@ -3158,6 +3159,8 @@ def test_upsample_interpolate_dask(self):
data = np.tile(z, (6, 3, 1))
array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time"))
chunks = {"x": 2, "y": 1}
if chunked_time:
chunks["time"] = 3

expected_times = times.to_series().resample("1H").asfreq().index
# Split the times into equal sub-intervals to simulate the 6 hour
Expand Down Expand Up @@ -3185,13 +3188,6 @@ def test_upsample_interpolate_dask(self):
# done here due to floating point arithmetic
assert_allclose(expected, actual, rtol=1e-16)

# Check that an error is raised if an attempt is made to interpolate
# over a chunked dimension
with raises_regex(
NotImplementedError, "Chunking along the dimension to be interpolated"
):
array.chunk({"time": 1}).resample(time="1H").interpolate("linear")

def test_align(self):
array = DataArray(
np.random.random((6, 8)), coords={"x": list("abcdef")}, dims=["x", "y"]
Expand Down
Loading