Skip to content
forked from pydata/xarray

Commit

Permalink
apply_func: Set meta=np.ndarray when vectorize=True and dask="paralle…
Browse files Browse the repository at this point in the history
…lized"

Closes pydata#3574
  • Loading branch information
dcherian committed Jan 2, 2020
1 parent b3d3b44 commit 7b86a5e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
14 changes: 13 additions & 1 deletion xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def apply_variable_ufunc(
output_dtypes=None,
output_sizes=None,
keep_attrs=False,
vectorize=False,
):
"""Apply a ndarray level function over Variable and/or ndarray objects.
"""
Expand Down Expand Up @@ -579,6 +580,7 @@ def apply_variable_ufunc(
elif dask == "parallelized":
input_dims = [broadcast_dims + dims for dims in signature.input_core_dims]
numpy_func = func
meta = np.ndarray if vectorize else None

def func(*arrays):
return _apply_blockwise(
Expand All @@ -589,6 +591,7 @@ def func(*arrays):
signature,
output_dtypes,
output_sizes,
meta,
)

elif dask == "allowed":
Expand Down Expand Up @@ -647,7 +650,14 @@ def func(*arrays):


def _apply_blockwise(
func, args, input_dims, output_dims, signature, output_dtypes, output_sizes=None
func,
args,
input_dims,
output_dims,
signature,
output_dtypes,
output_sizes=None,
meta=None,
):
import dask.array

Expand Down Expand Up @@ -719,6 +729,7 @@ def _apply_blockwise(
dtype=dtype,
concatenate=True,
new_axes=output_sizes,
meta=meta,
)


Expand Down Expand Up @@ -1005,6 +1016,7 @@ def earth_mover_distance(first_samples,
dask=dask,
output_dtypes=output_dtypes,
output_sizes=output_sizes,
vectorize=vectorize,
)

if any(isinstance(a, GroupBy) for a in args):
Expand Down
18 changes: 18 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,24 @@ def test_vectorize_dask():
assert_identical(expected, actual)


@requires_dask
def test_vectorize_dask_new_output_dims():
# regression test for GH3574
data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y"))
func = lambda x: x[np.newaxis, ...]
expected = data_array.expand_dims("z")
actual = apply_ufunc(
func,
data_array.chunk({"x": 1}),
output_core_dims=[["z"]],
vectorize=True,
dask="parallelized",
output_dtypes=[float],
output_sizes={"z": 1},
).transpose(*expected.dims)
assert_identical(expected, actual)


def test_output_wrong_number():
variable = xr.Variable("x", np.arange(10))

Expand Down

0 comments on commit 7b86a5e

Please sign in to comment.