diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 8f1ac2d93ce..0590bcdbb77 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -35,7 +35,7 @@ def make_meta(obj): from dask.array.utils import meta_from_array if isinstance(obj, DataArray): - meta = DataArray(obj.data._meta, dims=obj.dims) + meta = DataArray(obj.data._meta, dims=obj.dims, name=obj.name) if isinstance(obj, Dataset): meta = Dataset() @@ -45,9 +45,14 @@ def make_meta(obj): else: meta_obj = meta_from_array(obj[name].data) meta[name] = DataArray(meta_obj, dims=obj[name].dims) + # meta[name] = DataArray(obj[name].dims, meta_obj) else: meta = obj + # TODO: deal with non-dim coords + # for coord_name in (set(obj.coords) - set(obj.dims)): # DataArrays should have _coord_names! + # coord = obj[coord_name] + return meta @@ -65,7 +70,7 @@ def infer_template(func, obj, *args, **kwargs): return template -def _make_dict(x): +def make_dict(x): # Dataset.to_dict() is too complicated # maps variable name to numpy array if isinstance(x, DataArray): @@ -93,6 +98,9 @@ def map_blocks(func, obj, *args, **kwargs): properties of the returned object such as dtype, variable names, new dimensions and new indexes (if any). + This function must + - return either a DataArray or a Dataset + This function cannot - change size of existing dimensions. - add new chunked dimensions. @@ -101,18 +109,24 @@ def map_blocks(func, obj, *args, **kwargs): Chunks of this object will be provided to 'func'. The function must not change shape of the provided DataArray. args: - Passed on to func. + Passed on to func. Cannot include chunked xarray objects. kwargs: - Passed on to func. + Passed on to func. Cannot include chunked xarray objects. Returns ------- DataArray or Dataset + Notes + ----- + + This function is designed to work with dask-backed xarray objects. See apply_ufunc for + a similar function that works with numpy arrays. + See Also -------- - dask.array.map_blocks + dask.array.map_blocks, xarray.apply_ufunc """ def _wrapper(func, obj, to_array, args, kwargs): @@ -129,7 +143,7 @@ def _wrapper(func, obj, to_array, args, kwargs): % name ) - to_return = _make_dict(result) + to_return = make_dict(result) return to_return @@ -149,26 +163,30 @@ def _wrapper(func, obj, to_array, args, kwargs): if isinstance(template, DataArray): result_is_array = True template = template._to_temp_dataset() - else: + elif isinstance(template, Dataset): result_is_array = False + else: + raise ValueError( + "Function must return an xarray DataArray or Dataset. Instead it returned %r" + % type(template) + ) # If two different variables have different chunking along the same dim # .chunks will raise an error. input_chunks = dataset.chunks - indexes = dict(dataset.indexes) - for dim in template.indexes: - if dim not in indexes: - indexes[dim] = template.indexes[dim] + # TODO: add a test that fails when template and dataset are switched + indexes = dict(template.indexes) + indexes.update(dataset.indexes) graph = {} gname = "%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset)) # map dims to list of chunk indexes - ichunk = {dim: range(len(input_chunks[dim])) for dim in input_chunks} + ichunk = {dim: range(len(chunks_v)) for dim, chunks_v in input_chunks.items()} # mapping from chunk index to slice bounds chunk_index_bounds = { - dim: np.cumsum((0,) + input_chunks[dim]) for dim in input_chunks + dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in input_chunks.items() } # iterate over all possible chunk combinations @@ -185,17 +203,15 @@ def _wrapper(func, obj, to_array, args, kwargs): for name, variable in dataset.variables.items(): # make a task that creates tuple of (dims, chunk) if dask.is_dask_collection(variable.data): - var_dask_keys = variable.__dask_keys__() - # recursively index into dask_keys nested list to get chunk - chunk = var_dask_keys + chunk = variable.__dask_keys__() for dim in variable.dims: chunk = chunk[chunk_index_dict[dim]] - task_name = ("tuple-" + dask.base.tokenize(chunk),) + v - graph[task_name] = (tuple, [variable.dims, chunk]) + chunk_variable_task = ("tuple-" + dask.base.tokenize(chunk),) + v + graph[chunk_variable_task] = (tuple, [variable.dims, chunk]) else: - # numpy array with possibly chunked dimensions + # non-dask array with possibly chunked dimensions # index into variable appropriately subsetter = dict() for dim in variable.dims: @@ -207,14 +223,14 @@ def _wrapper(func, obj, to_array, args, kwargs): ) subset = variable.isel(subsetter) - task_name = (name + dask.base.tokenize(subset),) + v - graph[task_name] = (tuple, [subset.dims, subset]) + chunk_variable_task = (name + dask.base.tokenize(subset),) + v + graph[chunk_variable_task] = (tuple, [subset.dims, subset]) # this task creates dict mapping variable name to above tuple - if name in dataset.data_vars: - data_vars.append([name, task_name]) - if name in dataset.coords: - coords.append([name, task_name]) + if name in dataset._coord_names: + coords.append([name, chunk_variable_task]) + else: + data_vars.append([name, chunk_variable_task]) from_wrapper = (gname,) + v graph[from_wrapper] = ( @@ -229,14 +245,15 @@ def _wrapper(func, obj, to_array, args, kwargs): # mapping from variable name to dask graph key var_key_map = {} for name, variable in template.variables.items(): - var_dims = variable.dims + if name in indexes: + continue # cannot tokenize "name" because the hash of is not invariant! # This happens when the user function does not set a name on the returned DataArray gname_l = "%s-%s" % (gname, name) var_key_map[name] = gname_l key = (gname_l,) - for dim in var_dims: + for dim in variable.dims: if dim in chunk_index_dict: key += (chunk_index_dict[dim],) else: @@ -248,26 +265,26 @@ def _wrapper(func, obj, to_array, args, kwargs): graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset]) result = Dataset() - for var, key in var_key_map.items(): - # indexes need to be known - # otherwise compute is called when DataArray is created - if var in indexes: - result[var] = indexes[var] - continue - - dims = template[var].dims + # a quicker way to assign indexes? + # indexes need to be known + # otherwise compute is called when DataArray is created + for name in template.indexes: + result[name] = indexes[name] + for name, key in var_key_map.items(): + dims = template[name].dims var_chunks = [] for dim in dims: if dim in input_chunks: var_chunks.append(input_chunks[dim]) - else: - if dim in indexes: - var_chunks.append((len(indexes[dim]),)) + elif dim in indexes: + var_chunks.append((len(indexes[dim]),)) data = dask.array.Array( - graph, name=key, chunks=var_chunks, dtype=template[var].dtype + graph, name=key, chunks=var_chunks, dtype=template[name].dtype ) - result[var] = DataArray(data=data, dims=dims, name=var) + result[name] = (dims, data) + + result = result.set_coords(template._coord_names) if result_is_array: result = _to_array(result) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index a6d25d8cde7..8925f9c3ca7 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -891,7 +891,7 @@ def make_da(): def make_ds(): map_ds = xr.Dataset() - map_ds["a"] = map_da + map_ds["a"] = make_da() map_ds["b"] = map_ds.a + 50 map_ds["c"] = map_ds.x + 20 map_ds = map_ds.chunk({"x": 4, "y": 5}) @@ -909,6 +909,18 @@ def make_ds(): map_ds = make_ds() +# DataArray.chunks is not a dict but Dataset.chunks is! +def assert_chunks_equal(a, b): + + if isinstance(a, DataArray): + a = a._to_temp_dataset() + + if isinstance(b, DataArray): + b = b._to_temp_dataset() + + assert a.chunks == b.chunks + + def simple_func(obj): result = obj.x + 5 * obj.y return result @@ -933,6 +945,12 @@ def bad_func(darray): with raises_regex(ValueError, "Length of the.* has changed."): xr.map_blocks(bad_func, map_da).compute() + def returns_numpy(darray): + return (darray * darray.x + 5 * darray.y).values + + with raises_regex(ValueError, "Function must return an xarray DataArray"): + xr.map_blocks(returns_numpy, map_da) + @pytest.mark.parametrize( "func, obj", @@ -942,6 +960,7 @@ def test_map_blocks(func, obj): actual = xr.map_blocks(func, obj) expected = func(obj) + assert_chunks_equal(expected, actual) xr.testing.assert_equal(expected, actual) @@ -951,4 +970,31 @@ def test_map_blocks_args(obj): expected = obj + 10 actual = xr.map_blocks(operator.add, obj, 10) + assert_chunks_equal(expected, actual) xr.testing.assert_equal(expected, actual) + + +def da_to_ds(da): + return da.to_dataset() + + +def ds_to_da(ds): + return ds.to_array() + + +@pytest.mark.parametrize( + "func, obj, return_type", + [[da_to_ds, map_da, xr.Dataset], [ds_to_da, map_ds, xr.DataArray]], +) +def map_blocks_transformations(func, obj, return_type): + assert isinstance(xr.map_blocks(func, obj), return_type) + + +# func(DataArray) -> Dataset +# func(Dataset) -> DataArray +# func output contains less variables +# func output contains new variables +# func changes dtypes +# func output contains less (or more) dimensions +# *args, **kwargs are passed through +# IndexVariables don't accidentally cause the whole graph to be computed (the logic you wrote in the main function is quite subtle!)