From 62c6385c661e0fa91d97068ff24f0781a74a28b9 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 15 Jun 2020 14:55:57 +0200 Subject: [PATCH 01/42] Implement interp for interpolating between chunks of data (dask) --- xarray/core/missing.py | 149 +++++++++++++++++++++++++++++++++++------ 1 file changed, 128 insertions(+), 21 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 59d4f777c73..f8e9955b846 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -544,13 +544,11 @@ def _get_valid_fill_mask(arr, dim, limit): ) <= limit -def _assert_single_chunk(var, axes): +def _single_chunk(var, axes): for axis in axes: if len(var.chunks[axis]) > 1 or var.chunks[axis][0] < var.shape[axis]: - raise NotImplementedError( - "Chunking along the dimension to be interpolated " - "({}) is not yet supported.".format(axis) - ) + return False + return True def _localize(var, indexes_coords): @@ -706,23 +704,64 @@ def interp_func(var, x, new_x, method, kwargs): if isinstance(var, dask_array_type): import dask.array as da - _assert_single_chunk(var, range(var.ndim - len(x), var.ndim)) - chunks = var.chunks[: -len(x)] + new_x[0].shape - drop_axis = range(var.ndim - len(x), var.ndim) - new_axis = range(var.ndim - len(x), var.ndim - len(x) + new_x[0].ndim) - return da.map_blocks( - _interpnd, - var, - x, - new_x, - func, - kwargs, - dtype=var.dtype, - chunks=chunks, - new_axis=new_axis, - drop_axis=drop_axis, - ) + # easyer, and allows advanced interpolation + if _single_chunk(var, range(var.ndim - len(x), var.ndim)): + chunks = var.chunks[: -len(x)] + new_x[0].shape + drop_axis = range(var.ndim - len(x), var.ndim) + new_axis = range(var.ndim - len(x), var.ndim - len(x) + new_x[0].ndim) + return da.map_blocks( + _interpnd, + var, + x, + new_x, + func, + kwargs, + dtype=var.dtype, + chunks=chunks, + new_axis=new_axis, + drop_axis=drop_axis, + ) + + current_dims = [_x.name for _x in x] + + # number of non interpolated dimensions + nconst = var.ndim - len(x) + # chunks x + x = tuple(da.from_array(_x, chunks=chunks) for _x, chunks in zip(x, var.chunks[nconst:])) + + # duplicate the ghost cells of the array in the interpolated dimensions + var_with_ghost, x_with_ghost = _add_interp_ghost(var, x, nconst) + + # compute final chunks + target_dims = set.union(*[set(_x.dims) for _x in new_x]) + if target_dims - set(current_dims): + raise NotImplementedError( + "Advanced interpolation is not implemented with chunked dimension") + new_x = tuple([_x.set_dims(current_dims) for _x in new_x]) + total_chunks = _compute_chunks(x, x_with_ghost, new_x) + final_chunks = var.chunks[:-len(x)] + tuple(total_chunks) + + # chunks new_x + new_x = tuple(da.from_array(_x, chunks=total_chunks) for _x in new_x) + + # reshape x_with_ghost + # TODO: remove it (see _dask_aware_interpnd) + x_with_ghost = da.meshgrid(*x_with_ghost, indexing='ij') + + # compute on chunks + res = da.map_blocks(_dask_aware_interpnd, + var_with_ghost, func, kwargs, len(x_with_ghost), + *x_with_ghost, *new_x, + dtype=var.dtype, chunks=final_chunks) + + # reshape res and remove empty chunks + # TODO: remove it by using drop_axis and new_axis in map_blocks + res = res.squeeze() + new_chunks = tuple([tuple([chunk for chunk in chunks if chunk > 0]) for chunks in res.chunks]) + res = res.rechunk(new_chunks) + return res + return _interpnd(var, x, new_x, func, kwargs) @@ -751,3 +790,71 @@ def _interpnd(var, x, new_x, func, kwargs): # move back the interpolation axes to the last position rslt = rslt.transpose(range(-rslt.ndim + 1, 1)) return rslt.reshape(rslt.shape[:-1] + new_x[0].shape) + + +def _dask_aware_interpnd(var, + func: Callable[..., Any], + kwargs: Any, + nx: int, + *arrs): + """Wrapper for `_interpnd` allowing dask array to be used in `map_blocks` + + The first `nx` arrays in `arrs` are orginal coordinates, the rest are destination coordinate + Currently this need original coordinate to be full arrays (meshgrid) + + TODO: find a way to use 1d coordinates + """ + from .dataarray import DataArray + _old_x, _new_x = arrs[:nx], arrs[nx:] + + # reshape x (TODO REMOVE) + old_x = tuple([np.moveaxis(tmp, dim, -1)[tuple([0] * (len(tmp.shape) - 1))] + for dim, tmp in enumerate(_old_x)]) + + new_x = tuple([DataArray(_x) for _x in _new_x]) + + return _interpnd(var, old_x, new_x, func, kwargs) + + +def _add_interp_ghost(var, x, nconst: int): + """ Duplicate the ghost cells of the array (values and coordinates)""" + import dask.array as da + bnd = {i: "none" for i in range(len(var.shape))} + depth = {i: 0 if i < nconst else 1 for i in range(len(var.shape))} + + var_with_ghost = da.overlap.overlap(var, depth=depth, boundary=bnd) + + x_with_ghost = tuple(da.overlap.overlap(_x, depth={0: 1}, boundary={0: "none"}) + for _x in x) + return var_with_ghost, x_with_ghost + + +def _compute_chunks(x, x_with_ghost, new_x): + """Compute equilibrated chunks of new_x + + TODO: This only works if new_x is a set of 1d coordinate + more general function is needed for advanced interpolation with chunked dimension + """ + chunks_end = [np.cumsum(sizes) - 1 for _x in x + for sizes in _x.chunks] + chunks_end_with_ghost = [np.cumsum(sizes) - 1 for _x in x_with_ghost + for sizes in _x.chunks] + total_chunks = [] + for dim, ce in enumerate(zip(chunks_end, chunks_end_with_ghost)): + l_new_x_ends: List[np.ndarray] = [] + for iend, iend_with_ghost in zip(*ce): + + arr = np.moveaxis(new_x[dim].data, dim, -1) + arr = arr[tuple([0] * (len(arr.shape) - 1))] + + n_no_ghost = (arr <= x[dim][iend]).sum() + n_ghost = (arr <= x_with_ghost[dim][iend_with_ghost]).sum() + + equil = np.ceil(0.5 * (n_no_ghost + n_ghost)).astype(int) + + l_new_x_ends.append(equil) + + new_x_ends = np.array(l_new_x_ends) + chunks = new_x_ends[0], *(new_x_ends[1:] - new_x_ends[:-1]) + total_chunks.append(tuple(chunks)) + return total_chunks From f6f7dad535239e0ec299c9e8d6fba2d1aa51e285 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 15 Jun 2020 16:30:12 +0200 Subject: [PATCH 02/42] do not forget extra points at the end --- xarray/core/missing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index f8e9955b846..ef32235fbde 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -855,6 +855,8 @@ def _compute_chunks(x, x_with_ghost, new_x): l_new_x_ends.append(equil) new_x_ends = np.array(l_new_x_ends) + # do not forget extra points at the end + new_x_ends[-1] = len(arr) chunks = new_x_ends[0], *(new_x_ends[1:] - new_x_ends[:-1]) total_chunks.append(tuple(chunks)) return total_chunks From b0d8a5fe76bd006d97404a7a4562ef7e80e94cca Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 15 Jun 2020 16:30:40 +0200 Subject: [PATCH 03/42] add tests --- xarray/tests/test_interp.py | 53 +++++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 7a0dda216e2..bca71db8b2a 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -64,11 +64,6 @@ def test_interpolate_1d(method, dim, case): da = get_example_data(case) xdest = np.linspace(0.0, 0.9, 80) - if dim == "y" and case == 1: - with pytest.raises(NotImplementedError): - actual = da.interp(method=method, **{dim: xdest}) - pytest.skip("interpolation along chunked dimension is " "not yet supported") - actual = da.interp(method=method, **{dim: xdest}) # scipy interpolation for the reference @@ -717,3 +712,51 @@ def test_decompose(method): actual = da.interp(x=x_new, y=y_new, method=method).drop(("x", "y")) expected = da.interp(x=x_broadcast, y=y_broadcast, method=method).drop(("x", "y")) assert_allclose(actual, expected) + + +def test_interpolate_chunk_1d(): + if not has_scipy: + pytest.skip("scipy is not installed.") + + if not has_dask: + pytest.skip("dask is not installed in the environment.") + + da = get_example_data(1) + ydest = np.linspace(-0.1, 0.2, 80) + + actual = da.interp(method="linear", y=ydest) + expected = da.compute().interp(method="linear", y=ydest) + + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("scalar_nx", [True, False]) +def test_interpolate_chunk_nd(scalar_nx): + if not has_scipy: + pytest.skip("scipy is not installed.") + + if not has_dask: + pytest.skip("dask is not installed in the environment.") + + da = get_example_data(1).chunk({"x": 50}) + + if scalar_nx: + # 0.5 is between chunks + xdest = 0.5 + dims=["y"] + else: + # -0.5 is before data + # 0.5 is between chunks + # 1.5 is after data + xdest = [-0.5, 0.25, 0.5, 0.75, 1.5] + dims=["x", "y"] + # -0.1 is before data + # 0.05 is between chunks + # 0.15 is after data + ydest = [-0.1, 0.025, 0.05, 0.075, 0.15] + + actual = da.interp(method="linear", x=xdest, y=ydest) + expected = da.compute().interp(method="linear", x=xdest, y=ydest) + + assert_allclose(actual, expected) + From 1a3145760f82a91c0928b5b40ed405cdffdf96ec Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 15 Jun 2020 16:31:04 +0200 Subject: [PATCH 04/42] add whats-new comment --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index bcff60ce4df..c988c0fda3f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -42,6 +42,8 @@ Enhancements By `Keisuke Fujii `_. - :py:meth:`DataArray.reset_index` and :py:meth:`Dataset.reset_index` now keep coordinate attributes (:pull:`4103`). By `Oriol Abril `_. +- :py:meth:`DataArray.interp` now support simple interpolation in a chunked dimension + (but not advanced interpolation) (:pull:`??`). By `Alexandre Poux `_. New Features ~~~~~~~~~~~~ From 9933c73a7a2083d36cd0efbe1a178441e4ef8958 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 15 Jun 2020 16:32:45 +0200 Subject: [PATCH 05/42] fix isort / black --- xarray/core/missing.py | 66 +++++++++++++++++++++++-------------- xarray/tests/test_interp.py | 15 ++++----- 2 files changed, 48 insertions(+), 33 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index ef32235fbde..c5c15ece182 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -728,7 +728,10 @@ def interp_func(var, x, new_x, method, kwargs): nconst = var.ndim - len(x) # chunks x - x = tuple(da.from_array(_x, chunks=chunks) for _x, chunks in zip(x, var.chunks[nconst:])) + x = tuple( + da.from_array(_x, chunks=chunks) + for _x, chunks in zip(x, var.chunks[nconst:]) + ) # duplicate the ghost cells of the array in the interpolated dimensions var_with_ghost, x_with_ghost = _add_interp_ghost(var, x, nconst) @@ -737,31 +740,41 @@ def interp_func(var, x, new_x, method, kwargs): target_dims = set.union(*[set(_x.dims) for _x in new_x]) if target_dims - set(current_dims): raise NotImplementedError( - "Advanced interpolation is not implemented with chunked dimension") + "Advanced interpolation is not implemented with chunked dimension" + ) new_x = tuple([_x.set_dims(current_dims) for _x in new_x]) total_chunks = _compute_chunks(x, x_with_ghost, new_x) - final_chunks = var.chunks[:-len(x)] + tuple(total_chunks) + final_chunks = var.chunks[: -len(x)] + tuple(total_chunks) # chunks new_x new_x = tuple(da.from_array(_x, chunks=total_chunks) for _x in new_x) # reshape x_with_ghost # TODO: remove it (see _dask_aware_interpnd) - x_with_ghost = da.meshgrid(*x_with_ghost, indexing='ij') + x_with_ghost = da.meshgrid(*x_with_ghost, indexing="ij") # compute on chunks - res = da.map_blocks(_dask_aware_interpnd, - var_with_ghost, func, kwargs, len(x_with_ghost), - *x_with_ghost, *new_x, - dtype=var.dtype, chunks=final_chunks) + res = da.map_blocks( + _dask_aware_interpnd, + var_with_ghost, + func, + kwargs, + len(x_with_ghost), + *x_with_ghost, + *new_x, + dtype=var.dtype, + chunks=final_chunks, + ) # reshape res and remove empty chunks # TODO: remove it by using drop_axis and new_axis in map_blocks res = res.squeeze() - new_chunks = tuple([tuple([chunk for chunk in chunks if chunk > 0]) for chunks in res.chunks]) + new_chunks = tuple( + [tuple([chunk for chunk in chunks if chunk > 0]) for chunks in res.chunks] + ) res = res.rechunk(new_chunks) return res - + return _interpnd(var, x, new_x, func, kwargs) @@ -790,13 +803,9 @@ def _interpnd(var, x, new_x, func, kwargs): # move back the interpolation axes to the last position rslt = rslt.transpose(range(-rslt.ndim + 1, 1)) return rslt.reshape(rslt.shape[:-1] + new_x[0].shape) - - -def _dask_aware_interpnd(var, - func: Callable[..., Any], - kwargs: Any, - nx: int, - *arrs): + + +def _dask_aware_interpnd(var, func: Callable[..., Any], kwargs: Any, nx: int, *arrs): """Wrapper for `_interpnd` allowing dask array to be used in `map_blocks` The first `nx` arrays in `arrs` are orginal coordinates, the rest are destination coordinate @@ -805,11 +814,16 @@ def _dask_aware_interpnd(var, TODO: find a way to use 1d coordinates """ from .dataarray import DataArray + _old_x, _new_x = arrs[:nx], arrs[nx:] # reshape x (TODO REMOVE) - old_x = tuple([np.moveaxis(tmp, dim, -1)[tuple([0] * (len(tmp.shape) - 1))] - for dim, tmp in enumerate(_old_x)]) + old_x = tuple( + [ + np.moveaxis(tmp, dim, -1)[tuple([0] * (len(tmp.shape) - 1))] + for dim, tmp in enumerate(_old_x) + ] + ) new_x = tuple([DataArray(_x) for _x in _new_x]) @@ -819,13 +833,15 @@ def _dask_aware_interpnd(var, def _add_interp_ghost(var, x, nconst: int): """ Duplicate the ghost cells of the array (values and coordinates)""" import dask.array as da + bnd = {i: "none" for i in range(len(var.shape))} depth = {i: 0 if i < nconst else 1 for i in range(len(var.shape))} var_with_ghost = da.overlap.overlap(var, depth=depth, boundary=bnd) - x_with_ghost = tuple(da.overlap.overlap(_x, depth={0: 1}, boundary={0: "none"}) - for _x in x) + x_with_ghost = tuple( + da.overlap.overlap(_x, depth={0: 1}, boundary={0: "none"}) for _x in x + ) return var_with_ghost, x_with_ghost @@ -835,10 +851,10 @@ def _compute_chunks(x, x_with_ghost, new_x): TODO: This only works if new_x is a set of 1d coordinate more general function is needed for advanced interpolation with chunked dimension """ - chunks_end = [np.cumsum(sizes) - 1 for _x in x - for sizes in _x.chunks] - chunks_end_with_ghost = [np.cumsum(sizes) - 1 for _x in x_with_ghost - for sizes in _x.chunks] + chunks_end = [np.cumsum(sizes) - 1 for _x in x for sizes in _x.chunks] + chunks_end_with_ghost = [ + np.cumsum(sizes) - 1 for _x in x_with_ghost for sizes in _x.chunks + ] total_chunks = [] for dim, ce in enumerate(zip(chunks_end, chunks_end_with_ghost)): l_new_x_ends: List[np.ndarray] = [] diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index bca71db8b2a..787c46c8279 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -712,8 +712,8 @@ def test_decompose(method): actual = da.interp(x=x_new, y=y_new, method=method).drop(("x", "y")) expected = da.interp(x=x_broadcast, y=y_broadcast, method=method).drop(("x", "y")) assert_allclose(actual, expected) - - + + def test_interpolate_chunk_1d(): if not has_scipy: pytest.skip("scipy is not installed.") @@ -726,7 +726,7 @@ def test_interpolate_chunk_1d(): actual = da.interp(method="linear", y=ydest) expected = da.compute().interp(method="linear", y=ydest) - + assert_allclose(actual, expected) @@ -739,17 +739,17 @@ def test_interpolate_chunk_nd(scalar_nx): pytest.skip("dask is not installed in the environment.") da = get_example_data(1).chunk({"x": 50}) - + if scalar_nx: # 0.5 is between chunks xdest = 0.5 - dims=["y"] + dims = ["y"] else: # -0.5 is before data # 0.5 is between chunks # 1.5 is after data xdest = [-0.5, 0.25, 0.5, 0.75, 1.5] - dims=["x", "y"] + dims = ["x", "y"] # -0.1 is before data # 0.05 is between chunks # 0.15 is after data @@ -757,6 +757,5 @@ def test_interpolate_chunk_nd(scalar_nx): actual = da.interp(method="linear", x=xdest, y=ydest) expected = da.compute().interp(method="linear", x=xdest, y=ydest) - + assert_allclose(actual, expected) - From cea826b53e6e68e22bddf28975601d7efb140045 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 15 Jun 2020 16:41:40 +0200 Subject: [PATCH 06/42] typo --- xarray/core/missing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index c5c15ece182..75f5123b5ca 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -808,7 +808,7 @@ def _interpnd(var, x, new_x, func, kwargs): def _dask_aware_interpnd(var, func: Callable[..., Any], kwargs: Any, nx: int, *arrs): """Wrapper for `_interpnd` allowing dask array to be used in `map_blocks` - The first `nx` arrays in `arrs` are orginal coordinates, the rest are destination coordinate + The first `nx` arrays in `arrs` are original coordinates, the rest are destination coordinate Currently this need original coordinate to be full arrays (meshgrid) TODO: find a way to use 1d coordinates From 44bbedf1f1fdfe9788bcbe0fb9d495c85e1cf3d9 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 15 Jun 2020 16:43:27 +0200 Subject: [PATCH 07/42] update pull number --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c988c0fda3f..3938369b061 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -43,7 +43,7 @@ Enhancements - :py:meth:`DataArray.reset_index` and :py:meth:`Dataset.reset_index` now keep coordinate attributes (:pull:`4103`). By `Oriol Abril `_. - :py:meth:`DataArray.interp` now support simple interpolation in a chunked dimension - (but not advanced interpolation) (:pull:`??`). By `Alexandre Poux `_. + (but not advanced interpolation) (:pull:`4155`). By `Alexandre Poux `_. New Features ~~~~~~~~~~~~ From 067b7f35e82329248a4419e1995da67818b6d433 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 15 Jun 2020 16:45:46 +0200 Subject: [PATCH 08/42] fix github pep8 warnigns --- xarray/core/missing.py | 2 +- xarray/tests/test_interp.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 75f5123b5ca..7d2e1056ae6 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -2,7 +2,7 @@ import warnings from functools import partial from numbers import Number -from typing import Any, Callable, Dict, Hashable, Sequence, Union +from typing import Any, Callable, Dict, Hashable, Sequence, Union, List import numpy as np import pandas as pd diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 787c46c8279..573d28c7ab3 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -743,13 +743,11 @@ def test_interpolate_chunk_nd(scalar_nx): if scalar_nx: # 0.5 is between chunks xdest = 0.5 - dims = ["y"] else: # -0.5 is before data # 0.5 is between chunks # 1.5 is after data xdest = [-0.5, 0.25, 0.5, 0.75, 1.5] - dims = ["x", "y"] # -0.1 is before data # 0.05 is between chunks # 0.15 is after data From c47a1d5d8fd7ca401a0dddea67574af00c4d8e3b Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 15 Jun 2020 17:10:51 +0200 Subject: [PATCH 09/42] fix isort --- xarray/core/missing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 7d2e1056ae6..d0a01f2681c 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -2,7 +2,7 @@ import warnings from functools import partial from numbers import Number -from typing import Any, Callable, Dict, Hashable, Sequence, Union, List +from typing import Any, Callable, Dict, Hashable, List, Sequence, Union import numpy as np import pandas as pd From 7d505a1a3574db86ee23d7a80bd6b4f8802a8758 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Fri, 17 Jul 2020 17:54:38 +0200 Subject: [PATCH 10/42] clearer arguments in _dask_aware_interpnd --- xarray/core/missing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index d0a01f2681c..52e005c3a83 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -757,11 +757,11 @@ def interp_func(var, x, new_x, method, kwargs): res = da.map_blocks( _dask_aware_interpnd, var_with_ghost, - func, - kwargs, - len(x_with_ghost), *x_with_ghost, *new_x, + interp_func=func, + interp_kwargs=kwargs, + n_coord=len(x_with_ghost), dtype=var.dtype, chunks=final_chunks, ) @@ -805,17 +805,17 @@ def _interpnd(var, x, new_x, func, kwargs): return rslt.reshape(rslt.shape[:-1] + new_x[0].shape) -def _dask_aware_interpnd(var, func: Callable[..., Any], kwargs: Any, nx: int, *arrs): +def _dask_aware_interpnd(var, *coords, n_coords: int, interp_func, interp_kwargs): """Wrapper for `_interpnd` allowing dask array to be used in `map_blocks` - The first `nx` arrays in `arrs` are original coordinates, the rest are destination coordinate + The first `n_coords` arrays in `coords` are original coordinates, the rest are destination coordinate Currently this need original coordinate to be full arrays (meshgrid) TODO: find a way to use 1d coordinates """ from .dataarray import DataArray - _old_x, _new_x = arrs[:nx], arrs[nx:] + _old_x, _new_x = coords[:n_coords], coords[n_coords:] # reshape x (TODO REMOVE) old_x = tuple( @@ -827,7 +827,7 @@ def _dask_aware_interpnd(var, func: Callable[..., Any], kwargs: Any, nx: int, *a new_x = tuple([DataArray(_x) for _x in _new_x]) - return _interpnd(var, old_x, new_x, func, kwargs) + return _interpnd(var, old_x, new_x, interp_func, interp_kwargs) def _add_interp_ghost(var, x, nconst: int): From 423b36d25d82ca28d3c0ee693e564bd092dfcb0f Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Fri, 17 Jul 2020 17:55:17 +0200 Subject: [PATCH 11/42] typo --- xarray/core/missing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 52e005c3a83..920312ce0b6 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -761,7 +761,7 @@ def interp_func(var, x, new_x, method, kwargs): *new_x, interp_func=func, interp_kwargs=kwargs, - n_coord=len(x_with_ghost), + n_coords=len(x_with_ghost), dtype=var.dtype, chunks=final_chunks, ) From 85ff5394f550e580d3e1f5a3cab19cbee885a10f Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 20 Jul 2020 14:46:21 +0200 Subject: [PATCH 12/42] fix for datetimelike index --- xarray/core/missing.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 920312ce0b6..6716d1be2e9 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -13,7 +13,7 @@ from .duck_array_ops import dask_array_type, datetime_to_numeric, timedelta_to_numeric from .options import _get_keep_attrs from .utils import OrderedSet, is_scalar -from .variable import Variable, broadcast_variables +from .variable import IndexVariable, Variable, broadcast_variables def _get_nan_block_lengths(obj, dim: Hashable, index: Variable): @@ -820,12 +820,21 @@ def _dask_aware_interpnd(var, *coords, n_coords: int, interp_func, interp_kwargs # reshape x (TODO REMOVE) old_x = tuple( [ - np.moveaxis(tmp, dim, -1)[tuple([0] * (len(tmp.shape) - 1))] + IndexVariable( + str(dim), np.moveaxis(tmp, dim, -1)[tuple([0] * (len(tmp.shape) - 1))] + ) for dim, tmp in enumerate(_old_x) ] ) - new_x = tuple([DataArray(_x) for _x in _new_x]) + new_x = tuple( + [ + Variable( + [f"{outer_dim}{inner_dim}" for inner_dim in range(len(_x.shape))], _x + ) + for outer_dim, _x in enumerate(_new_x) + ] + ) return _interpnd(var, old_x, new_x, interp_func, interp_kwargs) From 6e9b50e04e514896ef5a881fcbf7a1ef2418e505 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 20 Jul 2020 14:47:34 +0200 Subject: [PATCH 13/42] chunked interpolation does not work for high order interpolation (quadratic or cubic) --- xarray/core/missing.py | 5 +++++ xarray/tests/test_dataarray.py | 36 +++++++++++++++++++++------------- xarray/tests/test_interp.py | 18 ++++++++++++++++- 3 files changed, 44 insertions(+), 15 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 6716d1be2e9..6f2a2d3bb76 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -722,6 +722,11 @@ def interp_func(var, x, new_x, method, kwargs): drop_axis=drop_axis, ) + if method in ["quadratic", "cubic"]: + raise NotImplementedError( + "Only constant or linear interpolation are available in a chunked direction" + ) + current_dims = [_x.name for _x in x] # number of non interpolated dimensions diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 95f0ad9f612..25a727be3b2 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3147,7 +3147,8 @@ def test_upsample_interpolate_regression_1605(self): @requires_dask @requires_scipy - def test_upsample_interpolate_dask(self): + @pytest.mark.parametrize("chunked_time", [True, False]) + def test_upsample_interpolate_dask(self, chunked_time): from scipy.interpolate import interp1d xs = np.arange(6) @@ -3158,13 +3159,27 @@ def test_upsample_interpolate_dask(self): data = np.tile(z, (6, 3, 1)) array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) chunks = {"x": 2, "y": 1} + if chunked_time: + chunks["time"] = 3 expected_times = times.to_series().resample("1H").asfreq().index # Split the times into equal sub-intervals to simulate the 6 hour # to 1 hour up-sampling new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: - actual = array.chunk(chunks).resample(time="1H").interpolate(kind) + actual = array.chunk(chunks).resample(time="1H") + + if chunked_time and (kind in ["quadratic", "cubic"]): + # Check that an error is raised if an attempt is made to interpolate + # over a chunked dimension with high order method + with raises_regex( + NotImplementedError, + "Only constant or linear interpolation are available in a chunked direction", + ): + actual.interpolate(kind) + continue + + actual = actual.interpolate(kind) actual = actual.compute() f = interp1d( np.arange(len(times)), @@ -3185,13 +3200,6 @@ def test_upsample_interpolate_dask(self): # done here due to floating point arithmetic assert_allclose(expected, actual, rtol=1e-16) - # Check that an error is raised if an attempt is made to interpolate - # over a chunked dimension - with raises_regex( - NotImplementedError, "Chunking along the dimension to be interpolated" - ): - array.chunk({"time": 1}).resample(time="1H").interpolate("linear") - def test_align(self): array = DataArray( np.random.random((6, 8)), coords={"x": list("abcdef")}, dims=["x", "y"] @@ -5575,8 +5583,8 @@ def test_name_in_masking(): class TestIrisConversion: @requires_iris def test_to_and_from_iris(self): - import iris import cf_units # iris requirement + import iris # to iris coord_dict = {} @@ -5646,9 +5654,9 @@ def test_to_and_from_iris(self): @requires_iris @requires_dask def test_to_and_from_iris_dask(self): + import cf_units # iris requirement import dask.array as da import iris - import cf_units # iris requirement coord_dict = {} coord_dict["distance"] = ("distance", [-2, 2], {"units": "meters"}) @@ -5781,8 +5789,8 @@ def test_da_name_from_cube(self, std_name, long_name, var_name, name, attrs): ], ) def test_da_coord_name_from_cube(self, std_name, long_name, var_name, name, attrs): - from iris.cube import Cube from iris.coords import DimCoord + from iris.cube import Cube latitude = DimCoord( [-90, 0, 90], standard_name=std_name, var_name=var_name, long_name=long_name @@ -5795,8 +5803,8 @@ def test_da_coord_name_from_cube(self, std_name, long_name, var_name, name, attr @requires_iris def test_prevent_duplicate_coord_names(self): - from iris.cube import Cube from iris.coords import DimCoord + from iris.cube import Cube # Iris enforces unique coordinate names. Because we use a different # name resolution order a valid iris Cube with coords that have the @@ -5817,8 +5825,8 @@ def test_prevent_duplicate_coord_names(self): [["IA", "IL", "IN"], [0, 2, 1]], # non-numeric values # non-monotonic values ) def test_fallback_to_iris_AuxCoord(self, coord_values): - from iris.cube import Cube from iris.coords import AuxCoord + from iris.cube import Cube data = [0, 0, 0] da = xr.DataArray(data, coords=[coord_values], dims=["space"]) diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 573d28c7ab3..b86b536a4d5 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -3,7 +3,13 @@ import pytest import xarray as xr -from xarray.tests import assert_allclose, assert_equal, requires_cftime, requires_scipy +from xarray.tests import ( + assert_allclose, + assert_equal, + raises_regex, + requires_cftime, + requires_scipy, +) from ..coding.cftimeindex import _parse_array_of_cftime_strings from . import has_dask, has_scipy @@ -64,6 +70,16 @@ def test_interpolate_1d(method, dim, case): da = get_example_data(case) xdest = np.linspace(0.0, 0.9, 80) + if method == "cubic" and dim == "y" and case == 1: + # Check that an error is raised if an attempt is made to interpolate + # over a chunked dimension with high order method + with raises_regex( + NotImplementedError, + "Only constant or linear interpolation are available in a chunked direction", + ): + da.interp(method=method, **{dim: xdest}) + return + actual = da.interp(method=method, **{dim: xdest}) # scipy interpolation for the reference From 86cb592505512c786236ba494c450583d8574e10 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 20 Jul 2020 15:02:10 +0200 Subject: [PATCH 14/42] fix whats new --- doc/whats-new.rst | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 98f99f939da..e11981481cd 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -77,17 +77,6 @@ Breaking changes `_. (:pull:`3274`) By `Elliott Sales de Andrade `_ - -Enhancements -~~~~~~~~~~~~ -- Performance improvement of :py:meth:`DataArray.interp` and :py:func:`Dataset.interp` - For orthogonal linear- and nearest-neighbor interpolation, we do 1d-interpolation sequentially - rather than interpolating in multidimensional space. (:issue:`2223`) - By `Keisuke Fujii `_. -- :py:meth:`DataArray.reset_index` and :py:meth:`Dataset.reset_index` now keep - coordinate attributes (:pull:`4103`). By `Oriol Abril `_. -- :py:meth:`DataArray.interp` now support simple interpolation in a chunked dimension - (but not advanced interpolation) (:pull:`4155`). By `Alexandre Poux `_. - The old :py:func:`auto_combine` function has now been removed in favour of the :py:func:`combine_by_coords` and :py:func:`combine_nested` functions. This also means that @@ -168,6 +157,8 @@ Enhancements For orthogonal linear- and nearest-neighbor interpolation, we do 1d-interpolation sequentially rather than interpolating in multidimensional space. (:issue:`2223`) By `Keisuke Fujii `_. +- :py:meth:`DataArray.interp` now support some interpolations over a chunked dimension + (low order and not advanced interpolation) (:pull:`4155`). By `Alexandre Poux `_. - Major performance improvement for :py:meth:`Dataset.from_dataframe` when the dataframe has a MultiIndex (:pull:`4184`). By `Stephan Hoyer `_. From 5e26a4e034c93e0ad9353a719e0542d51dbd44ea Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 20 Jul 2020 15:04:52 +0200 Subject: [PATCH 15/42] remove a useless import --- xarray/core/missing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 6f2a2d3bb76..cab1bd52d69 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -818,8 +818,6 @@ def _dask_aware_interpnd(var, *coords, n_coords: int, interp_func, interp_kwargs TODO: find a way to use 1d coordinates """ - from .dataarray import DataArray - _old_x, _new_x = coords[:n_coords], coords[n_coords:] # reshape x (TODO REMOVE) From 3ca6e6d6e2cbfbe13ac2020b72647bb349208e83 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Tue, 21 Jul 2020 14:05:19 +0200 Subject: [PATCH 16/42] use Variable instead of InexVariable --- xarray/core/missing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index cab1bd52d69..8f501d8750a 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -13,7 +13,7 @@ from .duck_array_ops import dask_array_type, datetime_to_numeric, timedelta_to_numeric from .options import _get_keep_attrs from .utils import OrderedSet, is_scalar -from .variable import IndexVariable, Variable, broadcast_variables +from .variable import Variable, broadcast_variables def _get_nan_block_lengths(obj, dim: Hashable, index: Variable): @@ -823,7 +823,7 @@ def _dask_aware_interpnd(var, *coords, n_coords: int, interp_func, interp_kwargs # reshape x (TODO REMOVE) old_x = tuple( [ - IndexVariable( + Variable( str(dim), np.moveaxis(tmp, dim, -1)[tuple([0] * (len(tmp.shape) - 1))] ) for dim, tmp in enumerate(_old_x) From a131b21ee2800045a2991172459a25c3b093f1c5 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Tue, 21 Jul 2020 14:11:03 +0200 Subject: [PATCH 17/42] avoid some list to tuple conversion --- xarray/core/missing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 8f501d8750a..c1437c2da51 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -824,7 +824,7 @@ def _dask_aware_interpnd(var, *coords, n_coords: int, interp_func, interp_kwargs old_x = tuple( [ Variable( - str(dim), np.moveaxis(tmp, dim, -1)[tuple([0] * (len(tmp.shape) - 1))] + str(dim), np.moveaxis(tmp, dim, -1)[(0, ) * (len(tmp.shape) - 1)] ) for dim, tmp in enumerate(_old_x) ] @@ -873,7 +873,7 @@ def _compute_chunks(x, x_with_ghost, new_x): for iend, iend_with_ghost in zip(*ce): arr = np.moveaxis(new_x[dim].data, dim, -1) - arr = arr[tuple([0] * (len(arr.shape) - 1))] + arr = arr[(0, ) * (len(arr.shape) - 1)] n_no_ghost = (arr <= x[dim][iend]).sum() n_ghost = (arr <= x_with_ghost[dim][iend_with_ghost]).sum() From 67d2b360173e08c084ac84b063de87c6e5cd0b60 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Tue, 21 Jul 2020 14:12:31 +0200 Subject: [PATCH 18/42] black fix --- xarray/core/missing.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index c1437c2da51..338e563ffd9 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -823,9 +823,7 @@ def _dask_aware_interpnd(var, *coords, n_coords: int, interp_func, interp_kwargs # reshape x (TODO REMOVE) old_x = tuple( [ - Variable( - str(dim), np.moveaxis(tmp, dim, -1)[(0, ) * (len(tmp.shape) - 1)] - ) + Variable(str(dim), np.moveaxis(tmp, dim, -1)[(0,) * (len(tmp.shape) - 1)]) for dim, tmp in enumerate(_old_x) ] ) @@ -873,7 +871,7 @@ def _compute_chunks(x, x_with_ghost, new_x): for iend, iend_with_ghost in zip(*ce): arr = np.moveaxis(new_x[dim].data, dim, -1) - arr = arr[(0, ) * (len(arr.shape) - 1)] + arr = arr[(0,) * (len(arr.shape) - 1)] n_no_ghost = (arr <= x[dim][iend]).sum() n_ghost = (arr <= x_with_ghost[dim][iend_with_ghost]).sum() From f48595897d29faede9946a23c50b5d88d934f3a5 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Tue, 21 Jul 2020 16:09:40 +0200 Subject: [PATCH 19/42] more comments to explain _compute_chunks --- xarray/core/missing.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 338e563ffd9..01f01cb50ec 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -2,7 +2,7 @@ import warnings from functools import partial from numbers import Number -from typing import Any, Callable, Dict, Hashable, List, Sequence, Union +from typing import Any, Callable, Dict, Hashable, Sequence, Union import numpy as np import pandas as pd @@ -858,6 +858,8 @@ def _add_interp_ghost(var, x, nconst: int): def _compute_chunks(x, x_with_ghost, new_x): """Compute equilibrated chunks of new_x + This routine assumes that x, x_with_ghost and new_x are sorted + TODO: This only works if new_x is a set of 1d coordinate more general function is needed for advanced interpolation with chunked dimension """ @@ -867,22 +869,29 @@ def _compute_chunks(x, x_with_ghost, new_x): ] total_chunks = [] for dim, ce in enumerate(zip(chunks_end, chunks_end_with_ghost)): - l_new_x_ends: List[np.ndarray] = [] - for iend, iend_with_ghost in zip(*ce): - arr = np.moveaxis(new_x[dim].data, dim, -1) - arr = arr[(0,) * (len(arr.shape) - 1)] + # select one line along dim + line_x = new_x[dim].data[ + (0,) * dim + (slice(None),) + (0,) * (len(new_x) - dim - 1) + ] - n_no_ghost = (arr <= x[dim][iend]).sum() - n_ghost = (arr <= x_with_ghost[dim][iend_with_ghost]).sum() + # the number of chunk of the output must be the same as the input (map_blocks) + new_x_ends = np.copy(ce[0]) + for i, (iend, iend_with_ghost) in enumerate(list(zip(*ce))[:-1]): + # number of points in line_x before the end of the current chunck + # with and without overlap + n_ghost = (line_x <= x_with_ghost[dim][iend_with_ghost]).sum() + n_no_ghost = (line_x <= x[dim][iend]).sum() - equil = np.ceil(0.5 * (n_no_ghost + n_ghost)).astype(int) + # put half of the points inside the overlap on the left + # and the other half on the right + n_plus_half = np.ceil(0.5 * (n_no_ghost + n_ghost)).astype(int) - l_new_x_ends.append(equil) + new_x_ends[i] = n_plus_half - new_x_ends = np.array(l_new_x_ends) # do not forget extra points at the end - new_x_ends[-1] = len(arr) + new_x_ends[-1] = len(line_x) + chunks = new_x_ends[0], *(new_x_ends[1:] - new_x_ends[:-1]) - total_chunks.append(tuple(chunks)) + total_chunks.append(chunks) return total_chunks From 42f8a3b41b1563a4baa6ad29960442c1fbd8d912 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Fri, 24 Jul 2020 09:16:29 +0200 Subject: [PATCH 20/42] For orthogonal linear- and nearest-neighbor interpolation, the scalar interpolation can also be done sequentially --- xarray/core/missing.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 01f01cb50ec..0dfc5ebf6c1 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -626,8 +626,15 @@ def interp(var, indexes_coords, method, **kwargs): if ( len(indexes_coords) > 1 and method in ["linear", "nearest"] - and all(dest[1].ndim == 1 for dest in indexes_coords.values()) - and len(set([d[1].dims[0] for d in indexes_coords.values()])) + and all(dest[1].ndim <= 1 for dest in indexes_coords.values()) + and len( + set( + [ + dest[1].dims[0] if dest[1].ndim == 1 else d + for d, dest in indexes_coords.items() + ] + ) + ) == len(indexes_coords) ): # interpolate sequentially From ec3c400745014819ce796e3990d43b965444b9f8 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Fri, 24 Jul 2020 16:25:54 +0200 Subject: [PATCH 21/42] better detection of Advanced interpolation --- xarray/core/missing.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 0dfc5ebf6c1..3cd7618359b 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -734,6 +734,12 @@ def interp_func(var, x, new_x, method, kwargs): "Only constant or linear interpolation are available in a chunked direction" ) + for _x in new_x: + if sum([s > 1 for s in _x.shape]) > 1: + raise NotImplementedError( + "Advanced interpolation is not implemented with chunked dimension" + ) + current_dims = [_x.name for _x in x] # number of non interpolated dimensions @@ -750,10 +756,6 @@ def interp_func(var, x, new_x, method, kwargs): # compute final chunks target_dims = set.union(*[set(_x.dims) for _x in new_x]) - if target_dims - set(current_dims): - raise NotImplementedError( - "Advanced interpolation is not implemented with chunked dimension" - ) new_x = tuple([_x.set_dims(current_dims) for _x in new_x]) total_chunks = _compute_chunks(x, x_with_ghost, new_x) final_chunks = var.chunks[: -len(x)] + tuple(total_chunks) From e231954f2e994f3c403ce09a274914ddea73ec50 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Fri, 24 Jul 2020 16:27:20 +0200 Subject: [PATCH 22/42] implement support of unsorted interpolation destination --- xarray/core/missing.py | 60 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 3cd7618359b..b44c15533c7 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -13,7 +13,7 @@ from .duck_array_ops import dask_array_type, datetime_to_numeric, timedelta_to_numeric from .options import _get_keep_attrs from .utils import OrderedSet, is_scalar -from .variable import Variable, broadcast_variables +from .variable import IndexVariable, Variable, broadcast_variables def _get_nan_block_lengths(obj, dim: Hashable, index: Variable): @@ -740,10 +740,49 @@ def interp_func(var, x, new_x, method, kwargs): "Advanced interpolation is not implemented with chunked dimension" ) - current_dims = [_x.name for _x in x] - # number of non interpolated dimensions nconst = var.ndim - len(x) + const_dims = [f"dim_{i}" for i in range(nconst)] + + # list of interpolated dimensions source + interp_dims = [_x.name for _x in x] + + # mapping to the interpolated dimensions destination + # (if interpolating on a point, removes de dim) + final_dims = dict( + **{d: d for d in const_dims}, + **{dim: _x.name for dim, _x in zip(interp_dims, new_x) if _x.size > 1}, + ) + + # rename new_x to correspond to source dimension + def rename_index(index, dim): + if index.size == 1: + # a scalar has no name + return index + return IndexVariable( + dims=[dim], + data=index, + attrs=index.attrs, + encoding=index.encoding, + fastpath=True, + ) + + new_x = [rename_index(_x, dim) for dim, _x in zip(interp_dims, new_x)] + + unsorted = any((np.any(np.diff(_x.data) < 0) for _x in new_x if _x.size > 1)) + if unsorted: + sorted_idx = { + dim: _x.data.argsort() + for dim, _x in zip(interp_dims, new_x) + if _x.size > 1 + } + + new_x = [ + _x[sorted_idx[dim]] if dim in sorted_idx else _x + for dim, _x in zip(interp_dims, new_x) + ] + # add missing dimensions to new_x + new_x = tuple([_x.set_dims(interp_dims) for _x in new_x]) # chunks x x = tuple( @@ -755,8 +794,6 @@ def interp_func(var, x, new_x, method, kwargs): var_with_ghost, x_with_ghost = _add_interp_ghost(var, x, nconst) # compute final chunks - target_dims = set.union(*[set(_x.dims) for _x in new_x]) - new_x = tuple([_x.set_dims(current_dims) for _x in new_x]) total_chunks = _compute_chunks(x, x_with_ghost, new_x) final_chunks = var.chunks[: -len(x)] + tuple(total_chunks) @@ -787,6 +824,19 @@ def interp_func(var, x, new_x, method, kwargs): [tuple([chunk for chunk in chunks if chunk > 0]) for chunks in res.chunks] ) res = res.rechunk(new_chunks) + + if unsorted: + # Reorder the output + # use DataArray for isel + from .dataarray import DataArray + + res = DataArray(data=res, dims=final_dims.values()) + for dim, idx in sorted_idx.items(): + res = res.isel({final_dims[dim]: np.argsort(idx)}) + + # rechunk because out-of-order isel generates a lot of chunks + res = res.data.rechunk() + return res return _interpnd(var, x, new_x, func, kwargs) From 061f5a8fcb000e102306f6a84d5ea44309adfdd9 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Fri, 24 Jul 2020 16:50:56 +0200 Subject: [PATCH 23/42] rework the tests --- xarray/tests/test_interp.py | 137 +++++++++++++++++++++++++++--------- 1 file changed, 102 insertions(+), 35 deletions(-) diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index b86b536a4d5..b50bc3fe00f 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -1,3 +1,5 @@ +from itertools import combinations, permutations + import numpy as np import pandas as pd import pytest @@ -8,6 +10,7 @@ assert_equal, raises_regex, requires_cftime, + requires_dask, requires_scipy, ) @@ -730,46 +733,110 @@ def test_decompose(method): assert_allclose(actual, expected) -def test_interpolate_chunk_1d(): - if not has_scipy: - pytest.skip("scipy is not installed.") - - if not has_dask: - pytest.skip("dask is not installed in the environment.") - - da = get_example_data(1) - ydest = np.linspace(-0.1, 0.2, 80) - - actual = da.interp(method="linear", y=ydest) - expected = da.compute().interp(method="linear", y=ydest) +@requires_scipy +@requires_dask +@pytest.mark.parametrize( + "method", ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"] +) +@pytest.mark.parametrize("sorted", [True, False]) +@pytest.mark.parametrize( + "data_ndim,interp_ndim,nscalar", + [ + (data_ndim, interp_ndim, nscalar) + for data_ndim in range(1, 4) + for interp_ndim in range(1, data_ndim + 1) + for nscalar in range(0, interp_ndim + 1) + ], +) +def test_interpolate_chunk(method, sorted, data_ndim, interp_ndim, nscalar): + # 3d non chunked data + x = np.linspace(0, 1, 5) + y = np.linspace(2, 4, 7) + z = np.linspace(-0.5, 0.5, 11) + da = xr.DataArray( + data=np.sin(x[:, np.newaxis, np.newaxis]) + * np.cos(y[:, np.newaxis]) + * np.exp(z), + coords=[("x", x), ("y", y), ("z", z)], + ) - assert_allclose(actual, expected) + # choose the data dimensions + for data_dims in permutations(da.dims, data_ndim): + # select only data_ndim dim + da = da.isel( + {dim: len(da.coords[dim]) // 2 for dim in da.dims if dim not in data_dims} + ) -@pytest.mark.parametrize("scalar_nx", [True, False]) -def test_interpolate_chunk_nd(scalar_nx): - if not has_scipy: - pytest.skip("scipy is not installed.") + # chunk data + da = da.chunk(chunks={dim: len(da.coords[dim]) // 3 for dim in da.dims}) + + # choose the interpolation dimensions + for interp_dims in permutations(da.dims, interp_ndim): + # choose the scalar interpolation dimensions + for scalar_dims in combinations(interp_dims, nscalar): + dest = {} + for dim in interp_dims: + # choose a point between chunks + first_chunks = da.chunks[da.get_axis_num(dim)][0] + middle_point = 0.5 * ( + da.coords[dim][first_chunks - 1] + da.coords[dim][first_chunks] + ) + if dim in scalar_dims: + # choose a point between chunks + dest[dim] = middle_point + else: + # pick some points, including outside the domain and bewteen chunks + before = 2 * da.coords[dim][0] - da.coords[dim][1] + after = 2 * da.coords[dim][-1] - da.coords[dim][-2] + inside = da.coords[dim][first_chunks // 2] + + xdest = np.linspace( + inside, middle_point, 2 * (first_chunks // 2), + ) + xdest = np.concatenate([[before], xdest, [after]]) + if not sorted: + xdest = xdest.reshape((-1, 2)) + xdest[:, 1] = xdest[::-1, 1] + xdest = xdest.flatten() + dest[dim] = xdest + + if interp_ndim > 1 and method not in ["linear", "nearest"]: + # Check that an error is raised if an attempt is made to interpolate + # over a chunked dimension with high order method + with raises_regex( + ValueError, + f"{method} is not a valid interpolator for interpolating over multiple dimensions.", + ): + da.interp(method=method, **dest) + return + + if method in ["quadratic", "cubic"]: + # Check that an error is raised if an attempt is made to interpolate + # over a chunked dimension with high order method + with raises_regex( + NotImplementedError, + "Only constant or linear interpolation are available in a chunked direction", + ): + da.interp(method=method, **dest) + return + + actual = da.interp(method=method, **dest) + expected = da.compute().interp(method=method, **dest) + + assert_allclose(actual, expected) + break + break - if not has_dask: - pytest.skip("dask is not installed in the environment.") - da = get_example_data(1).chunk({"x": 50}) +@requires_scipy +@requires_dask +def test_interpolate_chunk_rename(): + da = get_example_data(0).chunk({"x": 5}) - if scalar_nx: - # 0.5 is between chunks - xdest = 0.5 - else: - # -0.5 is before data - # 0.5 is between chunks - # 1.5 is after data - xdest = [-0.5, 0.25, 0.5, 0.75, 1.5] - # -0.1 is before data - # 0.05 is between chunks - # 0.15 is after data - ydest = [-0.1, 0.025, 0.05, 0.075, 0.15] - - actual = da.interp(method="linear", x=xdest, y=ydest) - expected = da.compute().interp(method="linear", x=xdest, y=ydest) + # grid -> 1d-sample + xdest = xr.DataArray(np.linspace(0.1, 1.0, 11), dims="renamed") + actual = da.interp(x=xdest, method="linear") + expected = da.compute().interp(x=xdest, method="linear") assert_allclose(actual, expected) From 623cb0b00aca3d478cb885acc9b1316379ebd7a1 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Fri, 24 Jul 2020 17:37:40 +0200 Subject: [PATCH 24/42] fix for datetime index (bug introduced with unsorted destination) --- xarray/core/missing.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index b44c15533c7..248c679d5cd 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -566,6 +566,16 @@ def _localize(var, indexes_coords): return var.isel(**indexes), indexes_coords +def _floatize_one_x(x): + """ Make x float. + This is particulary useful for datetime dtype. + x: np.ndarray + """ + if _contains_datetime_like_objects(x): + return x._to_numeric(dtype=np.float64) + return x + + def _floatize_x(x, new_x): """ Make x and new_x float. This is particulary useful for datetime dtype. @@ -769,11 +779,14 @@ def rename_index(index, dim): new_x = [rename_index(_x, dim) for dim, _x in zip(interp_dims, new_x)] - unsorted = any((np.any(np.diff(_x.data) < 0) for _x in new_x if _x.size > 1)) + new_x_float = [_floatize_one_x(_x) for _x in new_x] + unsorted = any( + (np.any(np.diff(_x.data) < 0) for _x in new_x_float if _x.size > 1) + ) if unsorted: sorted_idx = { dim: _x.data.argsort() - for dim, _x in zip(interp_dims, new_x) + for dim, _x in zip(interp_dims, new_x_float) if _x.size > 1 } From b66d12353a9247195d9b8ae095e54aef7bcc35b2 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Fri, 24 Jul 2020 17:54:42 +0200 Subject: [PATCH 25/42] Variable is cheaber that DataArray --- xarray/core/missing.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 248c679d5cd..108c3cc95b3 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -840,10 +840,9 @@ def rename_index(index, dim): if unsorted: # Reorder the output - # use DataArray for isel - from .dataarray import DataArray + # use Variable for isel - res = DataArray(data=res, dims=final_dims.values()) + res = Variable(data=res, dims=final_dims.values()) for dim, idx in sorted_idx.items(): res = res.isel({final_dims[dim]: np.argsort(idx)}) From e211127bde26ea83eb7d4e2f5e7b073a38cebf2a Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 27 Jul 2020 11:38:16 +0200 Subject: [PATCH 26/42] add warning if unsorted --- xarray/core/missing.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 108c3cc95b3..5f275a312bf 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -784,6 +784,10 @@ def rename_index(index, dim): (np.any(np.diff(_x.data) < 0) for _x in new_x_float if _x.size > 1) ) if unsorted: + warnings.warn( + "Interpolating to unsorted destination will rechunk the result", + da.PerformanceWarning) + sorted_idx = { dim: _x.data.argsort() for dim, _x in zip(interp_dims, new_x_float) @@ -843,8 +847,10 @@ def rename_index(index, dim): # use Variable for isel res = Variable(data=res, dims=final_dims.values()) - for dim, idx in sorted_idx.items(): - res = res.isel({final_dims[dim]: np.argsort(idx)}) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "out-of-order", da.PerformanceWarning) + for dim, idx in sorted_idx.items(): + res = res.isel({final_dims[dim]: np.argsort(idx)}) # rechunk because out-of-order isel generates a lot of chunks res = res.data.rechunk() From e610268eb7bf935f56541a9e2a53e2fc0eb4bf70 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 27 Jul 2020 11:39:21 +0200 Subject: [PATCH 27/42] simplify _compute_chunks --- xarray/core/missing.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 5f275a312bf..10de05b32d1 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -811,7 +811,7 @@ def rename_index(index, dim): var_with_ghost, x_with_ghost = _add_interp_ghost(var, x, nconst) # compute final chunks - total_chunks = _compute_chunks(x, x_with_ghost, new_x) + total_chunks = _compute_chunks(x, new_x) final_chunks = var.chunks[: -len(x)] + tuple(total_chunks) # chunks new_x @@ -932,7 +932,7 @@ def _add_interp_ghost(var, x, nconst: int): return var_with_ghost, x_with_ghost -def _compute_chunks(x, x_with_ghost, new_x): +def _compute_chunks(x, new_x): """Compute equilibrated chunks of new_x This routine assumes that x, x_with_ghost and new_x are sorted @@ -940,12 +940,9 @@ def _compute_chunks(x, x_with_ghost, new_x): TODO: This only works if new_x is a set of 1d coordinate more general function is needed for advanced interpolation with chunked dimension """ - chunks_end = [np.cumsum(sizes) - 1 for _x in x for sizes in _x.chunks] - chunks_end_with_ghost = [ - np.cumsum(sizes) - 1 for _x in x_with_ghost for sizes in _x.chunks - ] + chunks_ends = [np.cumsum(sizes) - 1 for _x in x for sizes in _x.chunks] total_chunks = [] - for dim, ce in enumerate(zip(chunks_end, chunks_end_with_ghost)): + for dim, chunk_ends in enumerate(chunks_ends): # select one line along dim line_x = new_x[dim].data[ @@ -953,18 +950,16 @@ def _compute_chunks(x, x_with_ghost, new_x): ] # the number of chunk of the output must be the same as the input (map_blocks) - new_x_ends = np.copy(ce[0]) - for i, (iend, iend_with_ghost) in enumerate(list(zip(*ce))[:-1]): + new_x_ends = np.copy(chunk_ends) + for i, chunk_end in enumerate(chunk_ends[:-1]): # number of points in line_x before the end of the current chunck - # with and without overlap - n_ghost = (line_x <= x_with_ghost[dim][iend_with_ghost]).sum() - n_no_ghost = (line_x <= x[dim][iend]).sum() - - # put half of the points inside the overlap on the left - # and the other half on the right - n_plus_half = np.ceil(0.5 * (n_no_ghost + n_ghost)).astype(int) + n_end = (line_x <= x[dim][chunk_end]).sum() + # number of points in line_x before the start of the next chunck + n_start = (line_x <= x[dim][chunk_end+1]).sum() - new_x_ends[i] = n_plus_half + # put half of the points between two consecutive chunk + # on the left and the other half on the right + new_x_ends[i] = np.ceil(0.5 * (n_end + n_start)).astype(int) # do not forget extra points at the end new_x_ends[-1] = len(line_x) From 7547d56482fef45429606714b7cac1567b9c4388 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 27 Jul 2020 11:41:22 +0200 Subject: [PATCH 28/42] add ghosts point in order to make quadratic and cubic method work in a chunked direction --- xarray/core/missing.py | 30 +++++++++++++++++++----------- xarray/tests/test_interp.py | 20 -------------------- 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 10de05b32d1..e4b7dda495a 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -739,11 +739,6 @@ def interp_func(var, x, new_x, method, kwargs): drop_axis=drop_axis, ) - if method in ["quadratic", "cubic"]: - raise NotImplementedError( - "Only constant or linear interpolation are available in a chunked direction" - ) - for _x in new_x: if sum([s > 1 for s in _x.shape]) > 1: raise NotImplementedError( @@ -808,7 +803,8 @@ def rename_index(index, dim): ) # duplicate the ghost cells of the array in the interpolated dimensions - var_with_ghost, x_with_ghost = _add_interp_ghost(var, x, nconst) + depth = {"quadratic": 2, "cubic": 3}.get(method, 1) + var, x, var_with_ghost, x_with_ghost = _add_interp_ghost(var, x, nconst, depth=depth) # compute final chunks total_chunks = _compute_chunks(x, new_x) @@ -917,19 +913,31 @@ def _dask_aware_interpnd(var, *coords, n_coords: int, interp_func, interp_kwargs return _interpnd(var, old_x, new_x, interp_func, interp_kwargs) -def _add_interp_ghost(var, x, nconst: int): +def _add_interp_ghost(var, x, nconst: int, depth=1): """ Duplicate the ghost cells of the array (values and coordinates)""" import dask.array as da bnd = {i: "none" for i in range(len(var.shape))} - depth = {i: 0 if i < nconst else 1 for i in range(len(var.shape))} + depths = {i: 0 if i < nconst else depth for i in range(len(var.shape))} + + minchunk = min((min(chunks) for chunks in var.chunks[nconst:])) + if minchunk < depth: + warnings.warn( + "Chunks are too small to interpolate, rechunking.", + da.PerformanceWarning) + var = var.rechunk() + # rechunks x + x = tuple( + _x.rechunk(chunks) + for _x, chunks in zip(x, var.chunks[nconst:]) + ) - var_with_ghost = da.overlap.overlap(var, depth=depth, boundary=bnd) + var_with_ghost = da.overlap.overlap(var, depth=depths, boundary=bnd) x_with_ghost = tuple( - da.overlap.overlap(_x, depth={0: 1}, boundary={0: "none"}) for _x in x + da.overlap.overlap(_x, depth={0: depth}, boundary={0: "none"}) for _x in x ) - return var_with_ghost, x_with_ghost + return var, x, var_with_ghost, x_with_ghost def _compute_chunks(x, new_x): diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index b50bc3fe00f..5e6982ce6e5 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -73,16 +73,6 @@ def test_interpolate_1d(method, dim, case): da = get_example_data(case) xdest = np.linspace(0.0, 0.9, 80) - if method == "cubic" and dim == "y" and case == 1: - # Check that an error is raised if an attempt is made to interpolate - # over a chunked dimension with high order method - with raises_regex( - NotImplementedError, - "Only constant or linear interpolation are available in a chunked direction", - ): - da.interp(method=method, **{dim: xdest}) - return - actual = da.interp(method=method, **{dim: xdest}) # scipy interpolation for the reference @@ -811,16 +801,6 @@ def test_interpolate_chunk(method, sorted, data_ndim, interp_ndim, nscalar): da.interp(method=method, **dest) return - if method in ["quadratic", "cubic"]: - # Check that an error is raised if an attempt is made to interpolate - # over a chunked dimension with high order method - with raises_regex( - NotImplementedError, - "Only constant or linear interpolation are available in a chunked direction", - ): - da.interp(method=method, **dest) - return - actual = da.interp(method=method, **dest) expected = da.compute().interp(method=method, **dest) From fd936dd14066f1c7b0d7c1f72c4a6bb93b8619b2 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 27 Jul 2020 11:43:20 +0200 Subject: [PATCH 29/42] black --- xarray/core/missing.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index e4b7dda495a..070c7d49dd0 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -780,8 +780,9 @@ def rename_index(index, dim): ) if unsorted: warnings.warn( - "Interpolating to unsorted destination will rechunk the result", - da.PerformanceWarning) + "Interpolating to unsorted destination will rechunk the result", + da.PerformanceWarning, + ) sorted_idx = { dim: _x.data.argsort() @@ -804,7 +805,9 @@ def rename_index(index, dim): # duplicate the ghost cells of the array in the interpolated dimensions depth = {"quadratic": 2, "cubic": 3}.get(method, 1) - var, x, var_with_ghost, x_with_ghost = _add_interp_ghost(var, x, nconst, depth=depth) + var, x, var_with_ghost, x_with_ghost = _add_interp_ghost( + var, x, nconst, depth=depth + ) # compute final chunks total_chunks = _compute_chunks(x, new_x) @@ -923,14 +926,11 @@ def _add_interp_ghost(var, x, nconst: int, depth=1): minchunk = min((min(chunks) for chunks in var.chunks[nconst:])) if minchunk < depth: warnings.warn( - "Chunks are too small to interpolate, rechunking.", - da.PerformanceWarning) + "Chunks are too small to interpolate, rechunking.", da.PerformanceWarning + ) var = var.rechunk() # rechunks x - x = tuple( - _x.rechunk(chunks) - for _x, chunks in zip(x, var.chunks[nconst:]) - ) + x = tuple(_x.rechunk(chunks) for _x, chunks in zip(x, var.chunks[nconst:])) var_with_ghost = da.overlap.overlap(var, depth=depths, boundary=bnd) @@ -963,7 +963,7 @@ def _compute_chunks(x, new_x): # number of points in line_x before the end of the current chunck n_end = (line_x <= x[dim][chunk_end]).sum() # number of points in line_x before the start of the next chunck - n_start = (line_x <= x[dim][chunk_end+1]).sum() + n_start = (line_x <= x[dim][chunk_end + 1]).sum() # put half of the points between two consecutive chunk # on the left and the other half on the right From 24f9460eb5fd09f6fe51eb121f5401ee7c06c492 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Mon, 27 Jul 2020 15:23:33 +0200 Subject: [PATCH 30/42] forgot to remove an exception in test_upsample_interpolate_dask --- xarray/tests/test_dataarray.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index ffd245d5b1d..a56d9fabdc4 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3169,16 +3169,6 @@ def test_upsample_interpolate_dask(self, chunked_time): for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: actual = array.chunk(chunks).resample(time="1H") - if chunked_time and (kind in ["quadratic", "cubic"]): - # Check that an error is raised if an attempt is made to interpolate - # over a chunked dimension with high order method - with raises_regex( - NotImplementedError, - "Only constant or linear interpolation are available in a chunked direction", - ): - actual.interpolate(kind) - continue - actual = actual.interpolate(kind) actual = actual.compute() f = interp1d( From dd2f2735cd110c2cb67c4f5e099a78543eb29a21 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Tue, 28 Jul 2020 13:58:55 +0200 Subject: [PATCH 31/42] fix filtering out-of-order warning --- xarray/core/missing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 070c7d49dd0..d7666de6b92 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -847,7 +847,11 @@ def rename_index(index, dim): res = Variable(data=res, dims=final_dims.values()) with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "out-of-order", da.PerformanceWarning) + warnings.filterwarnings( + "ignore", + r"Slicing with an out-of-order index is generating \d+ times more chunks", + da.PerformanceWarning, + ) for dim, idx in sorted_idx.items(): res = res.isel({final_dims[dim]: np.argsort(idx)}) From 49bdefafbf1560b12032a7c5caa6047f8fcaaa49 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Tue, 28 Jul 2020 13:59:30 +0200 Subject: [PATCH 32/42] use extrapolate to check external points --- xarray/tests/test_interp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 5e6982ce6e5..c21b6789014 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -749,6 +749,7 @@ def test_interpolate_chunk(method, sorted, data_ndim, interp_ndim, nscalar): * np.exp(z), coords=[("x", x), ("y", y), ("z", z)], ) + kwargs = {"fill_value": "extrapolate"} # choose the data dimensions for data_dims in permutations(da.dims, data_ndim): @@ -801,8 +802,8 @@ def test_interpolate_chunk(method, sorted, data_ndim, interp_ndim, nscalar): da.interp(method=method, **dest) return - actual = da.interp(method=method, **dest) - expected = da.compute().interp(method=method, **dest) + actual = da.interp(method=method, **dest, kwargs=kwargs) + expected = da.compute().interp(method=method, **dest, kwargs=kwargs) assert_allclose(actual, expected) break From d28086764359813a581472f171c327fe74f63d14 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Wed, 29 Jul 2020 10:26:52 +0200 Subject: [PATCH 33/42] Revert "add ghosts point in order to make quadratic and cubic method work in a chunked direction" --- xarray/core/missing.py | 27 ++++++++++----------------- xarray/tests/test_dataarray.py | 10 ++++++++++ xarray/tests/test_interp.py | 20 ++++++++++++++++++++ 3 files changed, 40 insertions(+), 17 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index d7666de6b92..824a3aa60d3 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -739,6 +739,11 @@ def interp_func(var, x, new_x, method, kwargs): drop_axis=drop_axis, ) + if method in ["quadratic", "cubic"]: + raise ValueError( + "Only constant or linear interpolation are possible in a chunked direction" + ) + for _x in new_x: if sum([s > 1 for s in _x.shape]) > 1: raise NotImplementedError( @@ -804,10 +809,7 @@ def rename_index(index, dim): ) # duplicate the ghost cells of the array in the interpolated dimensions - depth = {"quadratic": 2, "cubic": 3}.get(method, 1) - var, x, var_with_ghost, x_with_ghost = _add_interp_ghost( - var, x, nconst, depth=depth - ) + var_with_ghost, x_with_ghost = _add_interp_ghost(var, x, nconst) # compute final chunks total_chunks = _compute_chunks(x, new_x) @@ -920,28 +922,19 @@ def _dask_aware_interpnd(var, *coords, n_coords: int, interp_func, interp_kwargs return _interpnd(var, old_x, new_x, interp_func, interp_kwargs) -def _add_interp_ghost(var, x, nconst: int, depth=1): +def _add_interp_ghost(var, x, nconst: int): """ Duplicate the ghost cells of the array (values and coordinates)""" import dask.array as da bnd = {i: "none" for i in range(len(var.shape))} - depths = {i: 0 if i < nconst else depth for i in range(len(var.shape))} - - minchunk = min((min(chunks) for chunks in var.chunks[nconst:])) - if minchunk < depth: - warnings.warn( - "Chunks are too small to interpolate, rechunking.", da.PerformanceWarning - ) - var = var.rechunk() - # rechunks x - x = tuple(_x.rechunk(chunks) for _x, chunks in zip(x, var.chunks[nconst:])) + depths = {i: 0 if i < nconst else 1 for i in range(len(var.shape))} var_with_ghost = da.overlap.overlap(var, depth=depths, boundary=bnd) x_with_ghost = tuple( - da.overlap.overlap(_x, depth={0: depth}, boundary={0: "none"}) for _x in x + da.overlap.overlap(_x, depth={0: 1}, boundary={0: "none"}) for _x in x ) - return var, x, var_with_ghost, x_with_ghost + return var_with_ghost, x_with_ghost def _compute_chunks(x, new_x): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index a56d9fabdc4..793d1dbb463 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3169,6 +3169,16 @@ def test_upsample_interpolate_dask(self, chunked_time): for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: actual = array.chunk(chunks).resample(time="1H") + if chunked_time and (kind in ["quadratic", "cubic"]): + # Check that an error is raised if an attempt is made to interpolate + # over a chunked dimension with high order method + with raises_regex( + ValueError, + "Only constant or linear interpolation are possible in a chunked direction", + ): + actual.interpolate(kind) + continue + actual = actual.interpolate(kind) actual = actual.compute() f = interp1d( diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index c21b6789014..1ffb56b53d6 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -73,6 +73,16 @@ def test_interpolate_1d(method, dim, case): da = get_example_data(case) xdest = np.linspace(0.0, 0.9, 80) + if method == "cubic" and dim == "y" and case == 1: + # Check that an error is raised if an attempt is made to interpolate + # over a chunked dimension with high order method + with raises_regex( + ValueError, + "Only constant or linear interpolation are possible in a chunked direction", + ): + da.interp(method=method, **{dim: xdest}) + return + actual = da.interp(method=method, **{dim: xdest}) # scipy interpolation for the reference @@ -802,6 +812,16 @@ def test_interpolate_chunk(method, sorted, data_ndim, interp_ndim, nscalar): da.interp(method=method, **dest) return + if method in ["quadratic", "cubic"]: + # Check that an error is raised if an attempt is made to interpolate + # over a chunked dimension with high order method + with raises_regex( + ValueError, + "Only constant or linear interpolation are possible in a chunked direction", + ): + da.interp(method=method, **dest) + return + actual = da.interp(method=method, **dest, kwargs=kwargs) expected = da.compute().interp(method=method, **dest, kwargs=kwargs) From aeb7be1df6ad123c968e415a5996ed65a7a58266 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Wed, 29 Jul 2020 15:25:51 +0200 Subject: [PATCH 34/42] Complete rewrite using blockwise --- xarray/core/missing.py | 273 ++++++--------------------------- xarray/tests/test_dataarray.py | 11 -- xarray/tests/test_interp.py | 64 +++++--- 3 files changed, 90 insertions(+), 258 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 824a3aa60d3..9b20dba0a28 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -13,7 +13,7 @@ from .duck_array_ops import dask_array_type, datetime_to_numeric, timedelta_to_numeric from .options import _get_keep_attrs from .utils import OrderedSet, is_scalar -from .variable import IndexVariable, Variable, broadcast_variables +from .variable import Variable, broadcast_variables def _get_nan_block_lengths(obj, dim: Hashable, index: Variable): @@ -544,13 +544,6 @@ def _get_valid_fill_mask(arr, dim, limit): ) <= limit -def _single_chunk(var, axes): - for axis in axes: - if len(var.chunks[axis]) > 1 or var.chunks[axis][0] < var.shape[axis]: - return False - return True - - def _localize(var, indexes_coords): """ Speed up for linear and nearest neighbor method. Only consider a subspace that is needed for the interpolation @@ -566,16 +559,6 @@ def _localize(var, indexes_coords): return var.isel(**indexes), indexes_coords -def _floatize_one_x(x): - """ Make x float. - This is particulary useful for datetime dtype. - x: np.ndarray - """ - if _contains_datetime_like_objects(x): - return x._to_numeric(dtype=np.float64) - return x - - def _floatize_x(x, new_x): """ Make x and new_x float. This is particulary useful for datetime dtype. @@ -721,147 +704,41 @@ def interp_func(var, x, new_x, method, kwargs): if isinstance(var, dask_array_type): import dask.array as da - # easyer, and allows advanced interpolation - if _single_chunk(var, range(var.ndim - len(x), var.ndim)): - chunks = var.chunks[: -len(x)] + new_x[0].shape - drop_axis = range(var.ndim - len(x), var.ndim) - new_axis = range(var.ndim - len(x), var.ndim - len(x) + new_x[0].ndim) - return da.map_blocks( - _interpnd, - var, - x, - new_x, - func, - kwargs, - dtype=var.dtype, - chunks=chunks, - new_axis=new_axis, - drop_axis=drop_axis, - ) - - if method in ["quadratic", "cubic"]: - raise ValueError( - "Only constant or linear interpolation are possible in a chunked direction" - ) - - for _x in new_x: - if sum([s > 1 for s in _x.shape]) > 1: - raise NotImplementedError( - "Advanced interpolation is not implemented with chunked dimension" - ) - - # number of non interpolated dimensions nconst = var.ndim - len(x) - const_dims = [f"dim_{i}" for i in range(nconst)] - - # list of interpolated dimensions source - interp_dims = [_x.name for _x in x] - - # mapping to the interpolated dimensions destination - # (if interpolating on a point, removes de dim) - final_dims = dict( - **{d: d for d in const_dims}, - **{dim: _x.name for dim, _x in zip(interp_dims, new_x) if _x.size > 1}, - ) - # rename new_x to correspond to source dimension - def rename_index(index, dim): - if index.size == 1: - # a scalar has no name - return index - return IndexVariable( - dims=[dim], - data=index, - attrs=index.attrs, - encoding=index.encoding, - fastpath=True, - ) - - new_x = [rename_index(_x, dim) for dim, _x in zip(interp_dims, new_x)] - - new_x_float = [_floatize_one_x(_x) for _x in new_x] - unsorted = any( - (np.any(np.diff(_x.data) < 0) for _x in new_x_float if _x.size > 1) - ) - if unsorted: - warnings.warn( - "Interpolating to unsorted destination will rechunk the result", - da.PerformanceWarning, - ) - - sorted_idx = { - dim: _x.data.argsort() - for dim, _x in zip(interp_dims, new_x_float) - if _x.size > 1 - } - - new_x = [ - _x[sorted_idx[dim]] if dim in sorted_idx else _x - for dim, _x in zip(interp_dims, new_x) - ] - # add missing dimensions to new_x - new_x = tuple([_x.set_dims(interp_dims) for _x in new_x]) - - # chunks x - x = tuple( - da.from_array(_x, chunks=chunks) - for _x, chunks in zip(x, var.chunks[nconst:]) - ) - - # duplicate the ghost cells of the array in the interpolated dimensions - var_with_ghost, x_with_ghost = _add_interp_ghost(var, x, nconst) - - # compute final chunks - total_chunks = _compute_chunks(x, new_x) - final_chunks = var.chunks[: -len(x)] + tuple(total_chunks) - - # chunks new_x - new_x = tuple(da.from_array(_x, chunks=total_chunks) for _x in new_x) - - # reshape x_with_ghost - # TODO: remove it (see _dask_aware_interpnd) - x_with_ghost = da.meshgrid(*x_with_ghost, indexing="ij") + out_ind = list(range(nconst)) + list(range(var.ndim, var.ndim + new_x[0].ndim)) + new_axes = { + var.ndim + i: new_x[0].chunks[i] + if new_x[0].chunks is not None + else new_x[0].shape[i] + for i in range(new_x[0].ndim) + } + + # blockwise args format + x_arginds = [[_x, (nconst + index,)] for index, _x in enumerate(x)] + x_arginds = [item for pair in x_arginds for item in pair] + new_x_arginds = [ + [_x, [var.ndim + index for index in range(_x.ndim)]] for _x in new_x + ] + new_x_arginds = [item for pair in new_x_arginds for item in pair] - # compute on chunks - res = da.map_blocks( + return da.blockwise( _dask_aware_interpnd, - var_with_ghost, - *x_with_ghost, - *new_x, + out_ind, + var, + range(var.ndim), + *x_arginds, + *new_x_arginds, + n_x=var.ndim - nconst, interp_func=func, interp_kwargs=kwargs, - n_coords=len(x_with_ghost), + method=method, + concatenate=True, + new_axes=new_axes, dtype=var.dtype, - chunks=final_chunks, + meta=np.ndarray, ) - # reshape res and remove empty chunks - # TODO: remove it by using drop_axis and new_axis in map_blocks - res = res.squeeze() - new_chunks = tuple( - [tuple([chunk for chunk in chunks if chunk > 0]) for chunks in res.chunks] - ) - res = res.rechunk(new_chunks) - - if unsorted: - # Reorder the output - # use Variable for isel - - res = Variable(data=res, dims=final_dims.values()) - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - r"Slicing with an out-of-order index is generating \d+ times more chunks", - da.PerformanceWarning, - ) - for dim, idx in sorted_idx.items(): - res = res.isel({final_dims[dim]: np.argsort(idx)}) - - # rechunk because out-of-order isel generates a lot of chunks - res = res.data.rechunk() - - return res - return _interpnd(var, x, new_x, func, kwargs) @@ -892,83 +769,27 @@ def _interpnd(var, x, new_x, func, kwargs): return rslt.reshape(rslt.shape[:-1] + new_x[0].shape) -def _dask_aware_interpnd(var, *coords, n_coords: int, interp_func, interp_kwargs): - """Wrapper for `_interpnd` allowing dask array to be used in `map_blocks` +def _dask_aware_interpnd(var, *coords, n_x: int, interp_func, interp_kwargs, method): + """Wrapper for `_interpnd` through `blockwise` - The first `n_coords` arrays in `coords` are original coordinates, the rest are destination coordinate - Currently this need original coordinate to be full arrays (meshgrid) - - TODO: find a way to use 1d coordinates + The first `n_x` arrays in `coords` are original coordinates, + the `n_x` others (the rest) are destination coordinates """ - _old_x, _new_x = coords[:n_coords], coords[n_coords:] - - # reshape x (TODO REMOVE) - old_x = tuple( - [ - Variable(str(dim), np.moveaxis(tmp, dim, -1)[(0,) * (len(tmp.shape) - 1)]) - for dim, tmp in enumerate(_old_x) - ] - ) - - new_x = tuple( - [ - Variable( - [f"{outer_dim}{inner_dim}" for inner_dim in range(len(_x.shape))], _x - ) - for outer_dim, _x in enumerate(_new_x) - ] - ) - - return _interpnd(var, old_x, new_x, interp_func, interp_kwargs) - - -def _add_interp_ghost(var, x, nconst: int): - """ Duplicate the ghost cells of the array (values and coordinates)""" - import dask.array as da - - bnd = {i: "none" for i in range(len(var.shape))} - depths = {i: 0 if i < nconst else 1 for i in range(len(var.shape))} - - var_with_ghost = da.overlap.overlap(var, depth=depths, boundary=bnd) - - x_with_ghost = tuple( - da.overlap.overlap(_x, depth={0: 1}, boundary={0: "none"}) for _x in x - ) - return var_with_ghost, x_with_ghost - - -def _compute_chunks(x, new_x): - """Compute equilibrated chunks of new_x - - This routine assumes that x, x_with_ghost and new_x are sorted - - TODO: This only works if new_x is a set of 1d coordinate - more general function is needed for advanced interpolation with chunked dimension - """ - chunks_ends = [np.cumsum(sizes) - 1 for _x in x for sizes in _x.chunks] - total_chunks = [] - for dim, chunk_ends in enumerate(chunks_ends): - - # select one line along dim - line_x = new_x[dim].data[ - (0,) * dim + (slice(None),) + (0,) * (len(new_x) - dim - 1) - ] - - # the number of chunk of the output must be the same as the input (map_blocks) - new_x_ends = np.copy(chunk_ends) - for i, chunk_end in enumerate(chunk_ends[:-1]): - # number of points in line_x before the end of the current chunck - n_end = (line_x <= x[dim][chunk_end]).sum() - # number of points in line_x before the start of the next chunck - n_start = (line_x <= x[dim][chunk_end + 1]).sum() - - # put half of the points between two consecutive chunk - # on the left and the other half on the right - new_x_ends[i] = np.ceil(0.5 * (n_end + n_start)).astype(int) + # Convert all to Variable, in order to use _localize + nconst = len(var.shape) - n_x + x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])] + var = Variable([f"dim_{dim}" for dim in range(len(var.shape))], var) + new_x = [ + Variable( + [f"dim_{outer_dim}_{inner_dim}" for inner_dim in range(len(_x.shape))], _x + ) + for outer_dim, _x in enumerate(coords[n_x:]) + ] + indexes_coords = {_x.dims[0]: (_x, _new_x) for _x, _new_x in zip(x, new_x)} - # do not forget extra points at the end - new_x_ends[-1] = len(line_x) + # simple speed up for the local interpolation + if method in ["linear", "nearest"]: + var, indexes_coords = _localize(var, indexes_coords) + localized_x, localized_new_x = zip(*[indexes_coords[d] for d in indexes_coords]) - chunks = new_x_ends[0], *(new_x_ends[1:] - new_x_ends[:-1]) - total_chunks.append(chunks) - return total_chunks + return _interpnd(var.data, localized_x, localized_new_x, interp_func, interp_kwargs) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 793d1dbb463..66ed35ef83a 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3168,17 +3168,6 @@ def test_upsample_interpolate_dask(self, chunked_time): new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: actual = array.chunk(chunks).resample(time="1H") - - if chunked_time and (kind in ["quadratic", "cubic"]): - # Check that an error is raised if an attempt is made to interpolate - # over a chunked dimension with high order method - with raises_regex( - ValueError, - "Only constant or linear interpolation are possible in a chunked direction", - ): - actual.interpolate(kind) - continue - actual = actual.interpolate(kind) actual = actual.compute() f = interp1d( diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 1ffb56b53d6..a47f27729f6 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -72,17 +72,6 @@ def test_interpolate_1d(method, dim, case): da = get_example_data(case) xdest = np.linspace(0.0, 0.9, 80) - - if method == "cubic" and dim == "y" and case == 1: - # Check that an error is raised if an attempt is made to interpolate - # over a chunked dimension with high order method - with raises_regex( - ValueError, - "Only constant or linear interpolation are possible in a chunked direction", - ): - da.interp(method=method, **{dim: xdest}) - return - actual = da.interp(method=method, **{dim: xdest}) # scipy interpolation for the reference @@ -812,16 +801,6 @@ def test_interpolate_chunk(method, sorted, data_ndim, interp_ndim, nscalar): da.interp(method=method, **dest) return - if method in ["quadratic", "cubic"]: - # Check that an error is raised if an attempt is made to interpolate - # over a chunked dimension with high order method - with raises_regex( - ValueError, - "Only constant or linear interpolation are possible in a chunked direction", - ): - da.interp(method=method, **dest) - return - actual = da.interp(method=method, **dest, kwargs=kwargs) expected = da.compute().interp(method=method, **dest, kwargs=kwargs) @@ -841,3 +820,46 @@ def test_interpolate_chunk_rename(): expected = da.compute().interp(x=xdest, method="linear") assert_allclose(actual, expected) + + +@requires_scipy +@requires_dask +def test_interpolate_chunk_advanced(): + """Interpolate nd array with an nd indexer sharing coordinates.""" + # Create original array + x = np.linspace(-1, 1, 5) + y = np.linspace(-1, 1, 7) + z = np.linspace(-1, 1, 11) + t = np.linspace(0, 1, 13) + q = np.linspace(0, 1, 17) + da = xr.DataArray( + data=np.sin(x[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis]) + * np.cos(y[:, np.newaxis, np.newaxis, np.newaxis]) + * np.exp(z[:, np.newaxis, np.newaxis]) + * t[:, np.newaxis] + + q, + dims=("x", "y", "z", "t", "q"), + coords={"x": x, "y": y, "z": z, "t": t, "q": q, "label": "toto"}, + ) + + theta = np.linspace(0, 2 * np.pi, 19) + w = np.linspace(-0.25, 0.25, 23) + + r = xr.DataArray( + data=1 + w[:, np.newaxis] * np.cos(theta), coords=[("w", w), ("theta", theta)], + ) + x = r * np.cos(theta) + y = r * np.sin(theta) + z = xr.DataArray( + data=w[:, np.newaxis] * np.sin(theta), coords=[("w", w), ("theta", theta)], + ) + + kwargs = {"fill_value": None} + # Create indexer into `a` with dimensions (y, x) + expected = da.interp(x=x, y=y, z=z, t=0.5, kwargs=kwargs, method="linear") + da = da.chunk(2) + x = x.chunk(2) + # y = y.chunk(2) + z = z.chunk(2) + actual = da.interp(x=x, y=y, z=z, t=0.5, kwargs=kwargs, method="linear") + assert_allclose(actual, expected) From 0bc35d232ca7de009b90c41260c4bf5a185b3977 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Wed, 29 Jul 2020 15:35:37 +0200 Subject: [PATCH 35/42] update whats-new.rst --- doc/whats-new.rst | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d0da7205c69..a837264b937 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -162,11 +162,10 @@ New Features Enhancements ~~~~~~~~~~~~ - Performance improvement of :py:meth:`DataArray.interp` and :py:func:`Dataset.interp` - For orthogonal linear- and nearest-neighbor interpolation, we do 1d-interpolation sequentially + For orthogonal linear- and nearest-neighbor interpolation, we do 0d- and 1d-interpolation sequentially rather than interpolating in multidimensional space. (:issue:`2223`) By `Keisuke Fujii `_. -- :py:meth:`DataArray.interp` now support some interpolations over a chunked dimension - (low order and not advanced interpolation) (:pull:`4155`). By `Alexandre Poux `_. +- :py:meth:`DataArray.interp` now support some interpolations over a chunked dimension (:pull:`4155`). By `Alexandre Poux `_. - Major performance improvement for :py:meth:`Dataset.from_dataframe` when the dataframe has a MultiIndex (:pull:`4184`). By `Stephan Hoyer `_. From 0d5f6185c530220b532c299755866adce99d7c70 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Wed, 29 Jul 2020 15:36:04 +0200 Subject: [PATCH 36/42] reduce the diff --- xarray/tests/test_dataarray.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 66ed35ef83a..d7e88735fbf 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3167,8 +3167,7 @@ def test_upsample_interpolate_dask(self, chunked_time): # to 1 hour up-sampling new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: - actual = array.chunk(chunks).resample(time="1H") - actual = actual.interpolate(kind) + actual = array.chunk(chunks).resample(time="1H").interpolate(kind) actual = actual.compute() f = interp1d( np.arange(len(times)), From 290a07520fbd4ed74397a8d1d37f3ecf683350e1 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Wed, 29 Jul 2020 17:56:59 +0200 Subject: [PATCH 37/42] more decomposition of orthogonal interpolation --- xarray/core/missing.py | 106 +++++++++++++++++++++--------------- xarray/tests/test_interp.py | 13 ----- 2 files changed, 62 insertions(+), 57 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 9b20dba0a28..763b5203a84 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -608,56 +608,42 @@ def interp(var, indexes_coords, method, **kwargs): if not indexes_coords: return var.copy() - # simple speed up for the local interpolation - if method in ["linear", "nearest"]: - var, indexes_coords = _localize(var, indexes_coords) - # default behavior kwargs["bounds_error"] = kwargs.get("bounds_error", False) - # check if the interpolation can be done in orthogonal manner - if ( - len(indexes_coords) > 1 - and method in ["linear", "nearest"] - and all(dest[1].ndim <= 1 for dest in indexes_coords.values()) - and len( - set( - [ - dest[1].dims[0] if dest[1].ndim == 1 else d - for d, dest in indexes_coords.items() - ] - ) + result = var + # decompose the interpolation into a succession of independant interpolation + for indexes_coords in decompose_interp(indexes_coords): + var = result + + # simple speed up for the local interpolation + if method in ["linear", "nearest"]: + var, indexes_coords = _localize(var, indexes_coords) + + # target dimensions + dims = list(indexes_coords) + x, new_x = zip(*[indexes_coords[d] for d in dims]) + destination = broadcast_variables(*new_x) + + # transpose to make the interpolated axis to the last position + broadcast_dims = [d for d in var.dims if d not in dims] + original_dims = broadcast_dims + dims + new_dims = broadcast_dims + list(destination[0].dims) + interped = interp_func( + var.transpose(*original_dims).data, x, destination, method, kwargs ) - == len(indexes_coords) - ): - # interpolate sequentially - for dim, dest in indexes_coords.items(): - var = interp(var, {dim: dest}, method, **kwargs) - return var - - # target dimensions - dims = list(indexes_coords) - x, new_x = zip(*[indexes_coords[d] for d in dims]) - destination = broadcast_variables(*new_x) - - # transpose to make the interpolated axis to the last position - broadcast_dims = [d for d in var.dims if d not in dims] - original_dims = broadcast_dims + dims - new_dims = broadcast_dims + list(destination[0].dims) - interped = interp_func( - var.transpose(*original_dims).data, x, destination, method, kwargs - ) - result = Variable(new_dims, interped, attrs=var.attrs) + result = Variable(new_dims, interped, attrs=var.attrs) - # dimension of the output array - out_dims = OrderedSet() - for d in var.dims: - if d in dims: - out_dims.update(indexes_coords[d][1].dims) - else: - out_dims.add(d) - return result.transpose(*tuple(out_dims)) + # dimension of the output array + out_dims = OrderedSet() + for d in var.dims: + if d in dims: + out_dims.update(indexes_coords[d][1].dims) + else: + out_dims.add(d) + result = result.transpose(*tuple(out_dims)) + return result def interp_func(var, x, new_x, method, kwargs): @@ -793,3 +779,35 @@ def _dask_aware_interpnd(var, *coords, n_x: int, interp_func, interp_kwargs, met localized_x, localized_new_x = zip(*[indexes_coords[d] for d in indexes_coords]) return _interpnd(var.data, localized_x, localized_new_x, interp_func, interp_kwargs) + + +def decompose_interp(indexes_coords): + """Decompose the interpolation into a succession of independant interpolation keeping the order""" + + dest_dims = [ + dest[1].dims if dest[1].ndim > 0 else [dim] + for dim, dest in indexes_coords.items() + ] + partial_dest_dims = [] + partial_indexes_coords = {} + for i, index_coords in enumerate(indexes_coords.items()): + partial_indexes_coords.update([index_coords]) + + if i == len(dest_dims) - 1: + break + + partial_dest_dims += [dest_dims[i]] + other_dims = dest_dims[i + 1 :] + + s_partial_dest_dims = {dim for dims in partial_dest_dims for dim in dims} + s_other_dims = {dim for dims in other_dims for dim in dims} + + if not s_partial_dest_dims.intersection(s_other_dims): + # this interpolation is orthogonal to the rest + + yield partial_indexes_coords + + partial_dest_dims = [] + partial_indexes_coords = {} + + yield partial_indexes_coords diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index a47f27729f6..05f23ddf4ce 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -8,7 +8,6 @@ from xarray.tests import ( assert_allclose, assert_equal, - raises_regex, requires_cftime, requires_dask, requires_scipy, @@ -379,8 +378,6 @@ def test_errors(use_dask): # invalid method with pytest.raises(ValueError): da.interp(x=[2, 0], method="boo") - with pytest.raises(ValueError): - da.interp(x=[2, 0], y=2, method="cubic") with pytest.raises(ValueError): da.interp(y=[2, 0], method="boo") @@ -791,16 +788,6 @@ def test_interpolate_chunk(method, sorted, data_ndim, interp_ndim, nscalar): xdest = xdest.flatten() dest[dim] = xdest - if interp_ndim > 1 and method not in ["linear", "nearest"]: - # Check that an error is raised if an attempt is made to interpolate - # over a chunked dimension with high order method - with raises_regex( - ValueError, - f"{method} is not a valid interpolator for interpolating over multiple dimensions.", - ): - da.interp(method=method, **dest) - return - actual = da.interp(method=method, **dest, kwargs=kwargs) expected = da.compute().interp(method=method, **dest, kwargs=kwargs) From 3f8718e6a7939b09071729f17a229a2a27237a61 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Thu, 30 Jul 2020 15:29:12 +0200 Subject: [PATCH 38/42] simplify _dask_aware_interpnd a little --- xarray/core/missing.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 763b5203a84..db23324925b 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -708,6 +708,9 @@ def interp_func(var, x, new_x, method, kwargs): ] new_x_arginds = [item for pair in new_x_arginds for item in pair] + # if usefull, re-use localize for each chunk of new_x + localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None) + return da.blockwise( _dask_aware_interpnd, out_ind, @@ -715,14 +718,13 @@ def interp_func(var, x, new_x, method, kwargs): range(var.ndim), *x_arginds, *new_x_arginds, - n_x=var.ndim - nconst, interp_func=func, interp_kwargs=kwargs, - method=method, + localize=localize, concatenate=True, - new_axes=new_axes, - dtype=var.dtype, meta=np.ndarray, + dtype=var.dtype, + new_axes=new_axes, ) return _interpnd(var, x, new_x, func, kwargs) @@ -755,14 +757,16 @@ def _interpnd(var, x, new_x, func, kwargs): return rslt.reshape(rslt.shape[:-1] + new_x[0].shape) -def _dask_aware_interpnd(var, *coords, n_x: int, interp_func, interp_kwargs, method): +def _dask_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True): """Wrapper for `_interpnd` through `blockwise` The first `n_x` arrays in `coords` are original coordinates, the `n_x` others (the rest) are destination coordinates """ - # Convert all to Variable, in order to use _localize + n_x = len(coords) // 2 nconst = len(var.shape) - n_x + + # Convert all to Variable, in order to use _localize x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])] var = Variable([f"dim_{dim}" for dim in range(len(var.shape))], var) new_x = [ @@ -774,7 +778,7 @@ def _dask_aware_interpnd(var, *coords, n_x: int, interp_func, interp_kwargs, met indexes_coords = {_x.dims[0]: (_x, _new_x) for _x, _new_x in zip(x, new_x)} # simple speed up for the local interpolation - if method in ["linear", "nearest"]: + if localize: var, indexes_coords = _localize(var, indexes_coords) localized_x, localized_new_x = zip(*[indexes_coords[d] for d in indexes_coords]) From 562d5aaeaf8b0d9d46867b272bcedd07f762096c Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Thu, 30 Jul 2020 15:30:32 +0200 Subject: [PATCH 39/42] fix dask interp when chunks are not aligned --- xarray/core/missing.py | 26 ++++++++++++++++---------- xarray/tests/test_interp.py | 4 ++-- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index db23324925b..407c1376760 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -693,12 +693,6 @@ def interp_func(var, x, new_x, method, kwargs): nconst = var.ndim - len(x) out_ind = list(range(nconst)) + list(range(var.ndim, var.ndim + new_x[0].ndim)) - new_axes = { - var.ndim + i: new_x[0].chunks[i] - if new_x[0].chunks is not None - else new_x[0].shape[i] - for i in range(new_x[0].ndim) - } # blockwise args format x_arginds = [[_x, (nconst + index,)] for index, _x in enumerate(x)] @@ -708,16 +702,28 @@ def interp_func(var, x, new_x, method, kwargs): ] new_x_arginds = [item for pair in new_x_arginds for item in pair] + args = var, range(var.ndim), *x_arginds, *new_x_arginds, + + _, rechunked = da.unify_chunks(*args) + + args = tuple([elem for pair in zip(rechunked, args[1::2]) for elem in pair]) + + new_x = rechunked[1 + (len(rechunked)-1) // 2:] + + new_axes = { + var.ndim + i: new_x[0].chunks[i] + if new_x[0].chunks is not None + else new_x[0].shape[i] + for i in range(new_x[0].ndim) + } + # if usefull, re-use localize for each chunk of new_x localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None) return da.blockwise( _dask_aware_interpnd, out_ind, - var, - range(var.ndim), - *x_arginds, - *new_x_arginds, + *args, interp_func=func, interp_kwargs=kwargs, localize=localize, diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 05f23ddf4ce..292dba2e4f8 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -845,8 +845,8 @@ def test_interpolate_chunk_advanced(): # Create indexer into `a` with dimensions (y, x) expected = da.interp(x=x, y=y, z=z, t=0.5, kwargs=kwargs, method="linear") da = da.chunk(2) - x = x.chunk(2) + x = x.chunk(3) # y = y.chunk(2) - z = z.chunk(2) + z = z.chunk(5) actual = da.interp(x=x, y=y, z=z, t=0.5, kwargs=kwargs, method="linear") assert_allclose(actual, expected) From 62f059c79dab3ddc574b3989123b76c51088dc1a Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Fri, 31 Jul 2020 15:34:09 +0200 Subject: [PATCH 40/42] continue simplifying _dask_aware_interpnd --- xarray/core/missing.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 407c1376760..a6bed408164 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -702,13 +702,18 @@ def interp_func(var, x, new_x, method, kwargs): ] new_x_arginds = [item for pair in new_x_arginds for item in pair] - args = var, range(var.ndim), *x_arginds, *new_x_arginds, + args = ( + var, + range(var.ndim), + *x_arginds, + *new_x_arginds, + ) _, rechunked = da.unify_chunks(*args) args = tuple([elem for pair in zip(rechunked, args[1::2]) for elem in pair]) - new_x = rechunked[1 + (len(rechunked)-1) // 2:] + new_x = rechunked[1 + (len(rechunked) - 1) // 2 :] new_axes = { var.ndim + i: new_x[0].chunks[i] @@ -728,7 +733,6 @@ def interp_func(var, x, new_x, method, kwargs): interp_kwargs=kwargs, localize=localize, concatenate=True, - meta=np.ndarray, dtype=var.dtype, new_axes=new_axes, ) @@ -766,29 +770,33 @@ def _interpnd(var, x, new_x, func, kwargs): def _dask_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True): """Wrapper for `_interpnd` through `blockwise` - The first `n_x` arrays in `coords` are original coordinates, - the `n_x` others (the rest) are destination coordinates + The first half arrays in `coords` are original coordinates, + the other half are destination coordinates """ n_x = len(coords) // 2 nconst = len(var.shape) - n_x - # Convert all to Variable, in order to use _localize + # _interpnd expect coords to be Variables x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])] - var = Variable([f"dim_{dim}" for dim in range(len(var.shape))], var) new_x = [ - Variable( - [f"dim_{outer_dim}_{inner_dim}" for inner_dim in range(len(_x.shape))], _x - ) - for outer_dim, _x in enumerate(coords[n_x:]) + Variable([f"dim_{len(var.shape) + dim}" for dim in range(len(_x.shape))], _x) + for _x in coords[n_x:] ] - indexes_coords = {_x.dims[0]: (_x, _new_x) for _x, _new_x in zip(x, new_x)} - # simple speed up for the local interpolation if localize: + # _localize expect var to be a Variable + var = Variable([f"dim_{dim}" for dim in range(len(var.shape))], var) + + indexes_coords = {_x.dims[0]: (_x, _new_x) for _x, _new_x in zip(x, new_x)} + + # simple speed up for the local interpolation var, indexes_coords = _localize(var, indexes_coords) - localized_x, localized_new_x = zip(*[indexes_coords[d] for d in indexes_coords]) + x, new_x = zip(*[indexes_coords[d] for d in indexes_coords]) + + # put var back as a ndarray + var = var.data - return _interpnd(var.data, localized_x, localized_new_x, interp_func, interp_kwargs) + return _interpnd(var, x, new_x, interp_func, interp_kwargs) def decompose_interp(indexes_coords): From 3d4d45cec0410a8415aa8fdc74612b82906f38af Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Fri, 31 Jul 2020 18:26:36 +0200 Subject: [PATCH 41/42] update whats-new.rst --- doc/whats-new.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a837264b937..7e76066cb33 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -162,10 +162,10 @@ New Features Enhancements ~~~~~~~~~~~~ - Performance improvement of :py:meth:`DataArray.interp` and :py:func:`Dataset.interp` - For orthogonal linear- and nearest-neighbor interpolation, we do 0d- and 1d-interpolation sequentially - rather than interpolating in multidimensional space. (:issue:`2223`) + We performs independant interpolation sequentially rather than interpolating in + one large multidimensional space. (:issue:`2223`) By `Keisuke Fujii `_. -- :py:meth:`DataArray.interp` now support some interpolations over a chunked dimension (:pull:`4155`). By `Alexandre Poux `_. +- :py:meth:`DataArray.interp` now support interpolations over chunked dimensions (:pull:`4155`). By `Alexandre Poux `_. - Major performance improvement for :py:meth:`Dataset.from_dataframe` when the dataframe has a MultiIndex (:pull:`4184`). By `Stephan Hoyer `_. From b60cddf176d0524ed0a09c3cbb9a5acb76449e76 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Fri, 31 Jul 2020 18:27:00 +0200 Subject: [PATCH 42/42] clean tests --- xarray/tests/test_interp.py | 83 +++++++++++++++---------------------- 1 file changed, 34 insertions(+), 49 deletions(-) diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 292dba2e4f8..17e418c3731 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -8,6 +8,7 @@ from xarray.tests import ( assert_allclose, assert_equal, + assert_identical, requires_cftime, requires_dask, requires_scipy, @@ -724,7 +725,7 @@ def test_decompose(method): @pytest.mark.parametrize( "method", ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"] ) -@pytest.mark.parametrize("sorted", [True, False]) +@pytest.mark.parametrize("chunked", [True, False]) @pytest.mark.parametrize( "data_ndim,interp_ndim,nscalar", [ @@ -734,7 +735,12 @@ def test_decompose(method): for nscalar in range(0, interp_ndim + 1) ], ) -def test_interpolate_chunk(method, sorted, data_ndim, interp_ndim, nscalar): +def test_interpolate_chunk_1d(method, data_ndim, interp_ndim, nscalar, chunked): + """Interpolate nd array with multiple independant indexers + + It should do a series of 1d interpolation + """ + # 3d non chunked data x = np.linspace(0, 1, 5) y = np.linspace(2, 4, 7) @@ -751,12 +757,12 @@ def test_interpolate_chunk(method, sorted, data_ndim, interp_ndim, nscalar): for data_dims in permutations(da.dims, data_ndim): # select only data_ndim dim - da = da.isel( + da = da.isel( # take the middle line {dim: len(da.coords[dim]) // 2 for dim in da.dims if dim not in data_dims} ) # chunk data - da = da.chunk(chunks={dim: len(da.coords[dim]) // 3 for dim in da.dims}) + da = da.chunk(chunks={dim: i + 1 for i, dim in enumerate(da.dims)}) # choose the interpolation dimensions for interp_dims in permutations(da.dims, interp_ndim): @@ -764,54 +770,33 @@ def test_interpolate_chunk(method, sorted, data_ndim, interp_ndim, nscalar): for scalar_dims in combinations(interp_dims, nscalar): dest = {} for dim in interp_dims: - # choose a point between chunks - first_chunks = da.chunks[da.get_axis_num(dim)][0] - middle_point = 0.5 * ( - da.coords[dim][first_chunks - 1] + da.coords[dim][first_chunks] - ) if dim in scalar_dims: - # choose a point between chunks - dest[dim] = middle_point + # take the middle point + dest[dim] = 0.5 * (da.coords[dim][0] + da.coords[dim][-1]) else: - # pick some points, including outside the domain and bewteen chunks + # pick some points, including outside the domain before = 2 * da.coords[dim][0] - da.coords[dim][1] after = 2 * da.coords[dim][-1] - da.coords[dim][-2] - inside = da.coords[dim][first_chunks // 2] - - xdest = np.linspace( - inside, middle_point, 2 * (first_chunks // 2), - ) - xdest = np.concatenate([[before], xdest, [after]]) - if not sorted: - xdest = xdest.reshape((-1, 2)) - xdest[:, 1] = xdest[::-1, 1] - xdest = xdest.flatten() - dest[dim] = xdest + dest[dim] = np.linspace(before, after, len(da.coords[dim]) * 13) + if chunked: + dest[dim] = xr.DataArray(data=dest[dim], dims=[dim]) + dest[dim] = dest[dim].chunk(2) actual = da.interp(method=method, **dest, kwargs=kwargs) expected = da.compute().interp(method=method, **dest, kwargs=kwargs) - assert_allclose(actual, expected) + assert_identical(actual, expected) + + # all the combinations are usualy not necessary break break + break @requires_scipy @requires_dask -def test_interpolate_chunk_rename(): - da = get_example_data(0).chunk({"x": 5}) - - # grid -> 1d-sample - xdest = xr.DataArray(np.linspace(0.1, 1.0, 11), dims="renamed") - actual = da.interp(x=xdest, method="linear") - expected = da.compute().interp(x=xdest, method="linear") - - assert_allclose(actual, expected) - - -@requires_scipy -@requires_dask -def test_interpolate_chunk_advanced(): +@pytest.mark.parametrize("method", ["linear", "nearest"]) +def test_interpolate_chunk_advanced(method): """Interpolate nd array with an nd indexer sharing coordinates.""" # Create original array x = np.linspace(-1, 1, 5) @@ -826,15 +811,16 @@ def test_interpolate_chunk_advanced(): * t[:, np.newaxis] + q, dims=("x", "y", "z", "t", "q"), - coords={"x": x, "y": y, "z": z, "t": t, "q": q, "label": "toto"}, + coords={"x": x, "y": y, "z": z, "t": t, "q": q, "label": "dummy_attr"}, ) - theta = np.linspace(0, 2 * np.pi, 19) - w = np.linspace(-0.25, 0.25, 23) - + # Create indexer into `da` with shared coordinate ("full-twist" Möbius strip) + theta = np.linspace(0, 2 * np.pi, 5) + w = np.linspace(-0.25, 0.25, 7) r = xr.DataArray( data=1 + w[:, np.newaxis] * np.cos(theta), coords=[("w", w), ("theta", theta)], ) + x = r * np.cos(theta) y = r * np.sin(theta) z = xr.DataArray( @@ -842,11 +828,10 @@ def test_interpolate_chunk_advanced(): ) kwargs = {"fill_value": None} - # Create indexer into `a` with dimensions (y, x) - expected = da.interp(x=x, y=y, z=z, t=0.5, kwargs=kwargs, method="linear") + expected = da.interp(t=0.5, x=x, y=y, z=z, kwargs=kwargs, method=method) + da = da.chunk(2) - x = x.chunk(3) - # y = y.chunk(2) - z = z.chunk(5) - actual = da.interp(x=x, y=y, z=z, t=0.5, kwargs=kwargs, method="linear") - assert_allclose(actual, expected) + x = x.chunk(1) + z = z.chunk(3) + actual = da.interp(t=0.5, x=x, y=y, z=z, kwargs=kwargs, method=method) + assert_identical(actual, expected)