Skip to content

Commit

Permalink
review feedback:
Browse files Browse the repository at this point in the history
1. skip index graph nodes.
2. var → name
3. quicker dataarray creation.
4. Add restrictions to docstring.
5. rename chunk construction task.
6. error when non-xarray object is returned.
7. restore non-coord dims.

review
  • Loading branch information
dcherian committed Sep 19, 2019
1 parent adbe48e commit 924bf69
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 41 deletions.
97 changes: 57 additions & 40 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def make_meta(obj):
from dask.array.utils import meta_from_array

if isinstance(obj, DataArray):
meta = DataArray(obj.data._meta, dims=obj.dims)
meta = DataArray(obj.data._meta, dims=obj.dims, name=obj.name)

if isinstance(obj, Dataset):
meta = Dataset()
Expand All @@ -45,9 +45,14 @@ def make_meta(obj):
else:
meta_obj = meta_from_array(obj[name].data)
meta[name] = DataArray(meta_obj, dims=obj[name].dims)
# meta[name] = DataArray(obj[name].dims, meta_obj)
else:
meta = obj

# TODO: deal with non-dim coords
# for coord_name in (set(obj.coords) - set(obj.dims)): # DataArrays should have _coord_names!
# coord = obj[coord_name]

return meta


Expand All @@ -65,7 +70,7 @@ def infer_template(func, obj, *args, **kwargs):
return template


def _make_dict(x):
def make_dict(x):
# Dataset.to_dict() is too complicated
# maps variable name to numpy array
if isinstance(x, DataArray):
Expand Down Expand Up @@ -93,6 +98,9 @@ def map_blocks(func, obj, *args, **kwargs):
properties of the returned object such as dtype, variable names,
new dimensions and new indexes (if any).
This function must
- return either a DataArray or a Dataset
This function cannot
- change size of existing dimensions.
- add new chunked dimensions.
Expand All @@ -101,18 +109,24 @@ def map_blocks(func, obj, *args, **kwargs):
Chunks of this object will be provided to 'func'. The function must not change
shape of the provided DataArray.
args:
Passed on to func.
Passed on to func. Cannot include chunked xarray objects.
kwargs:
Passed on to func.
Passed on to func. Cannot include chunked xarray objects.
Returns
-------
DataArray or Dataset
Notes
-----
This function is designed to work with dask-backed xarray objects. See apply_ufunc for
a similar function that works with numpy arrays.
See Also
--------
dask.array.map_blocks
dask.array.map_blocks, xarray.apply_ufunc
"""

def _wrapper(func, obj, to_array, args, kwargs):
Expand All @@ -129,7 +143,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
% name
)

to_return = _make_dict(result)
to_return = make_dict(result)

return to_return

Expand All @@ -149,26 +163,30 @@ def _wrapper(func, obj, to_array, args, kwargs):
if isinstance(template, DataArray):
result_is_array = True
template = template._to_temp_dataset()
else:
elif isinstance(template, Dataset):
result_is_array = False
else:
raise ValueError(
"Function must return an xarray DataArray or Dataset. Instead it returned %r"
% type(template)
)

# If two different variables have different chunking along the same dim
# .chunks will raise an error.
input_chunks = dataset.chunks

indexes = dict(dataset.indexes)
for dim in template.indexes:
if dim not in indexes:
indexes[dim] = template.indexes[dim]
# TODO: add a test that fails when template and dataset are switched
indexes = dict(template.indexes)
indexes.update(dataset.indexes)

graph = {}
gname = "%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset))

# map dims to list of chunk indexes
ichunk = {dim: range(len(input_chunks[dim])) for dim in input_chunks}
ichunk = {dim: range(len(chunks_v)) for dim, chunks_v in input_chunks.items()}
# mapping from chunk index to slice bounds
chunk_index_bounds = {
dim: np.cumsum((0,) + input_chunks[dim]) for dim in input_chunks
dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in input_chunks.items()
}

# iterate over all possible chunk combinations
Expand All @@ -185,17 +203,15 @@ def _wrapper(func, obj, to_array, args, kwargs):
for name, variable in dataset.variables.items():
# make a task that creates tuple of (dims, chunk)
if dask.is_dask_collection(variable.data):
var_dask_keys = variable.__dask_keys__()

# recursively index into dask_keys nested list to get chunk
chunk = var_dask_keys
chunk = variable.__dask_keys__()
for dim in variable.dims:
chunk = chunk[chunk_index_dict[dim]]

task_name = ("tuple-" + dask.base.tokenize(chunk),) + v
graph[task_name] = (tuple, [variable.dims, chunk])
chunk_variable_task = ("tuple-" + dask.base.tokenize(chunk),) + v
graph[chunk_variable_task] = (tuple, [variable.dims, chunk])
else:
# numpy array with possibly chunked dimensions
# non-dask array with possibly chunked dimensions
# index into variable appropriately
subsetter = dict()
for dim in variable.dims:
Expand All @@ -207,14 +223,14 @@ def _wrapper(func, obj, to_array, args, kwargs):
)

subset = variable.isel(subsetter)
task_name = (name + dask.base.tokenize(subset),) + v
graph[task_name] = (tuple, [subset.dims, subset])
chunk_variable_task = (name + dask.base.tokenize(subset),) + v
graph[chunk_variable_task] = (tuple, [subset.dims, subset])

# this task creates dict mapping variable name to above tuple
if name in dataset.data_vars:
data_vars.append([name, task_name])
if name in dataset.coords:
coords.append([name, task_name])
if name in dataset._coord_names:
coords.append([name, chunk_variable_task])
else:
data_vars.append([name, chunk_variable_task])

from_wrapper = (gname,) + v
graph[from_wrapper] = (
Expand All @@ -229,14 +245,15 @@ def _wrapper(func, obj, to_array, args, kwargs):
# mapping from variable name to dask graph key
var_key_map = {}
for name, variable in template.variables.items():
var_dims = variable.dims
if name in indexes:
continue
# cannot tokenize "name" because the hash of <this-array> is not invariant!
# This happens when the user function does not set a name on the returned DataArray
gname_l = "%s-%s" % (gname, name)
var_key_map[name] = gname_l

key = (gname_l,)
for dim in var_dims:
for dim in variable.dims:
if dim in chunk_index_dict:
key += (chunk_index_dict[dim],)
else:
Expand All @@ -248,26 +265,26 @@ def _wrapper(func, obj, to_array, args, kwargs):
graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset])

result = Dataset()
for var, key in var_key_map.items():
# indexes need to be known
# otherwise compute is called when DataArray is created
if var in indexes:
result[var] = indexes[var]
continue

dims = template[var].dims
# a quicker way to assign indexes?
# indexes need to be known
# otherwise compute is called when DataArray is created
for name in template.indexes:
result[name] = indexes[name]
for name, key in var_key_map.items():
dims = template[name].dims
var_chunks = []
for dim in dims:
if dim in input_chunks:
var_chunks.append(input_chunks[dim])
else:
if dim in indexes:
var_chunks.append((len(indexes[dim]),))
elif dim in indexes:
var_chunks.append((len(indexes[dim]),))

data = dask.array.Array(
graph, name=key, chunks=var_chunks, dtype=template[var].dtype
graph, name=key, chunks=var_chunks, dtype=template[name].dtype
)
result[var] = DataArray(data=data, dims=dims, name=var)
result[name] = (dims, data)

result = result.set_coords(template._coord_names)

if result_is_array:
result = _to_array(result)
Expand Down
48 changes: 47 additions & 1 deletion xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def make_da():

def make_ds():
map_ds = xr.Dataset()
map_ds["a"] = map_da
map_ds["a"] = make_da()
map_ds["b"] = map_ds.a + 50
map_ds["c"] = map_ds.x + 20
map_ds = map_ds.chunk({"x": 4, "y": 5})
Expand All @@ -909,6 +909,18 @@ def make_ds():
map_ds = make_ds()


# DataArray.chunks is not a dict but Dataset.chunks is!
def assert_chunks_equal(a, b):

if isinstance(a, DataArray):
a = a._to_temp_dataset()

if isinstance(b, DataArray):
b = b._to_temp_dataset()

assert a.chunks == b.chunks


def simple_func(obj):
result = obj.x + 5 * obj.y
return result
Expand All @@ -933,6 +945,12 @@ def bad_func(darray):
with raises_regex(ValueError, "Length of the.* has changed."):
xr.map_blocks(bad_func, map_da).compute()

def returns_numpy(darray):
return (darray * darray.x + 5 * darray.y).values

with raises_regex(ValueError, "Function must return an xarray DataArray"):
xr.map_blocks(returns_numpy, map_da)


@pytest.mark.parametrize(
"func, obj",
Expand All @@ -942,6 +960,7 @@ def test_map_blocks(func, obj):

actual = xr.map_blocks(func, obj)
expected = func(obj)
assert_chunks_equal(expected, actual)
xr.testing.assert_equal(expected, actual)


Expand All @@ -951,4 +970,31 @@ def test_map_blocks_args(obj):

expected = obj + 10
actual = xr.map_blocks(operator.add, obj, 10)
assert_chunks_equal(expected, actual)
xr.testing.assert_equal(expected, actual)


def da_to_ds(da):
return da.to_dataset()


def ds_to_da(ds):
return ds.to_array()


@pytest.mark.parametrize(
"func, obj, return_type",
[[da_to_ds, map_da, xr.Dataset], [ds_to_da, map_ds, xr.DataArray]],
)
def map_blocks_transformations(func, obj, return_type):
assert isinstance(xr.map_blocks(func, obj), return_type)


# func(DataArray) -> Dataset
# func(Dataset) -> DataArray
# func output contains less variables
# func output contains new variables
# func changes dtypes
# func output contains less (or more) dimensions
# *args, **kwargs are passed through
# IndexVariables don't accidentally cause the whole graph to be computed (the logic you wrote in the main function is quite subtle!)

0 comments on commit 924bf69

Please sign in to comment.