From b090a9ed92988de6fd019550ccaf330d0517d306 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 2 Sep 2019 09:07:01 -0600 Subject: [PATCH 01/76] map_block attempt 2 --- xarray/__init__.py | 1 + xarray/core/parallel.py | 141 ++++++++++++++++++++++++++++++++++++++ xarray/tests/test_dask.py | 28 ++++++++ 3 files changed, 170 insertions(+) create mode 100644 xarray/core/parallel.py diff --git a/xarray/__init__.py b/xarray/__init__.py index a3df034f7c7..1023bd01b89 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -17,6 +17,7 @@ from .core.dataarray import DataArray from .core.merge import merge, MergeError from .core.options import set_options +from .core.parallel import map_blocks from .backends.api import ( open_dataset, diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py new file mode 100644 index 00000000000..2e0780a207b --- /dev/null +++ b/xarray/core/parallel.py @@ -0,0 +1,141 @@ +try: + import dask + import dask.array + from dask.highlevelgraph import HighLevelGraph + +except ImportError: + pass + +import itertools +import numpy as np + +from .dataarray import DataArray +from .dataset import Dataset + + +def map_blocks(func, obj, *args, **kwargs): + """ + Apply a function to each chunk of a DataArray or Dataset. + + Parameters + ---------- + func: callable + User-provided function that should accept DataArrays corresponding to one chunk. + obj: DataArray, Dataset + Chunks of this object will be provided to 'func'. The function must not change + shape of the provided DataArray. + args, kwargs: + Passed on to func. + + Returns + ------- + DataArray + + See Also + -------- + dask.array.map_blocks + """ + + def _wrapper(func, obj, to_array, args, kwargs): + if to_array: + # this should be easier + obj = obj.to_array().squeeze().drop("variable") + + result = func(obj, *args, **kwargs) + + if not isinstance(result, type(obj)): + raise ValueError("Result is not the same type as input.") + if result.shape != obj.shape: + raise ValueError("Result does not have the same shape as input.") + + return result + + # if not isinstance(obj, DataArray): + # raise ValueError("map_blocks can only be used with DataArrays at present.") + + if isinstance(obj, DataArray): + dataset = obj._to_temp_dataset() + to_array = True + else: + dataset = obj + to_array = False + + dataset_dims = list(dataset.dims) + + graph = {} + gname = "map-%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset)) + + # map dims to list of chunk indexes + # If two different variables have different chunking along the same dim + # .chunks will raise an error. + chunks = dataset.chunks + ichunk = {dim: range(len(chunks[dim])) for dim in chunks} + # mapping from chunk index to slice bounds + chunk_index_bounds = {dim: np.cumsum((0,) + chunks[dim]) for dim in chunks} + + # iterate over all possible chunk combinations + for v in itertools.product(*ichunk.values()): + chunk_index_dict = dict(zip(dataset_dims, v)) + + # this will become [[name1, variable1], + # [name2, variable2], + # ...] + # which is passed to dict and then to Dataset + data_vars = [] + coords = [] + + for name, variable in dataset.variables.items(): + # make a task that creates tuple of (dims, chunk) + if dask.is_dask_collection(variable.data): + var_dask_keys = variable.__dask_keys__() + + # recursively index into dask_keys nested list to get chunk + chunk = var_dask_keys + for dim in variable.dims: + chunk = chunk[chunk_index_dict[dim]] + + task_name = ("tuple-" + dask.base.tokenize(chunk),) + v + graph[task_name] = (tuple, [variable.dims, chunk]) + else: + # numpy array with possibly chunked dimensions + # index into variable appropriately + subsetter = dict() + for dim in variable.dims: + if dim in chunk_index_dict: + which_chunk = chunk_index_dict[dim] + subsetter[dim] = slice( + chunk_index_bounds[dim][which_chunk], + chunk_index_bounds[dim][which_chunk + 1], + ) + + subset = variable.isel(subsetter) + task_name = (name + dask.base.tokenize(subset),) + v + graph[task_name] = (tuple, [subset.dims, subset]) + + # this task creates dict mapping variable name to above tuple + if name in dataset.data_vars: + data_vars.append([name, task_name]) + if name in dataset.coords: + coords.append([name, task_name]) + + graph[(gname,) + v] = ( + _wrapper, + func, + (Dataset, (dict, data_vars), (dict, coords), dataset.attrs), + to_array, + args, + kwargs, + ) + + final_graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset]) + + if isinstance(obj, DataArray): + result = DataArray( + dask.array.Array( + final_graph, name=gname, chunks=obj.data.chunks, meta=obj.data._meta + ), + dims=obj.dims, + coords=obj.coords, + ) + + return result diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index e3fc6f65e0f..9d4bb093456 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -878,3 +878,31 @@ def test_dask_layers_and_dependencies(): assert set(x.foo.__dask_graph__().dependencies).issuperset( ds.__dask_graph__().dependencies ) + + +def test_map_blocks(): + darray = xr.DataArray( + dask.array.ones((10, 20), chunks=[4, 5]), + dims=["x", "y"], + coords={"x": np.arange(10), "y": np.arange(100, 120)}, + ) + darray.name = None + + def good_func(darray): + return darray * darray.x + 5 * darray.y + + def bad_func(darray): + return (darray * darray.x + 5 * darray.y)[:1, :1] + + actual = xr.map_blocks(good_func, darray) + expected = good_func(darray) + xr.testing.assert_equal(expected, actual) + + with raises_regex(ValueError, "not have the same shape"): + xr.map_blocks(bad_func, darray).compute() + + import operator + + expected = darray + 10 + actual = xr.map_blocks(operator.add, darray, 10) + xr.testing.assert_equal(expected, actual) From 39487985537f1e5301b53c22c2f54fdca61c9669 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 2 Sep 2019 17:11:53 -0600 Subject: [PATCH 02/76] Address reviews: errors, args + kwargs support. --- xarray/core/parallel.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 2e0780a207b..7071ce445dc 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -13,7 +13,7 @@ from .dataset import Dataset -def map_blocks(func, obj, *args, **kwargs): +def map_blocks(func, obj, *args, dtype=None, **kwargs): """ Apply a function to each chunk of a DataArray or Dataset. @@ -24,8 +24,13 @@ def map_blocks(func, obj, *args, **kwargs): obj: DataArray, Dataset Chunks of this object will be provided to 'func'. The function must not change shape of the provided DataArray. - args, kwargs: + args: Passed on to func. + dtype: + dtype of the DataArray returned by func. + kwargs: + Passed on to func. + Returns ------- @@ -50,8 +55,21 @@ def _wrapper(func, obj, to_array, args, kwargs): return result - # if not isinstance(obj, DataArray): - # raise ValueError("map_blocks can only be used with DataArrays at present.") + if not isinstance(obj, DataArray): + raise ValueError("map_blocks can only be used with DataArrays at present.") + + if not dask.is_dask_collection(obj): + raise ValueError( + "map_blocks can only be used with dask-backed DataArrays. Use .chunk() to convert to a Dask array." + ) + + try: + meta_array = DataArray(obj.data._meta, dims=obj.dims) + result_meta = func(meta_array, *args, **kwargs) + if dtype is None: + dtype = result_meta.dtype + except ValueError: + raise ValueError("Cannot infer return type from user-provided function.") if isinstance(obj, DataArray): dataset = obj._to_temp_dataset() @@ -132,7 +150,7 @@ def _wrapper(func, obj, to_array, args, kwargs): if isinstance(obj, DataArray): result = DataArray( dask.array.Array( - final_graph, name=gname, chunks=obj.data.chunks, meta=obj.data._meta + final_graph, name=gname, chunks=obj.data.chunks, meta=result_meta ), dims=obj.dims, coords=obj.coords, From 4f159c8390d7b04e58b2350109e86d79efd28912 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 3 Sep 2019 07:19:17 -0600 Subject: [PATCH 03/76] Works with datasets! --- xarray/core/parallel.py | 171 +++++++++++++++++++++++++++++--------- xarray/tests/test_dask.py | 59 ++++++++++--- 2 files changed, 180 insertions(+), 50 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 7071ce445dc..a8ec5a61aa7 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -6,14 +6,68 @@ except ImportError: pass -import itertools import numpy as np +import operator from .dataarray import DataArray from .dataset import Dataset -def map_blocks(func, obj, *args, dtype=None, **kwargs): +def _to_dataset(obj): + if obj.name is not None: + dataset = obj.to_dataset() + else: + dataset = obj._to_temp_dataset() + + return dataset + + +def _to_array(obj): + if not isinstance(obj, Dataset): + raise ValueError("Trying to convert DataArray to DataArray!") + + if len(obj.data_vars) > 1: + raise ValueError( + "Trying to convert Dataset with more than one variable to DataArray" + ) + + name = list(obj.data_vars)[0] + da = obj.to_array().squeeze().drop("variable") + da.name = name + return da + + +def make_meta(obj): + if isinstance(obj, DataArray): + meta = DataArray(obj.data._meta, dims=obj.dims) + + if isinstance(obj, Dataset): + meta = Dataset() + for name, variable in obj.variables.items(): + if dask.is_dask_collection(variable): + meta[name] = DataArray(obj[name].data._meta, dims=obj[name].dims) + else: + continue + + return meta + + +def _make_dict(x): + # Dataset.to_dict() is too complicated + # maps variable name to numpy array + if isinstance(x, DataArray): + x = _to_dataset(x) + + to_return = dict() + for var in x.variables: + # if var not in x: + # raise ValueError("Variable %r not found in returned object." % var) + to_return[var] = x[var].values + + return to_return + + +def map_blocks(func, obj, *args, template=None, **kwargs): """ Apply a function to each chunk of a DataArray or Dataset. @@ -26,67 +80,77 @@ def map_blocks(func, obj, *args, dtype=None, **kwargs): shape of the provided DataArray. args: Passed on to func. - dtype: - dtype of the DataArray returned by func. + template: + template object representing result kwargs: Passed on to func. Returns ------- - DataArray + DataArray or Dataset See Also -------- dask.array.map_blocks """ + import itertools + def _wrapper(func, obj, to_array, args, kwargs): if to_array: # this should be easier - obj = obj.to_array().squeeze().drop("variable") + obj = _to_array(obj) result = func(obj, *args, **kwargs) - if not isinstance(result, type(obj)): - raise ValueError("Result is not the same type as input.") - if result.shape != obj.shape: - raise ValueError("Result does not have the same shape as input.") + # if isinstance(result, DataArray): + # if result.shape != obj.shape: + # raise ValueError("Result does not have the same shape as input.") - return result + to_return = _make_dict(result) - if not isinstance(obj, DataArray): - raise ValueError("map_blocks can only be used with DataArrays at present.") + return to_return if not dask.is_dask_collection(obj): raise ValueError( "map_blocks can only be used with dask-backed DataArrays. Use .chunk() to convert to a Dask array." ) - try: - meta_array = DataArray(obj.data._meta, dims=obj.dims) - result_meta = func(meta_array, *args, **kwargs) - if dtype is None: - dtype = result_meta.dtype - except ValueError: - raise ValueError("Cannot infer return type from user-provided function.") - if isinstance(obj, DataArray): - dataset = obj._to_temp_dataset() - to_array = True + dataset = _to_dataset(obj) + input_is_array = True else: dataset = obj - to_array = False + input_is_array = False dataset_dims = list(dataset.dims) + # infer template / meta information here + if template is None: + try: + meta = make_meta(obj) + result_meta = func(meta, *args, **kwargs) + except ValueError: + raise ValueError("Cannot infer return type from user-provided function.") + + template = result_meta + + if isinstance(template, DataArray): + result_is_array = True + template = _to_dataset(template) + else: + result_is_array = False + + template_vars = list(template.variables) + graph = {} - gname = "map-%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset)) + gname = "%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset)) - # map dims to list of chunk indexes # If two different variables have different chunking along the same dim # .chunks will raise an error. chunks = dataset.chunks + # map dims to list of chunk indexes ichunk = {dim: range(len(chunks[dim])) for dim in chunks} # mapping from chunk index to slice bounds chunk_index_bounds = {dim: np.cumsum((0,) + chunks[dim]) for dim in chunks} @@ -136,24 +200,57 @@ def _wrapper(func, obj, to_array, args, kwargs): if name in dataset.coords: coords.append([name, task_name]) - graph[(gname,) + v] = ( + from_wrapper = (gname,) + v + graph[from_wrapper] = ( _wrapper, func, (Dataset, (dict, data_vars), (dict, coords), dataset.attrs), - to_array, + input_is_array, args, kwargs, ) - final_graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset]) + # mapping from variable name to dask graph key + var_key_map = {} + for var in template_vars: + var_dims = template.variables[var].dims + gname_l = gname + dask.base.tokenize(var) + var_key_map[var] = gname_l - if isinstance(obj, DataArray): - result = DataArray( - dask.array.Array( - final_graph, name=gname, chunks=obj.data.chunks, meta=result_meta - ), - dims=obj.dims, - coords=obj.coords, - ) + key = (gname_l,) + for dim in var_dims: + if dim in chunk_index_dict: + key += (chunk_index_dict[dim],) + else: + # unchunked dimensions in the input have one chunk in the result + key += (0,) + + # this is a list [name, values, dims, attrs] + graph[key] = (operator.getitem, from_wrapper, var) + + graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset]) + + result = Dataset() + for var, key in var_key_map.items(): + # indexes need to be known + # otherwise compute is called when DataArray is created + if var in template.indexes: + result[var] = template[var] + continue + + name = var + dims = template[var].dims + chunks = [ + template.chunks[dim] if dim in template.chunks else (len(template[dim]),) + for dim in dims + ] + dtype = template[var].dtype + + data = dask.array.Array(graph, name=key, chunks=chunks, dtype=dtype) + result[name] = DataArray(data=data, dims=dims, name=name) + + result = Dataset(result) + if result_is_array: + result = _to_array(result) return result diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 9d4bb093456..ac577a5cefd 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -880,29 +880,62 @@ def test_dask_layers_and_dependencies(): ) -def test_map_blocks(): - darray = xr.DataArray( - dask.array.ones((10, 20), chunks=[4, 5]), +def make_da(): + return xr.DataArray( + np.ones((10, 20)), dims=["x", "y"], coords={"x": np.arange(10), "y": np.arange(100, 120)}, - ) - darray.name = None + name="a", + ).chunk({"x": 4, "y": 5}) + + +def make_ds(): + map_ds = xr.Dataset() + map_ds["a"] = map_da + map_ds["b"] = map_ds.a + 50 + map_ds["c"] = map_ds.x + 20 + map_ds = map_ds.chunk({"x": 4, "y": 5}) + map_ds["d"] = ("z", [1, 1, 1, 1]) + map_ds["z"] = [0, 1, 2, 3] + map_ds["e"] = map_ds.x + map_ds.y + map_ds.z + map_ds.attrs["test"] = "test" + + return map_ds + + +# work around mypy error +# xarray/tests/test_dask.py:888: error: Dict entry 0 has incompatible type "str": "int"; expected "Hashable": "Union[None, Number, Tuple[Number, ...]]" +map_da = make_da() +map_ds = make_ds() - def good_func(darray): - return darray * darray.x + 5 * darray.y +@pytest.mark.xfail(reason="Not implemented yet.") +def test_map_blocks_error(): def bad_func(darray): return (darray * darray.x + 5 * darray.y)[:1, :1] - actual = xr.map_blocks(good_func, darray) - expected = good_func(darray) + with raises_regex(ValueError, "not have the same shape"): + xr.map_blocks(bad_func, map_da).compute() + + +@pytest.mark.parametrize("obj, template", [[map_da, map_da], [map_ds, map_ds.a]]) +def test_map_blocks(obj, template): + def good_func(obj): + result = obj.x + 5 * obj.y + # TODO: this needs to be fixed. + if isinstance(result, DataArray): + result.name = "a" + return result + + actual = xr.map_blocks(good_func, obj, template=template) + expected = good_func(obj) xr.testing.assert_equal(expected, actual) - with raises_regex(ValueError, "not have the same shape"): - xr.map_blocks(bad_func, darray).compute() +@pytest.mark.parametrize("obj", [map_da, map_ds]) +def test_map_blocks_args(obj): import operator - expected = darray + 10 - actual = xr.map_blocks(operator.add, darray, 10) + expected = obj + 10 + actual = xr.map_blocks(operator.add, obj, 10, template=obj) xr.testing.assert_equal(expected, actual) From 9179f0b653f6734784fdbc86b04056005edb8cf7 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 7 Sep 2019 22:06:44 -0600 Subject: [PATCH 04/76] remove wrong comment. --- xarray/core/parallel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index a8ec5a61aa7..0dea7a371b6 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -225,7 +225,6 @@ def _wrapper(func, obj, to_array, args, kwargs): # unchunked dimensions in the input have one chunk in the result key += (0,) - # this is a list [name, values, dims, attrs] graph[key] = (operator.getitem, from_wrapper, var) graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset]) From 20c5d5b76e2fb467e47083435240cbd38a99fa1a Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 7 Sep 2019 23:32:17 -0600 Subject: [PATCH 05/76] Support chunks. --- xarray/core/parallel.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 0dea7a371b6..dc31571d8c3 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -67,9 +67,10 @@ def _make_dict(x): return to_return -def map_blocks(func, obj, *args, template=None, **kwargs): +def map_blocks(func, obj, *args, template=None, chunks=None, **kwargs): """ - Apply a function to each chunk of a DataArray or Dataset. + Apply a function to each chunk of a DataArray or Dataset. This function is experimental + and its signature may change. Parameters ---------- @@ -149,11 +150,13 @@ def _wrapper(func, obj, to_array, args, kwargs): # If two different variables have different chunking along the same dim # .chunks will raise an error. - chunks = dataset.chunks + input_chunks = dataset.chunks # map dims to list of chunk indexes - ichunk = {dim: range(len(chunks[dim])) for dim in chunks} + ichunk = {dim: range(len(input_chunks[dim])) for dim in input_chunks} # mapping from chunk index to slice bounds - chunk_index_bounds = {dim: np.cumsum((0,) + chunks[dim]) for dim in chunks} + chunk_index_bounds = { + dim: np.cumsum((0,) + input_chunks[dim]) for dim in input_chunks + } # iterate over all possible chunk combinations for v in itertools.product(*ichunk.values()): @@ -229,6 +232,15 @@ def _wrapper(func, obj, to_array, args, kwargs): graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset]) + if chunks is None: + if template is not None: + chunks = template.chunks + else: + chunks = input_chunks + for dim in chunks: + if dim not in template.dims: + chunks.pop(dim) + result = Dataset() for var, key in var_key_map.items(): # indexes need to be known @@ -239,13 +251,12 @@ def _wrapper(func, obj, to_array, args, kwargs): name = var dims = template[var].dims - chunks = [ - template.chunks[dim] if dim in template.chunks else (len(template[dim]),) - for dim in dims + var_chunks = [ + chunks[dim] if dim in chunks else (len(template[dim]),) for dim in dims ] dtype = template[var].dtype - data = dask.array.Array(graph, name=key, chunks=chunks, dtype=dtype) + data = dask.array.Array(graph, name=key, chunks=var_chunks, dtype=dtype) result[name] = DataArray(data=data, dims=dims, name=name) result = Dataset(result) From b16b237482285b80626f4b83dc365deca65b3607 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 8 Sep 2019 00:25:21 -0600 Subject: [PATCH 06/76] infer template. --- xarray/core/parallel.py | 109 +++++++++++++++++++++++--------------- xarray/tests/test_dask.py | 34 ++++++++---- 2 files changed, 90 insertions(+), 53 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index dc31571d8c3..fe9d4dc9812 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -38,6 +38,9 @@ def _to_array(obj): def make_meta(obj): + + from dask.array.utils import meta_from_array + if isinstance(obj, DataArray): meta = DataArray(obj.data._meta, dims=obj.dims) @@ -45,13 +48,30 @@ def make_meta(obj): meta = Dataset() for name, variable in obj.variables.items(): if dask.is_dask_collection(variable): - meta[name] = DataArray(obj[name].data._meta, dims=obj[name].dims) + meta_obj = obj[name].data._meta else: - continue + meta_obj = meta_from_array(obj[name].data) + meta[name] = DataArray(meta_obj, dims=obj[name].dims) + else: + meta = obj return meta +def infer_template(func, obj, *args, **kwargs): + meta_args = [] + for arg in (obj,) + args: + meta_args.append(make_meta(arg)) + + # infer template / meta information here + try: + template = func(*meta_args, **kwargs) + except ValueError: + raise ValueError("Cannot infer return type from user-provided function.") + + return template + + def _make_dict(x): # Dataset.to_dict() is too complicated # maps variable name to numpy array @@ -67,7 +87,7 @@ def _make_dict(x): return to_return -def map_blocks(func, obj, *args, template=None, chunks=None, **kwargs): +def map_blocks(func, obj, *args, chunks=None, **kwargs): """ Apply a function to each chunk of a DataArray or Dataset. This function is experimental and its signature may change. @@ -76,13 +96,21 @@ def map_blocks(func, obj, *args, template=None, chunks=None, **kwargs): ---------- func: callable User-provided function that should accept DataArrays corresponding to one chunk. + The function will be run on a small piece of data that looks like 'obj' to determine + properties of the returned object such as dtype, variable names, + new dimensions and new indexes (if any). + + This function cannot + - change size of existing dimensions. + - add new chunked dimensions. + obj: DataArray, Dataset Chunks of this object will be provided to 'func'. The function must not change shape of the provided DataArray. args: Passed on to func. - template: - template object representing result + chunks: + dict mapping index name to chunk size. kwargs: Passed on to func. @@ -127,30 +155,32 @@ def _wrapper(func, obj, to_array, args, kwargs): dataset_dims = list(dataset.dims) - # infer template / meta information here - if template is None: - try: - meta = make_meta(obj) - result_meta = func(meta, *args, **kwargs) - except ValueError: - raise ValueError("Cannot infer return type from user-provided function.") - - template = result_meta - + template = infer_template(func, obj, *args, **kwargs) if isinstance(template, DataArray): result_is_array = True template = _to_dataset(template) else: result_is_array = False - template_vars = list(template.variables) + # If two different variables have different chunking along the same dim + # .chunks will raise an error. + input_chunks = dataset.chunks + # get output chunks + if chunks is None: + # assume chunking doesn't change + chunks = input_chunks + for dim in chunks: + if dim not in template.dims: + chunks.pop(dim) + + indexes = dict(dataset.indexes) + for dim in template.indexes: + if dim not in indexes: + indexes[dim] = template.indexes[dim] graph = {} gname = "%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset)) - # If two different variables have different chunking along the same dim - # .chunks will raise an error. - input_chunks = dataset.chunks # map dims to list of chunk indexes ichunk = {dim: range(len(input_chunks[dim])) for dim in input_chunks} # mapping from chunk index to slice bounds @@ -215,10 +245,10 @@ def _wrapper(func, obj, to_array, args, kwargs): # mapping from variable name to dask graph key var_key_map = {} - for var in template_vars: - var_dims = template.variables[var].dims - gname_l = gname + dask.base.tokenize(var) - var_key_map[var] = gname_l + for name, variable in template.variables.items(): + var_dims = variable.dims + gname_l = gname + dask.base.tokenize(name) + var_key_map[name] = gname_l key = (gname_l,) for dim in var_dims: @@ -228,36 +258,31 @@ def _wrapper(func, obj, to_array, args, kwargs): # unchunked dimensions in the input have one chunk in the result key += (0,) - graph[key] = (operator.getitem, from_wrapper, var) + graph[key] = (operator.getitem, from_wrapper, name) graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset]) - if chunks is None: - if template is not None: - chunks = template.chunks - else: - chunks = input_chunks - for dim in chunks: - if dim not in template.dims: - chunks.pop(dim) - result = Dataset() for var, key in var_key_map.items(): # indexes need to be known # otherwise compute is called when DataArray is created - if var in template.indexes: - result[var] = template[var] + if var in indexes: + result[var] = indexes[var] continue - name = var dims = template[var].dims - var_chunks = [ - chunks[dim] if dim in chunks else (len(template[dim]),) for dim in dims - ] - dtype = template[var].dtype + var_chunks = [] + for dim in dims: + if dim in chunks: + var_chunks.append(chunks[dim]) + else: + if dim in indexes: + var_chunks.append((len(indexes[dim]),)) - data = dask.array.Array(graph, name=key, chunks=var_chunks, dtype=dtype) - result[name] = DataArray(data=data, dims=dims, name=name) + data = dask.array.Array( + graph, name=key, chunks=var_chunks, dtype=template[var].dtype + ) + result[var] = DataArray(data=data, dims=dims, name=var) result = Dataset(result) if result_is_array: diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index ac577a5cefd..245ffb257bc 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -909,6 +909,21 @@ def make_ds(): map_ds = make_ds() +def simple_func(obj): + result = obj.x + 5 * obj.y + # TODO: this needs to be fixed. + if isinstance(result, DataArray): + result.name = "a" + return result + + +def complicated_func(obj): + new = obj.copy() + new = new[["a", "b"]].rename({"a": "new_var1"}).expand_dims(k=[0, 1, 2]) + new["b"] = new.b.astype("int32") + return new + + @pytest.mark.xfail(reason="Not implemented yet.") def test_map_blocks_error(): def bad_func(darray): @@ -918,17 +933,14 @@ def bad_func(darray): xr.map_blocks(bad_func, map_da).compute() -@pytest.mark.parametrize("obj, template", [[map_da, map_da], [map_ds, map_ds.a]]) -def test_map_blocks(obj, template): - def good_func(obj): - result = obj.x + 5 * obj.y - # TODO: this needs to be fixed. - if isinstance(result, DataArray): - result.name = "a" - return result +@pytest.mark.parametrize( + "func, obj", + [[simple_func, map_da], [simple_func, map_ds], [complicated_func, map_ds]], +) +def test_map_blocks(func, obj): - actual = xr.map_blocks(good_func, obj, template=template) - expected = good_func(obj) + actual = xr.map_blocks(func, obj) + expected = func(obj) xr.testing.assert_equal(expected, actual) @@ -937,5 +949,5 @@ def test_map_blocks_args(obj): import operator expected = obj + 10 - actual = xr.map_blocks(operator.add, obj, 10, template=obj) + actual = xr.map_blocks(operator.add, obj, 10) xr.testing.assert_equal(expected, actual) From 43ef2b723b4c7973519bf66bad639606389b1ffd Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 8 Sep 2019 00:25:31 -0600 Subject: [PATCH 07/76] cleanup --- xarray/core/parallel.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index fe9d4dc9812..b10deea90aa 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -6,6 +6,7 @@ except ImportError: pass +import itertools import numpy as np import operator @@ -124,8 +125,6 @@ def map_blocks(func, obj, *args, chunks=None, **kwargs): dask.array.map_blocks """ - import itertools - def _wrapper(func, obj, to_array, args, kwargs): if to_array: # this should be easier @@ -153,8 +152,6 @@ def _wrapper(func, obj, to_array, args, kwargs): dataset = obj input_is_array = False - dataset_dims = list(dataset.dims) - template = infer_template(func, obj, *args, **kwargs) if isinstance(template, DataArray): result_is_array = True @@ -190,7 +187,7 @@ def _wrapper(func, obj, to_array, args, kwargs): # iterate over all possible chunk combinations for v in itertools.product(*ichunk.values()): - chunk_index_dict = dict(zip(dataset_dims, v)) + chunk_index_dict = dict(zip(dataset.dims, v)) # this will become [[name1, variable1], # [name2, variable2], @@ -284,7 +281,6 @@ def _wrapper(func, obj, to_array, args, kwargs): ) result[var] = DataArray(data=data, dims=dims, name=var) - result = Dataset(result) if result_is_array: result = _to_array(result) From 5ebf7387500b826e33558c96854fd08c6528fec6 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 8 Sep 2019 00:27:53 -0600 Subject: [PATCH 08/76] cleanup2 --- xarray/core/parallel.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index b10deea90aa..f80f960b0a9 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -64,11 +64,10 @@ def infer_template(func, obj, *args, **kwargs): for arg in (obj,) + args: meta_args.append(make_meta(arg)) - # infer template / meta information here try: template = func(*meta_args, **kwargs) except ValueError: - raise ValueError("Cannot infer return type from user-provided function.") + raise ValueError("Cannot infer object returned by user-provided function.") return template From 8a460bbf619b5d8b8120682a9ec3fc55717d428c Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 8 Sep 2019 00:34:25 -0600 Subject: [PATCH 09/76] api.rst --- doc/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/api.rst b/doc/api.rst index 872e7786e1b..610977b0798 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -30,6 +30,7 @@ Top-level functions zeros_like ones_like dot + map_blocks Dataset ======= From 505f3f05f059b56b4657d6199276a07e53b8762e Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 8 Sep 2019 08:52:18 -0600 Subject: [PATCH 10/76] simple shape change error check. --- xarray/core/parallel.py | 11 ++++++++--- xarray/tests/test_dask.py | 3 +-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index f80f960b0a9..036ca8be3d7 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -60,6 +60,7 @@ def make_meta(obj): def infer_template(func, obj, *args, **kwargs): + """ Infer return object by running the function on meta objects. """ meta_args = [] for arg in (obj,) + args: meta_args.append(make_meta(arg)) @@ -131,9 +132,13 @@ def _wrapper(func, obj, to_array, args, kwargs): result = func(obj, *args, **kwargs) - # if isinstance(result, DataArray): - # if result.shape != obj.shape: - # raise ValueError("Result does not have the same shape as input.") + for name, index in result.indexes.items(): + if name in obj.indexes: + if len(index) != len(obj.indexes[name]): + raise ValueError( + "Length of the %r dimension has changed. This is not allowed." + % name + ) to_return = _make_dict(result) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 245ffb257bc..c5a3d0ddaa7 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -924,12 +924,11 @@ def complicated_func(obj): return new -@pytest.mark.xfail(reason="Not implemented yet.") def test_map_blocks_error(): def bad_func(darray): return (darray * darray.x + 5 * darray.y)[:1, :1] - with raises_regex(ValueError, "not have the same shape"): + with raises_regex(ValueError, "Length of the.* has changed."): xr.map_blocks(bad_func, map_da).compute() From fe1982f26ef2a958d10229294c06f5c332d19fc9 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 8 Sep 2019 08:53:20 -0600 Subject: [PATCH 11/76] Make test more complicated. --- xarray/tests/test_dask.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index c5a3d0ddaa7..b7b55a7f7f5 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -919,7 +919,12 @@ def simple_func(obj): def complicated_func(obj): new = obj.copy() - new = new[["a", "b"]].rename({"a": "new_var1"}).expand_dims(k=[0, 1, 2]) + new = ( + new[["a", "b"]] + .rename({"a": "new_var1"}) + .expand_dims(k=[0, 1, 2]) + .transpose("k", "y", "x") + ) new["b"] = new.b.astype("int32") return new From 066eb598ed278b99efeffa1d0b558b2a59ae67bf Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 8 Sep 2019 20:43:18 -0600 Subject: [PATCH 12/76] Fix for when user function doesn't set DataArray.name --- xarray/core/parallel.py | 6 ++++-- xarray/tests/test_dask.py | 3 --- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 036ca8be3d7..189bc22e8d3 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -33,6 +33,7 @@ def _to_array(obj): ) name = list(obj.data_vars)[0] + # this should be easier da = obj.to_array().squeeze().drop("variable") da.name = name return da @@ -127,7 +128,6 @@ def map_blocks(func, obj, *args, chunks=None, **kwargs): def _wrapper(func, obj, to_array, args, kwargs): if to_array: - # this should be easier obj = _to_array(obj) result = func(obj, *args, **kwargs) @@ -248,7 +248,9 @@ def _wrapper(func, obj, to_array, args, kwargs): var_key_map = {} for name, variable in template.variables.items(): var_dims = variable.dims - gname_l = gname + dask.base.tokenize(name) + # cannot tokenize "name" because the hash of is not invariant! + # This happens when the user function does not set a name on the returned DataArray + gname_l = "%s-%s" % (gname, name) var_key_map[name] = gname_l key = (gname_l,) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index b7b55a7f7f5..a6d25d8cde7 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -911,9 +911,6 @@ def make_ds(): def simple_func(obj): result = obj.x + 5 * obj.y - # TODO: this needs to be fixed. - if isinstance(result, DataArray): - result.name = "a" return result From 83eb3102c80b1cbfc8484c4df16b57141b6de45a Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 8 Sep 2019 20:50:20 -0600 Subject: [PATCH 13/76] Now _to_temp_dataset works. --- xarray/core/parallel.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 189bc22e8d3..be332d6e415 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -14,15 +14,6 @@ from .dataset import Dataset -def _to_dataset(obj): - if obj.name is not None: - dataset = obj.to_dataset() - else: - dataset = obj._to_temp_dataset() - - return dataset - - def _to_array(obj): if not isinstance(obj, Dataset): raise ValueError("Trying to convert DataArray to DataArray!") @@ -78,7 +69,7 @@ def _make_dict(x): # Dataset.to_dict() is too complicated # maps variable name to numpy array if isinstance(x, DataArray): - x = _to_dataset(x) + x = x._to_temp_dataset() to_return = dict() for var in x.variables: @@ -150,7 +141,7 @@ def _wrapper(func, obj, to_array, args, kwargs): ) if isinstance(obj, DataArray): - dataset = _to_dataset(obj) + dataset = obj._to_temp_dataset() input_is_array = True else: dataset = obj @@ -159,7 +150,7 @@ def _wrapper(func, obj, to_array, args, kwargs): template = infer_template(func, obj, *args, **kwargs) if isinstance(template, DataArray): result_is_array = True - template = _to_dataset(template) + template = template._to_temp_dataset() else: result_is_array = False From 008ce29425fa2b90b465ada57c07649213bad07a Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 8 Sep 2019 20:53:54 -0600 Subject: [PATCH 14/76] Add whats-new --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8796c79da4c..661da88b851 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -61,6 +61,9 @@ New functions/methods This requires `sparse>=0.8.0`. By `Nezar Abdennur `_ and `Guido Imperiale `_. +- Added :py:func:`~xarray.map_blocks`, modeled after :py:func:`dask.array.map_blocks` + By `Deepak Cherian `_. + - :py:meth:`~Dataset.from_dataframe` and :py:meth:`~DataArray.from_series` now support ``sparse=True`` for converting pandas objects into xarray objects wrapping sparse arrays. This is particularly useful with sparsely populated From adbe48e01efe5bf041897d84563f30f925572393 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 8 Sep 2019 21:06:57 -0600 Subject: [PATCH 15/76] chunks kwarg makes no sense right now. --- xarray/core/parallel.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index be332d6e415..8f1ac2d93ce 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -80,7 +80,7 @@ def _make_dict(x): return to_return -def map_blocks(func, obj, *args, chunks=None, **kwargs): +def map_blocks(func, obj, *args, **kwargs): """ Apply a function to each chunk of a DataArray or Dataset. This function is experimental and its signature may change. @@ -102,8 +102,6 @@ def map_blocks(func, obj, *args, chunks=None, **kwargs): shape of the provided DataArray. args: Passed on to func. - chunks: - dict mapping index name to chunk size. kwargs: Passed on to func. @@ -157,13 +155,6 @@ def _wrapper(func, obj, to_array, args, kwargs): # If two different variables have different chunking along the same dim # .chunks will raise an error. input_chunks = dataset.chunks - # get output chunks - if chunks is None: - # assume chunking doesn't change - chunks = input_chunks - for dim in chunks: - if dim not in template.dims: - chunks.pop(dim) indexes = dict(dataset.indexes) for dim in template.indexes: @@ -267,8 +258,8 @@ def _wrapper(func, obj, to_array, args, kwargs): dims = template[var].dims var_chunks = [] for dim in dims: - if dim in chunks: - var_chunks.append(chunks[dim]) + if dim in input_chunks: + var_chunks.append(input_chunks[dim]) else: if dim in indexes: var_chunks.append((len(indexes[dim]),)) From 924bf692f037ab2e51fb09c4804dde990ea77323 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Sep 2019 08:51:08 -0600 Subject: [PATCH 16/76] review feedback: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. skip index graph nodes. 2. var → name 3. quicker dataarray creation. 4. Add restrictions to docstring. 5. rename chunk construction task. 6. error when non-xarray object is returned. 7. restore non-coord dims. review --- xarray/core/parallel.py | 97 +++++++++++++++++++++++---------------- xarray/tests/test_dask.py | 48 ++++++++++++++++++- 2 files changed, 104 insertions(+), 41 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 8f1ac2d93ce..0590bcdbb77 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -35,7 +35,7 @@ def make_meta(obj): from dask.array.utils import meta_from_array if isinstance(obj, DataArray): - meta = DataArray(obj.data._meta, dims=obj.dims) + meta = DataArray(obj.data._meta, dims=obj.dims, name=obj.name) if isinstance(obj, Dataset): meta = Dataset() @@ -45,9 +45,14 @@ def make_meta(obj): else: meta_obj = meta_from_array(obj[name].data) meta[name] = DataArray(meta_obj, dims=obj[name].dims) + # meta[name] = DataArray(obj[name].dims, meta_obj) else: meta = obj + # TODO: deal with non-dim coords + # for coord_name in (set(obj.coords) - set(obj.dims)): # DataArrays should have _coord_names! + # coord = obj[coord_name] + return meta @@ -65,7 +70,7 @@ def infer_template(func, obj, *args, **kwargs): return template -def _make_dict(x): +def make_dict(x): # Dataset.to_dict() is too complicated # maps variable name to numpy array if isinstance(x, DataArray): @@ -93,6 +98,9 @@ def map_blocks(func, obj, *args, **kwargs): properties of the returned object such as dtype, variable names, new dimensions and new indexes (if any). + This function must + - return either a DataArray or a Dataset + This function cannot - change size of existing dimensions. - add new chunked dimensions. @@ -101,18 +109,24 @@ def map_blocks(func, obj, *args, **kwargs): Chunks of this object will be provided to 'func'. The function must not change shape of the provided DataArray. args: - Passed on to func. + Passed on to func. Cannot include chunked xarray objects. kwargs: - Passed on to func. + Passed on to func. Cannot include chunked xarray objects. Returns ------- DataArray or Dataset + Notes + ----- + + This function is designed to work with dask-backed xarray objects. See apply_ufunc for + a similar function that works with numpy arrays. + See Also -------- - dask.array.map_blocks + dask.array.map_blocks, xarray.apply_ufunc """ def _wrapper(func, obj, to_array, args, kwargs): @@ -129,7 +143,7 @@ def _wrapper(func, obj, to_array, args, kwargs): % name ) - to_return = _make_dict(result) + to_return = make_dict(result) return to_return @@ -149,26 +163,30 @@ def _wrapper(func, obj, to_array, args, kwargs): if isinstance(template, DataArray): result_is_array = True template = template._to_temp_dataset() - else: + elif isinstance(template, Dataset): result_is_array = False + else: + raise ValueError( + "Function must return an xarray DataArray or Dataset. Instead it returned %r" + % type(template) + ) # If two different variables have different chunking along the same dim # .chunks will raise an error. input_chunks = dataset.chunks - indexes = dict(dataset.indexes) - for dim in template.indexes: - if dim not in indexes: - indexes[dim] = template.indexes[dim] + # TODO: add a test that fails when template and dataset are switched + indexes = dict(template.indexes) + indexes.update(dataset.indexes) graph = {} gname = "%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset)) # map dims to list of chunk indexes - ichunk = {dim: range(len(input_chunks[dim])) for dim in input_chunks} + ichunk = {dim: range(len(chunks_v)) for dim, chunks_v in input_chunks.items()} # mapping from chunk index to slice bounds chunk_index_bounds = { - dim: np.cumsum((0,) + input_chunks[dim]) for dim in input_chunks + dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in input_chunks.items() } # iterate over all possible chunk combinations @@ -185,17 +203,15 @@ def _wrapper(func, obj, to_array, args, kwargs): for name, variable in dataset.variables.items(): # make a task that creates tuple of (dims, chunk) if dask.is_dask_collection(variable.data): - var_dask_keys = variable.__dask_keys__() - # recursively index into dask_keys nested list to get chunk - chunk = var_dask_keys + chunk = variable.__dask_keys__() for dim in variable.dims: chunk = chunk[chunk_index_dict[dim]] - task_name = ("tuple-" + dask.base.tokenize(chunk),) + v - graph[task_name] = (tuple, [variable.dims, chunk]) + chunk_variable_task = ("tuple-" + dask.base.tokenize(chunk),) + v + graph[chunk_variable_task] = (tuple, [variable.dims, chunk]) else: - # numpy array with possibly chunked dimensions + # non-dask array with possibly chunked dimensions # index into variable appropriately subsetter = dict() for dim in variable.dims: @@ -207,14 +223,14 @@ def _wrapper(func, obj, to_array, args, kwargs): ) subset = variable.isel(subsetter) - task_name = (name + dask.base.tokenize(subset),) + v - graph[task_name] = (tuple, [subset.dims, subset]) + chunk_variable_task = (name + dask.base.tokenize(subset),) + v + graph[chunk_variable_task] = (tuple, [subset.dims, subset]) # this task creates dict mapping variable name to above tuple - if name in dataset.data_vars: - data_vars.append([name, task_name]) - if name in dataset.coords: - coords.append([name, task_name]) + if name in dataset._coord_names: + coords.append([name, chunk_variable_task]) + else: + data_vars.append([name, chunk_variable_task]) from_wrapper = (gname,) + v graph[from_wrapper] = ( @@ -229,14 +245,15 @@ def _wrapper(func, obj, to_array, args, kwargs): # mapping from variable name to dask graph key var_key_map = {} for name, variable in template.variables.items(): - var_dims = variable.dims + if name in indexes: + continue # cannot tokenize "name" because the hash of is not invariant! # This happens when the user function does not set a name on the returned DataArray gname_l = "%s-%s" % (gname, name) var_key_map[name] = gname_l key = (gname_l,) - for dim in var_dims: + for dim in variable.dims: if dim in chunk_index_dict: key += (chunk_index_dict[dim],) else: @@ -248,26 +265,26 @@ def _wrapper(func, obj, to_array, args, kwargs): graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset]) result = Dataset() - for var, key in var_key_map.items(): - # indexes need to be known - # otherwise compute is called when DataArray is created - if var in indexes: - result[var] = indexes[var] - continue - - dims = template[var].dims + # a quicker way to assign indexes? + # indexes need to be known + # otherwise compute is called when DataArray is created + for name in template.indexes: + result[name] = indexes[name] + for name, key in var_key_map.items(): + dims = template[name].dims var_chunks = [] for dim in dims: if dim in input_chunks: var_chunks.append(input_chunks[dim]) - else: - if dim in indexes: - var_chunks.append((len(indexes[dim]),)) + elif dim in indexes: + var_chunks.append((len(indexes[dim]),)) data = dask.array.Array( - graph, name=key, chunks=var_chunks, dtype=template[var].dtype + graph, name=key, chunks=var_chunks, dtype=template[name].dtype ) - result[var] = DataArray(data=data, dims=dims, name=var) + result[name] = (dims, data) + + result = result.set_coords(template._coord_names) if result_is_array: result = _to_array(result) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index a6d25d8cde7..8925f9c3ca7 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -891,7 +891,7 @@ def make_da(): def make_ds(): map_ds = xr.Dataset() - map_ds["a"] = map_da + map_ds["a"] = make_da() map_ds["b"] = map_ds.a + 50 map_ds["c"] = map_ds.x + 20 map_ds = map_ds.chunk({"x": 4, "y": 5}) @@ -909,6 +909,18 @@ def make_ds(): map_ds = make_ds() +# DataArray.chunks is not a dict but Dataset.chunks is! +def assert_chunks_equal(a, b): + + if isinstance(a, DataArray): + a = a._to_temp_dataset() + + if isinstance(b, DataArray): + b = b._to_temp_dataset() + + assert a.chunks == b.chunks + + def simple_func(obj): result = obj.x + 5 * obj.y return result @@ -933,6 +945,12 @@ def bad_func(darray): with raises_regex(ValueError, "Length of the.* has changed."): xr.map_blocks(bad_func, map_da).compute() + def returns_numpy(darray): + return (darray * darray.x + 5 * darray.y).values + + with raises_regex(ValueError, "Function must return an xarray DataArray"): + xr.map_blocks(returns_numpy, map_da) + @pytest.mark.parametrize( "func, obj", @@ -942,6 +960,7 @@ def test_map_blocks(func, obj): actual = xr.map_blocks(func, obj) expected = func(obj) + assert_chunks_equal(expected, actual) xr.testing.assert_equal(expected, actual) @@ -951,4 +970,31 @@ def test_map_blocks_args(obj): expected = obj + 10 actual = xr.map_blocks(operator.add, obj, 10) + assert_chunks_equal(expected, actual) xr.testing.assert_equal(expected, actual) + + +def da_to_ds(da): + return da.to_dataset() + + +def ds_to_da(ds): + return ds.to_array() + + +@pytest.mark.parametrize( + "func, obj, return_type", + [[da_to_ds, map_da, xr.Dataset], [ds_to_da, map_ds, xr.DataArray]], +) +def map_blocks_transformations(func, obj, return_type): + assert isinstance(xr.map_blocks(func, obj), return_type) + + +# func(DataArray) -> Dataset +# func(Dataset) -> DataArray +# func output contains less variables +# func output contains new variables +# func changes dtypes +# func output contains less (or more) dimensions +# *args, **kwargs are passed through +# IndexVariables don't accidentally cause the whole graph to be computed (the logic you wrote in the main function is quite subtle!) From 8aed8e74f57feec0d5120a808481fd3678c12f90 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Sep 2019 14:58:47 -0600 Subject: [PATCH 17/76] Support nondim coords in make_meta. --- xarray/core/parallel.py | 12 +++++++----- xarray/tests/test_dask.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 0590bcdbb77..3ea5288e635 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -37,7 +37,7 @@ def make_meta(obj): if isinstance(obj, DataArray): meta = DataArray(obj.data._meta, dims=obj.dims, name=obj.name) - if isinstance(obj, Dataset): + elif isinstance(obj, Dataset): meta = Dataset() for name, variable in obj.variables.items(): if dask.is_dask_collection(variable): @@ -49,9 +49,11 @@ def make_meta(obj): else: meta = obj - # TODO: deal with non-dim coords - # for coord_name in (set(obj.coords) - set(obj.dims)): # DataArrays should have _coord_names! - # coord = obj[coord_name] + # TODO: DataArrays should have _coord_names! + + if isinstance(obj, (DataArray, Dataset)): + for coord_name in set(obj.coords) - set(obj.dims): + meta = meta.set_coords(coord_name) return meta @@ -268,7 +270,7 @@ def _wrapper(func, obj, to_array, args, kwargs): # a quicker way to assign indexes? # indexes need to be known # otherwise compute is called when DataArray is created - for name in template.indexes: + for name in template.dims: result[name] = indexes[name] for name, key in var_key_map.items(): dims = template[name].dims diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 8925f9c3ca7..f5a194ab4b5 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -990,6 +990,20 @@ def map_blocks_transformations(func, obj, return_type): assert isinstance(xr.map_blocks(func, obj), return_type) +def test_make_meta(): + from ..core.parallel import make_meta + + meta = make_meta(map_ds) + + for variable in map_ds._coord_names: + assert variable in meta._coord_names + assert meta.coords[variable].shape == (0,) * meta.coords[variable].ndim + + for variable in map_ds.data_vars: + assert variable in meta.data_vars + assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim + + # func(DataArray) -> Dataset # func(Dataset) -> DataArray # func output contains less variables From d0797f6bc9086828c664a7193f4281864b8d7cb7 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Sep 2019 14:56:46 -0600 Subject: [PATCH 18/76] Add Dataset.unify_chunks --- xarray/core/dataarray.py | 4 +++ xarray/core/dataset.py | 56 ++++++++++++++++++++++++++++++++++++++- xarray/core/parallel.py | 3 ++- xarray/tests/test_dask.py | 33 ++++++++++++++++++++--- 4 files changed, 90 insertions(+), 6 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e5d53b1943a..de95d2d0fe4 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2981,6 +2981,10 @@ def integrate( ds = self._to_temp_dataset().integrate(dim, datetime_unit) return self._from_temp_dataset(ds) + def unify_chunks(self): + ds = self.copy()._to_temp_dataset().unify_chunks() + return self._from_temp_dataset(ds) + # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names # noqa str = property(StringAccessor) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f3ad4650b38..2485e9d0c2c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1669,7 +1669,10 @@ def chunks(self) -> Mapping[Hashable, Tuple[int, ...]]: if v.chunks is not None: for dim, c in zip(v.dims, v.chunks): if dim in chunks and c != chunks[dim]: - raise ValueError("inconsistent chunks") + raise ValueError( + "Object has inconsistent chunks along dimension %r. This can be fixed by calling unify_chunks()." + % dim + ) chunks[dim] = c return Frozen(SortedKeysDict(chunks)) @@ -4997,5 +5000,56 @@ def filter_by_attrs(self, **kwargs): selection.append(var_name) return self[selection] + def unify_chunks(self): + """ Unifies chunksize along all chunked dimensions of this Dataset. + + Returns + ------- + + Dataset with consistent chunk sizes for all dask-array variables + + See Also + -------- + + dask.array.core.unify_chunks + """ + import dask.array + + ds = self.copy() + + alphabet = "abcdefghijklmnopqrstuvwxyz" # this is stupid :) + alphamap = dict(zip(ds.dims, alphabet)) + + dask_arrays = {} + for name, variable in ds.variables.items(): + if isinstance(variable.data, dask.array.Array): + index_string = "".join([alphamap[dim] for dim in variable.dims]) + dask_arrays[name] = (variable.data, index_string) + + args = [] + for value in dask_arrays.values(): + args.append(value[0]) + args.append(value[1]) + + new_chunks, rechunked_arrays = dask.array.core.unify_chunks(*args) + + # invert the map + for dim, alpha in alphamap.items(): + if alpha not in new_chunks: + continue + new_chunks[dim] = new_chunks[alpha] + del new_chunks[alpha] + + try: + if new_chunks == ds.chunks: + return ds + except ValueError: # "inconsistent chunks" + pass + + for name, new_array in zip(dask_arrays.keys(), rechunked_arrays): + ds.variables[name]._data = new_array + + return ds + ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 3ea5288e635..441668215de 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -174,7 +174,8 @@ def _wrapper(func, obj, to_array, args, kwargs): ) # If two different variables have different chunking along the same dim - # .chunks will raise an error. + # fix that by "unifying chunks" + dataset = dataset.unify_chunks() input_chunks = dataset.chunks # TODO: add a test that fails when template and dataset are switched diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index f5a194ab4b5..f9f8613edea 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -899,6 +899,7 @@ def make_ds(): map_ds["z"] = [0, 1, 2, 3] map_ds["e"] = map_ds.x + map_ds.y + map_ds.z map_ds.attrs["test"] = "test" + map_ds["xx"] = (map_ds["a"] * map_ds.y).chunk({"y": 20}) return map_ds @@ -909,20 +910,44 @@ def make_ds(): map_ds = make_ds() -# DataArray.chunks is not a dict but Dataset.chunks is! +def test_unify_chunks(): + with raises_regex(ValueError, "inconsistent chunks"): + map_ds.chunks + + expected = {"x": (4, 4, 2), "y": (5, 5, 5, 5)} + actual = map_ds.unify_chunks().chunks + assert expected == actual + + +# TODO: DataArray.chunks is not a dict but Dataset.chunks is! def assert_chunks_equal(a, b): + # unchunked dimensions in input to map_blocks become 1 chunk. + # do this manually before comparing + def at_least_one_chunk(obj): + new_chunks = {} + for dim in obj.dims: + if dim not in obj.chunks: + new_chunks[dim] = len(obj[dim]) + else: + # must preserve old chunks too + new_chunks[dim] = obj.chunks[dim][0] + return obj.chunk(new_chunks) + if isinstance(a, DataArray): a = a._to_temp_dataset() if isinstance(b, DataArray): b = b._to_temp_dataset() + a = at_least_one_chunk(a.unify_chunks()) + b = at_least_one_chunk(b.unify_chunks()) + assert a.chunks == b.chunks def simple_func(obj): - result = obj.x + 5 * obj.y + result = obj + obj.x + 5 * obj.y return result @@ -959,7 +984,7 @@ def returns_numpy(darray): def test_map_blocks(func, obj): actual = xr.map_blocks(func, obj) - expected = func(obj) + expected = func(obj).unify_chunks() assert_chunks_equal(expected, actual) xr.testing.assert_equal(expected, actual) @@ -968,7 +993,7 @@ def test_map_blocks(func, obj): def test_map_blocks_args(obj): import operator - expected = obj + 10 + expected = obj.unify_chunks() + 10 actual = xr.map_blocks(operator.add, obj, 10) assert_chunks_equal(expected, actual) xr.testing.assert_equal(expected, actual) From 765ca5d44219e6efb831d613660d9cc36c385289 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Sep 2019 16:15:04 -0600 Subject: [PATCH 19/76] doc updates. --- doc/api.rst | 2 ++ doc/whats-new.rst | 10 +++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 4940f890dd9..bcd83fa4075 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -500,6 +500,7 @@ Dataset methods Dataset.persist Dataset.load Dataset.chunk + Dataset.unify_chunks Dataset.filter_by_attrs Dataset.info @@ -530,6 +531,7 @@ DataArray methods DataArray.persist DataArray.load DataArray.chunk + DataArray.unify_chunks GroupBy objects =============== diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4670eeca157..1268f7f3982 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -18,6 +18,13 @@ What's New v0.13.1 (unreleased) -------------------- + +New functions/methods +~~~~~~~~~~~~~~~~~~~~~ + +- Added :py:func:`~xarray.map_blocks`, modeled after :py:func:`dask.array.map_blocks` + By `Deepak Cherian `_. + .. _whats-new.0.13.0: v0.13.0 (17 Sep 2019) @@ -86,9 +93,6 @@ New functions/methods This requires `sparse>=0.8.0`. By `Nezar Abdennur `_ and `Guido Imperiale `_. -- Added :py:func:`~xarray.map_blocks`, modeled after :py:func:`dask.array.map_blocks` - By `Deepak Cherian `_. - - :py:meth:`~Dataset.from_dataframe` and :py:meth:`~DataArray.from_series` now support ``sparse=True`` for converting pandas objects into xarray objects wrapping sparse arrays. This is particularly useful with sparsely populated From f0de1db72a47473cef4bd721f2297fbb959a2d59 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Sep 2019 16:16:08 -0600 Subject: [PATCH 20/76] minor. --- doc/whats-new.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e460e1e230c..aba95c09c73 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -24,6 +24,7 @@ New functions/methods - Added :py:func:`~xarray.map_blocks`, modeled after :py:func:`dask.array.map_blocks` By `Deepak Cherian `_. + Bug fixes ~~~~~~~~~ - Reintroduce support for :mod:`weakref` (broken in v0.13.0). Support has been From 1251a5dd99b36718c5e949bd5ebe570507fc8131 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Sep 2019 16:22:57 -0600 Subject: [PATCH 21/76] update comment. --- xarray/core/parallel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 441668215de..bd02f7260f5 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -250,8 +250,9 @@ def _wrapper(func, obj, to_array, args, kwargs): for name, variable in template.variables.items(): if name in indexes: continue - # cannot tokenize "name" because the hash of is not invariant! - # This happens when the user function does not set a name on the returned DataArray + # cannot tokenize "name" because the hash of ReprObject () + # is a function of its value. This happens when the user function does not + # set a name on the returned DataArray gname_l = "%s-%s" % (gname, name) var_key_map[name] = gname_l From 47a0e39872e6fcaa2983de908959c2ed96601748 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Sep 2019 16:25:04 -0600 Subject: [PATCH 22/76] More complicated test dataset. Tests fail :X --- xarray/tests/test_dask.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 56e08bc1df0..a8a02576531 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -937,6 +937,10 @@ def make_ds(): map_ds["d"] = ("z", [1, 1, 1, 1]) map_ds["z"] = [0, 1, 2, 3] map_ds["e"] = map_ds.x + map_ds.y + map_ds.z + map_ds.coords["c1"] = 0.5 + map_ds.coords["cx"] = ("x", np.arange(len(map_da.x))) + map_ds.coords["cxy"] = (("x", "y"), map_da.x * map_da.y) + map_ds.coords["cxy"] = map_ds.cxy.chunk({"y": 10}) map_ds.attrs["test"] = "test" map_ds["xx"] = (map_ds["a"] * map_ds.y).chunk({"y": 20}) From fa44d32cddb264c3ab569437bf407ba08f09bb70 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Sep 2019 16:28:19 -0600 Subject: [PATCH 23/76] Don't know why compute is needed. --- xarray/tests/test_dask.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index a8a02576531..e116b1355ef 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1029,7 +1029,8 @@ def test_map_blocks(func, obj): actual = xr.map_blocks(func, obj) expected = func(obj).unify_chunks() assert_chunks_equal(expected, actual) - xr.testing.assert_equal(expected, actual) + # why is compute needed? + xr.testing.assert_equal(expected.compute(), actual.compute()) @pytest.mark.parametrize("obj", [map_da, map_ds]) @@ -1039,7 +1040,8 @@ def test_map_blocks_args(obj): expected = obj.unify_chunks() + 10 actual = xr.map_blocks(operator.add, obj, 10) assert_chunks_equal(expected, actual) - xr.testing.assert_equal(expected, actual) + # why is compute needed? + xr.testing.assert_equal(expected.compute(), actual.compute()) def da_to_ds(da): From a6e84efd0f41ca8226c5ae3093ce4432d8df254a Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Sep 2019 16:35:29 -0600 Subject: [PATCH 24/76] work with DataArray nondim coords. --- xarray/core/parallel.py | 15 ++++++++++----- xarray/tests/test_dask.py | 4 +++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index bd02f7260f5..10f690c02be 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -35,9 +35,13 @@ def make_meta(obj): from dask.array.utils import meta_from_array if isinstance(obj, DataArray): - meta = DataArray(obj.data._meta, dims=obj.dims, name=obj.name) + to_array = True + obj_array = obj.copy() + obj = obj._to_temp_dataset() + else: + to_array = False - elif isinstance(obj, Dataset): + if isinstance(obj, Dataset): meta = Dataset() for name, variable in obj.variables.items(): if dask.is_dask_collection(variable): @@ -49,12 +53,13 @@ def make_meta(obj): else: meta = obj - # TODO: DataArrays should have _coord_names! - - if isinstance(obj, (DataArray, Dataset)): + if isinstance(obj, Dataset): for coord_name in set(obj.coords) - set(obj.dims): meta = meta.set_coords(coord_name) + if to_array: + meta = obj_array._from_temp_dataset(meta) + return meta diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index e116b1355ef..7fa6e520566 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -920,12 +920,14 @@ def test_dask_layers_and_dependencies(): def make_da(): - return xr.DataArray( + da = xr.DataArray( np.ones((10, 20)), dims=["x", "y"], coords={"x": np.arange(10), "y": np.arange(100, 120)}, name="a", ).chunk({"x": 4, "y": 5}) + da.coords["c2"] = 0.5 + return da def make_ds(): From c28b4020c4f890a8b5ede1c55eb5af295c2513fc Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Sep 2019 18:51:13 -0600 Subject: [PATCH 25/76] fastpath unify_chunks --- xarray/core/dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 25ea2a873e9..4e3f4495746 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5164,6 +5164,12 @@ def unify_chunks(self): """ import dask.array + try: + if self.chunks: + return self + except ValueError: # "inconsistent chunks" + pass + ds = self.copy() alphabet = "abcdefghijklmnopqrstuvwxyz" # this is stupid :) @@ -5189,12 +5195,6 @@ def unify_chunks(self): new_chunks[dim] = new_chunks[alpha] del new_chunks[alpha] - try: - if new_chunks == ds.chunks: - return ds - except ValueError: # "inconsistent chunks" - pass - for name, new_array in zip(dask_arrays.keys(), rechunked_arrays): ds.variables[name]._data = new_array From 1694d0330b5cf9391d133dd0cd91fb50ba63739c Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 20 Sep 2019 07:57:30 -0600 Subject: [PATCH 26/76] comment. --- xarray/core/dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4e3f4495746..d051cf2d17c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5172,7 +5172,9 @@ def unify_chunks(self): ds = self.copy() - alphabet = "abcdefghijklmnopqrstuvwxyz" # this is stupid :) + # dask unify_chunks needs dimensions named using a single character + # map our dimension names to a single character + alphabet = "abcdefghijklmnopqrstuvwxyz0123456789" alphamap = dict(zip(ds.dims, alphabet)) dask_arrays = {} From cf04ec8676097a21edceaffbb312e4de621ce7ad Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 20 Sep 2019 07:57:40 -0600 Subject: [PATCH 27/76] much improved tests. --- xarray/tests/test_dask.py | 77 +++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 44 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 7fa6e520566..3663a5d8ca2 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -991,23 +991,6 @@ def at_least_one_chunk(obj): assert a.chunks == b.chunks -def simple_func(obj): - result = obj + obj.x + 5 * obj.y - return result - - -def complicated_func(obj): - new = obj.copy() - new = ( - new[["a", "b"]] - .rename({"a": "new_var1"}) - .expand_dims(k=[0, 1, 2]) - .transpose("k", "y", "x") - ) - new["b"] = new.b.astype("int32") - return new - - def test_map_blocks_error(): def bad_func(darray): return (darray * darray.x + 5 * darray.y)[:1, :1] @@ -1022,13 +1005,14 @@ def returns_numpy(darray): xr.map_blocks(returns_numpy, map_da) -@pytest.mark.parametrize( - "func, obj", - [[simple_func, map_da], [simple_func, map_ds], [complicated_func, map_ds]], -) -def test_map_blocks(func, obj): +@pytest.mark.parametrize("obj", [map_da, map_ds]) +def test_map_blocks(obj): + def func(obj): + result = obj + obj.x + 5 * obj.y + return result - actual = xr.map_blocks(func, obj) + with raise_if_dask_computes(): + actual = xr.map_blocks(func, obj) expected = func(obj).unify_chunks() assert_chunks_equal(expected, actual) # why is compute needed? @@ -1040,26 +1024,41 @@ def test_map_blocks_args(obj): import operator expected = obj.unify_chunks() + 10 - actual = xr.map_blocks(operator.add, obj, 10) + with raise_if_dask_computes(): + actual = xr.map_blocks(operator.add, obj, 10) assert_chunks_equal(expected, actual) # why is compute needed? xr.testing.assert_equal(expected.compute(), actual.compute()) -def da_to_ds(da): - return da.to_dataset() - - -def ds_to_da(ds): - return ds.to_array() +@pytest.mark.parametrize("obj", [map_da, map_ds]) +def test_map_blocks_kwargs(obj): + expected = xr.full_like(obj, fill_value=np.nan) + with raise_if_dask_computes(): + actual = xr.map_blocks(xr.full_like, obj, fill_value=np.nan) + assert_chunks_equal(expected, actual) + # why is compute needed? + xr.testing.assert_equal(expected.compute(), actual.compute()) @pytest.mark.parametrize( - "func, obj, return_type", - [[da_to_ds, map_da, xr.Dataset], [ds_to_da, map_ds, xr.DataArray]], + "func, obj", + [ + [lambda x: x.to_dataset(), map_da], + [lambda x: x.to_array(), map_ds], + [lambda x: x.drop("a"), map_ds], + [lambda x: x.expand_dims(k=[1, 2, 3]), map_ds], + [lambda x: x.expand_dims(k=[1, 2, 3]), map_da], + [lambda x: x.isel(x=1), map_ds], + [lambda x: x.isel(x=1).drop("x"), map_da], + [lambda x: x.assign_coords(new_coord=("y", x.y * 2)), map_da], + [lambda x: x.astype(np.int32), map_da], + [lambda x: x.rename({"a": "new1", "b": "new2"}), map_ds], + ], ) -def map_blocks_transformations(func, obj, return_type): - assert isinstance(xr.map_blocks(func, obj), return_type) +def map_blocks_transformations(func, obj, expected): + with raise_if_dask_computes(): + assert_equal(xr.map_blocks(func, obj), func(obj)) def test_make_meta(): @@ -1074,13 +1073,3 @@ def test_make_meta(): for variable in map_ds.data_vars: assert variable in meta.data_vars assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim - - -# func(DataArray) -> Dataset -# func(Dataset) -> DataArray -# func output contains less variables -# func output contains new variables -# func changes dtypes -# func output contains less (or more) dimensions -# *args, **kwargs are passed through -# IndexVariables don't accidentally cause the whole graph to be computed (the logic you wrote in the main function is quite subtle!) From 3e9db261a2011c01a6b6c8eafcb4bc5f792c250a Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 20 Sep 2019 08:31:03 -0600 Subject: [PATCH 28/76] Change args, kwargs syntax. --- xarray/core/parallel.py | 17 +++++++++++------ xarray/tests/test_dask.py | 15 ++++++++++----- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 10f690c02be..c2e955a7ec9 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -92,7 +92,7 @@ def make_dict(x): return to_return -def map_blocks(func, obj, *args, **kwargs): +def map_blocks(func, obj, args=[], kwargs={}): """ Apply a function to each chunk of a DataArray or Dataset. This function is experimental and its signature may change. @@ -115,11 +115,10 @@ def map_blocks(func, obj, *args, **kwargs): obj: DataArray, Dataset Chunks of this object will be provided to 'func'. The function must not change shape of the provided DataArray. - args: - Passed on to func. Cannot include chunked xarray objects. - kwargs: - Passed on to func. Cannot include chunked xarray objects. - + args: list + Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. + kwargs: dict + Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. Returns ------- @@ -154,6 +153,12 @@ def _wrapper(func, obj, to_array, args, kwargs): return to_return + if not isinstance(args, list): + raise ValueError("args must be a list.") + + if not isinstance(kwargs, dict): + raise ValueError("kwargs must be a dictionary.") + if not dask.is_dask_collection(obj): raise ValueError( "map_blocks can only be used with dask-backed DataArrays. Use .chunk() to convert to a Dask array." diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 3663a5d8ca2..244ee5873a6 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1,3 +1,4 @@ +import operator import pickle from collections import OrderedDict from contextlib import suppress @@ -1004,6 +1005,12 @@ def returns_numpy(darray): with raises_regex(ValueError, "Function must return an xarray DataArray"): xr.map_blocks(returns_numpy, map_da) + with raises_regex(ValueError, "args must be"): + xr.map_blocks(operator.add, map_da, args=10) + + with raises_regex(ValueError, "kwargs must be"): + xr.map_blocks(operator.add, map_da, args=[10], kwargs=[20]) + @pytest.mark.parametrize("obj", [map_da, map_ds]) def test_map_blocks(obj): @@ -1020,12 +1027,10 @@ def func(obj): @pytest.mark.parametrize("obj", [map_da, map_ds]) -def test_map_blocks_args(obj): - import operator - +def test_map_blocks_convert_args_to_list(obj): expected = obj.unify_chunks() + 10 with raise_if_dask_computes(): - actual = xr.map_blocks(operator.add, obj, 10) + actual = xr.map_blocks(operator.add, obj, [10]) assert_chunks_equal(expected, actual) # why is compute needed? xr.testing.assert_equal(expected.compute(), actual.compute()) @@ -1035,7 +1040,7 @@ def test_map_blocks_args(obj): def test_map_blocks_kwargs(obj): expected = xr.full_like(obj, fill_value=np.nan) with raise_if_dask_computes(): - actual = xr.map_blocks(xr.full_like, obj, fill_value=np.nan) + actual = xr.map_blocks(xr.full_like, obj, kwargs=dict(fill_value=np.nan)) assert_chunks_equal(expected, actual) # why is compute needed? xr.testing.assert_equal(expected.compute(), actual.compute()) From 20fdde64312db3d2510c0ce9ca7492737f6b967a Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 20 Sep 2019 08:47:16 -0600 Subject: [PATCH 29/76] Add dataset, dataarray methods. --- xarray/core/dataarray.py | 5 +++++ xarray/core/dataset.py | 5 +++++ xarray/tests/test_dask.py | 13 +++++++++++++ 3 files changed, 23 insertions(+) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 690b73d12b3..20559c372a9 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3038,6 +3038,11 @@ def unify_chunks(self): ds = self.copy()._to_temp_dataset().unify_chunks() return self._from_temp_dataset(ds) + def map_blocks(self, func, args=[], kwargs={}): + from .parallel import map_blocks + + return map_blocks(func, self, args, kwargs) + # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names # noqa str = property(StringAccessor) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d051cf2d17c..577b7153475 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5202,5 +5202,10 @@ def unify_chunks(self): return ds + def map_blocks(self, func, args=[], kwargs={}): + from .parallel import map_blocks + + return map_blocks(func, self, args, kwargs) + ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 244ee5873a6..da8d00930f2 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1066,6 +1066,19 @@ def map_blocks_transformations(func, obj, expected): assert_equal(xr.map_blocks(func, obj), func(obj)) +@pytest.mark.parametrize("obj", [map_da, map_ds]) +def test_map_blocks_object_method(obj): + def func(obj): + result = obj + obj.x + 5 * obj.y + return result + + with raise_if_dask_computes(): + expected = xr.map_blocks(func, obj) + actual = obj.map_blocks(func) + + assert_equal(expected.compute(), actual.compute()) + + def test_make_meta(): from ..core.parallel import make_meta From 22e9c4e3987760cd936333cff4954ad8f73532ff Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 20 Sep 2019 08:53:58 -0600 Subject: [PATCH 30/76] api.rst --- doc/api.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index bcd83fa4075..98ac63d29c1 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -501,6 +501,7 @@ Dataset methods Dataset.load Dataset.chunk Dataset.unify_chunks + Dataset.map_blocks Dataset.filter_by_attrs Dataset.info @@ -532,6 +533,7 @@ DataArray methods DataArray.load DataArray.chunk DataArray.unify_chunks + DataArray.map_blocks GroupBy objects =============== From b145787b68ff2676be1251171aca8f9a42aeeca6 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 20 Sep 2019 08:59:55 -0600 Subject: [PATCH 31/76] docstrings. --- xarray/core/dataarray.py | 43 ++++++++++++++++++++++++++++++++++++++++ xarray/core/dataset.py | 43 ++++++++++++++++++++++++++++++++++++++++ xarray/core/parallel.py | 10 +++++++--- 3 files changed, 93 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 20559c372a9..15528c4191e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3039,6 +3039,49 @@ def unify_chunks(self): return self._from_temp_dataset(ds) def map_blocks(self, func, args=[], kwargs={}): + """ + Apply a function to each chunk of this DataArray. This function is experimental + and its signature may change. + + Parameters + ---------- + func: callable + User-provided function that should accept xarray objects. + This function will receive a subset of this dataset, corresponding to one chunk along + each chunked dimension. + The function will be run on a small piece of data that looks like 'obj' to determine + properties of the returned object such as dtype, variable names, + new dimensions and new indexes (if any). + + This function must + - return either a single DataArray or a single Dataset + + This function cannot + - change size of existing dimensions. + - add new chunked dimensions. + + If your function expects numpy arrays, see `xarray.apply_ufunc` + + args: list + Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. + kwargs: dict + Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. + + Returns + ------- + A single DataArray or Dataset + + Notes + ----- + + This function is designed to work with dask-backed xarray objects. See apply_ufunc for + a similar function that works with numpy arrays. + + See Also + -------- + dask.array.map_blocks, xarray.apply_ufunc + """ + from .parallel import map_blocks return map_blocks(func, self, args, kwargs) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 577b7153475..0e8ed60483d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5203,6 +5203,49 @@ def unify_chunks(self): return ds def map_blocks(self, func, args=[], kwargs={}): + """ + Apply a function to each chunk of this Dataset. This function is experimental + and its signature may change. + + Parameters + ---------- + func: callable + User-provided function that should accept xarray objects. + This function will receive a subset of this dataset, corresponding to one chunk along + each chunked dimension. + The function will be run on a small piece of data that looks like 'obj' to determine + properties of the returned object such as dtype, variable names, + new dimensions and new indexes (if any). + + This function must + - return either a single DataArray or a single Dataset + + This function cannot + - change size of existing dimensions. + - add new chunked dimensions. + + If your function expects numpy arrays, see `xarray.apply_ufunc` + + args: list + Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. + kwargs: dict + Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. + + Returns + ------- + A single DataArray or Dataset + + Notes + ----- + + This function is designed to work with dask-backed xarray objects. See apply_ufunc for + a similar function that works with numpy arrays. + + See Also + -------- + dask.array.map_blocks, xarray.apply_ufunc + """ + from .parallel import map_blocks return map_blocks(func, self, args, kwargs) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index c2e955a7ec9..528a822a366 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -100,18 +100,22 @@ def map_blocks(func, obj, args=[], kwargs={}): Parameters ---------- func: callable - User-provided function that should accept DataArrays corresponding to one chunk. + User-provided function that should accept xarray objects. + This function will receive a subset of this dataset, corresponding to one chunk along + each chunked dimension. The function will be run on a small piece of data that looks like 'obj' to determine properties of the returned object such as dtype, variable names, new dimensions and new indexes (if any). This function must - - return either a DataArray or a Dataset + - return either a single DataArray or a single Dataset This function cannot - change size of existing dimensions. - add new chunked dimensions. + If your function expects numpy arrays, see `xarray.apply_ufunc` + obj: DataArray, Dataset Chunks of this object will be provided to 'func'. The function must not change shape of the provided DataArray. @@ -122,7 +126,7 @@ def map_blocks(func, obj, args=[], kwargs={}): Returns ------- - DataArray or Dataset + A single DataArray or Dataset Notes ----- From f600c4a78a17462f4ca07888859a55f299e24133 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 23 Sep 2019 10:14:55 -0600 Subject: [PATCH 32/76] Fix unify_chunks. --- xarray/core/dataarray.py | 12 ++++++++++++ xarray/core/dataset.py | 38 +++++++++++++------------------------- xarray/tests/test_dask.py | 7 ++++--- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 15528c4191e..3387caf3821 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3035,6 +3035,18 @@ def integrate( return self._from_temp_dataset(ds) def unify_chunks(self): + """ Unifies chunksize along all chunked dimensions of this DataArray. + + Returns + ------- + + DataArray with consistent chunk sizes for all dask-array variables + + See Also + -------- + + dask.array.core.unify_chunks + """ ds = self.copy()._to_temp_dataset().unify_chunks() return self._from_temp_dataset(ds) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 0e8ed60483d..1e148b2e8db 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5162,42 +5162,30 @@ def unify_chunks(self): dask.array.core.unify_chunks """ - import dask.array - try: - if self.chunks: - return self + self.chunks + return self.copy() except ValueError: # "inconsistent chunks" pass + import dask.array + ds = self.copy() - # dask unify_chunks needs dimensions named using a single character - # map our dimension names to a single character - alphabet = "abcdefghijklmnopqrstuvwxyz0123456789" - alphamap = dict(zip(ds.dims, alphabet)) + dims_pos_map = {dim: index for index, dim in enumerate(ds.dims)} - dask_arrays = {} + dask_array_names = [] + dask_unify_args = [] for name, variable in ds.variables.items(): if isinstance(variable.data, dask.array.Array): - index_string = "".join([alphamap[dim] for dim in variable.dims]) - dask_arrays[name] = (variable.data, index_string) + dims_tuple = [dims_pos_map[dim] for dim in variable.dims] + dask_array_names.append(name) + dask_unify_args.append(variable.data) + dask_unify_args.append(dims_tuple) - args = [] - for value in dask_arrays.values(): - args.append(value[0]) - args.append(value[1]) - - new_chunks, rechunked_arrays = dask.array.core.unify_chunks(*args) - - # invert the map - for dim, alpha in alphamap.items(): - if alpha not in new_chunks: - continue - new_chunks[dim] = new_chunks[alpha] - del new_chunks[alpha] + _, rechunked_arrays = dask.array.core.unify_chunks(*dask_unify_args) - for name, new_array in zip(dask_arrays.keys(), rechunked_arrays): + for name, new_array in zip(dask_array_names, rechunked_arrays): ds.variables[name]._data = new_array return ds diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index da8d00930f2..f56dbe5415a 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -960,9 +960,10 @@ def test_unify_chunks(): with raises_regex(ValueError, "inconsistent chunks"): map_ds.chunks - expected = {"x": (4, 4, 2), "y": (5, 5, 5, 5)} - actual = map_ds.unify_chunks().chunks - assert expected == actual + expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5)} + actual_chunks = map_ds.unify_chunks().chunks + assert expected_chunks == actual_chunks + assert_identical(map_ds, map_ds.unify_chunks()) # TODO: DataArray.chunks is not a dict but Dataset.chunks is! From 4af5a67490e40bdc664e0eb10d6eebe74bde540f Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 23 Sep 2019 10:28:56 -0600 Subject: [PATCH 33/76] Move assert_chunks_equal to xarray.testing. --- doc/api.rst | 1 + xarray/testing.py | 23 +++++++++++++++++++++++ xarray/tests/test_dask.py | 28 +--------------------------- 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 98ac63d29c1..eeb61e20994 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -634,6 +634,7 @@ Testing testing.assert_equal testing.assert_identical testing.assert_allclose + testing.assert_chunks_equal Exceptions ========== diff --git a/xarray/testing.py b/xarray/testing.py index 9fa58b64001..7e04a586fba 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -142,6 +142,29 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): raise TypeError("{} not supported by assertion comparison".format(type(a))) +def assert_chunks_equal(a, b): + """ + Assert that chunksizes along chunked dimensions are equal. + + Parameters + ---------- + a : xarray.Dataset, xarray.DataArray or xarray.Variable + The first object to compare. + b : xarray.Dataset, xarray.DataArray or xarray.Variable + The second object to compare. + """ + + if isinstance(a, DataArray): + a = a._to_temp_dataset() + + if isinstance(b, DataArray): + b = b._to_temp_dataset() + + left = a.chunk().unify_chunks() + right = b.chunk().unify_chunks() + assert left.chunks == right.chunks + + def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims): assert isinstance(indexes, OrderedDict), indexes assert all(isinstance(v, pd.Index) for v in indexes.values()), { diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index f56dbe5415a..1461b26a281 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -13,6 +13,7 @@ import xarray.ufuncs as xu from xarray import DataArray, Dataset, Variable from xarray.tests import mock +from xarray.testing import assert_chunks_equal from . import ( assert_allclose, @@ -966,33 +967,6 @@ def test_unify_chunks(): assert_identical(map_ds, map_ds.unify_chunks()) -# TODO: DataArray.chunks is not a dict but Dataset.chunks is! -def assert_chunks_equal(a, b): - - # unchunked dimensions in input to map_blocks become 1 chunk. - # do this manually before comparing - def at_least_one_chunk(obj): - new_chunks = {} - for dim in obj.dims: - if dim not in obj.chunks: - new_chunks[dim] = len(obj[dim]) - else: - # must preserve old chunks too - new_chunks[dim] = obj.chunks[dim][0] - return obj.chunk(new_chunks) - - if isinstance(a, DataArray): - a = a._to_temp_dataset() - - if isinstance(b, DataArray): - b = b._to_temp_dataset() - - a = at_least_one_chunk(a.unify_chunks()) - b = at_least_one_chunk(b.unify_chunks()) - - assert a.chunks == b.chunks - - def test_map_blocks_error(): def bad_func(darray): return (darray * darray.x + 5 * darray.y)[:1, :1] From 3ca4b7bf747dd2ac8cdcde8ee088b5a60835165a Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 23 Sep 2019 10:30:55 -0600 Subject: [PATCH 34/76] minor changes. --- xarray/core/parallel.py | 74 ++++++++++++++++----------------------- xarray/tests/test_dask.py | 4 +-- 2 files changed, 32 insertions(+), 46 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 528a822a366..5db1d556051 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -2,6 +2,7 @@ import dask import dask.array from dask.highlevelgraph import HighLevelGraph + from dask.array.utils import meta_from_array except ImportError: pass @@ -13,14 +14,16 @@ from .dataarray import DataArray from .dataset import Dataset +from typing import Sequence, Mapping + -def _to_array(obj): +def dataset_to_dataarray(obj: Dataset) -> DataArray: if not isinstance(obj, Dataset): - raise ValueError("Trying to convert DataArray to DataArray!") + raise TypeError("Expected Dataset, got %s" % type(obj)) if len(obj.data_vars) > 1: - raise ValueError( - "Trying to convert Dataset with more than one variable to DataArray" + raise TypeError( + "Trying to convert Dataset with more than one data variable to DataArray" ) name = list(obj.data_vars)[0] @@ -32,8 +35,6 @@ def _to_array(obj): def make_meta(obj): - from dask.array.utils import meta_from_array - if isinstance(obj, DataArray): to_array = True obj_array = obj.copy() @@ -47,15 +48,13 @@ def make_meta(obj): if dask.is_dask_collection(variable): meta_obj = obj[name].data._meta else: - meta_obj = meta_from_array(obj[name].data) - meta[name] = DataArray(meta_obj, dims=obj[name].dims) - # meta[name] = DataArray(obj[name].dims, meta_obj) + meta_obj = meta_from_array(variable.data) + meta[name] = DataArray(meta_obj, dims=variable.dims) else: meta = obj if isinstance(obj, Dataset): - for coord_name in set(obj.coords) - set(obj.dims): - meta = meta.set_coords(coord_name) + meta = meta.set_coords(obj.coords) if to_array: meta = obj_array._from_temp_dataset(meta) @@ -65,14 +64,9 @@ def make_meta(obj): def infer_template(func, obj, *args, **kwargs): """ Infer return object by running the function on meta objects. """ - meta_args = [] - for arg in (obj,) + args: - meta_args.append(make_meta(arg)) + meta_args = [make_meta(arg) for arg in (obj,) + args] - try: - template = func(*meta_args, **kwargs) - except ValueError: - raise ValueError("Cannot infer object returned by user-provided function.") + template = func(*meta_args, **kwargs) return template @@ -83,13 +77,7 @@ def make_dict(x): if isinstance(x, DataArray): x = x._to_temp_dataset() - to_return = dict() - for var in x.variables: - # if var not in x: - # raise ValueError("Variable %r not found in returned object." % var) - to_return[var] = x[var].values - - return to_return + return {k: v.data for k, v in x.variables.items()} def map_blocks(func, obj, args=[], kwargs={}): @@ -103,9 +91,10 @@ def map_blocks(func, obj, args=[], kwargs={}): User-provided function that should accept xarray objects. This function will receive a subset of this dataset, corresponding to one chunk along each chunked dimension. - The function will be run on a small piece of data that looks like 'obj' to determine - properties of the returned object such as dtype, variable names, - new dimensions and new indexes (if any). + To determine properties of the returned object such as type (DataArray or Dataset), dtypes, + and new/removed dimensions and/or variables, the function will be run on dummy data + with the same variables, dimension names, and data types as this DataArray, but zero-sized + dimensions. This function must - return either a single DataArray or a single Dataset @@ -114,11 +103,11 @@ def map_blocks(func, obj, args=[], kwargs={}): - change size of existing dimensions. - add new chunked dimensions. - If your function expects numpy arrays, see `xarray.apply_ufunc` - + This function is designed to work with whole xarray objects. If your function can be applied + to numpy or dask arrays (e.g. it doesn't need indices, variable names, etc.), + you should consider using :func:~xarray.apply_ufunc instead. obj: DataArray, Dataset - Chunks of this object will be provided to 'func'. The function must not change - shape of the provided DataArray. + Chunks of this object will be provided to 'func'. args: list Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. kwargs: dict @@ -136,12 +125,12 @@ def map_blocks(func, obj, args=[], kwargs={}): See Also -------- - dask.array.map_blocks, xarray.apply_ufunc + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, xarray.DataArray.map_blocks """ def _wrapper(func, obj, to_array, args, kwargs): if to_array: - obj = _to_array(obj) + obj = dataset_to_dataarray(obj) result = func(obj, *args, **kwargs) @@ -157,14 +146,14 @@ def _wrapper(func, obj, to_array, args, kwargs): return to_return - if not isinstance(args, list): - raise ValueError("args must be a list.") + if not isinstance(args, Sequence): + raise TypeError("args must be a sequence.") - if not isinstance(kwargs, dict): - raise ValueError("kwargs must be a dictionary.") + if not isinstance(kwargs, Mapping): + raise TypeError("kwargs must be a mapping.") if not dask.is_dask_collection(obj): - raise ValueError( + raise TypeError( "map_blocks can only be used with dask-backed DataArrays. Use .chunk() to convert to a Dask array." ) @@ -230,7 +219,7 @@ def _wrapper(func, obj, to_array, args, kwargs): else: # non-dask array with possibly chunked dimensions # index into variable appropriately - subsetter = dict() + subsetter = {} for dim in variable.dims: if dim in chunk_index_dict: which_chunk = chunk_index_dict[dim] @@ -264,9 +253,6 @@ def _wrapper(func, obj, to_array, args, kwargs): for name, variable in template.variables.items(): if name in indexes: continue - # cannot tokenize "name" because the hash of ReprObject () - # is a function of its value. This happens when the user function does not - # set a name on the returned DataArray gname_l = "%s-%s" % (gname, name) var_key_map[name] = gname_l @@ -305,6 +291,6 @@ def _wrapper(func, obj, to_array, args, kwargs): result = result.set_coords(template._coord_names) if result_is_array: - result = _to_array(result) + result = dataset_to_dataarray(result) return result diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 1461b26a281..e876e5f9b16 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -980,10 +980,10 @@ def returns_numpy(darray): with raises_regex(ValueError, "Function must return an xarray DataArray"): xr.map_blocks(returns_numpy, map_da) - with raises_regex(ValueError, "args must be"): + with raises_regex(TypeError, "args must be"): xr.map_blocks(operator.add, map_da, args=10) - with raises_regex(ValueError, "kwargs must be"): + with raises_regex(TypeError, "kwargs must be"): xr.map_blocks(operator.add, map_da, args=[10], kwargs=[20]) From 3345d25b330ff5ad5a82cbf3021559f73ba2c782 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 23 Sep 2019 12:09:25 -0600 Subject: [PATCH 35/76] Better error handling when inferring returned object --- xarray/core/parallel.py | 18 ++++++++++++------ xarray/tests/test_dask.py | 6 ++++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 5db1d556051..c9b216fa644 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -66,7 +66,18 @@ def infer_template(func, obj, *args, **kwargs): """ Infer return object by running the function on meta objects. """ meta_args = [make_meta(arg) for arg in (obj,) + args] - template = func(*meta_args, **kwargs) + try: + template = func(*meta_args, **kwargs) + except Exception as e: + raise Exception( + "Cannot infer object returned from running user provided function." + ) from e + + if not isinstance(template, (Dataset, DataArray)): + raise TypeError( + "Function must return an xarray DataArray or Dataset. Instead it returned %r" + % type(template) + ) return template @@ -170,11 +181,6 @@ def _wrapper(func, obj, to_array, args, kwargs): template = template._to_temp_dataset() elif isinstance(template, Dataset): result_is_array = False - else: - raise ValueError( - "Function must return an xarray DataArray or Dataset. Instead it returned %r" - % type(template) - ) # If two different variables have different chunking along the same dim # fix that by "unifying chunks" diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index e876e5f9b16..23167b6ad11 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -986,6 +986,12 @@ def returns_numpy(darray): with raises_regex(TypeError, "kwargs must be"): xr.map_blocks(operator.add, map_da, args=[10], kwargs=[20]) + def really_bad_func(darray): + raise ValueError("couldn't do anything.") + + with raises_regex(Exception, "Cannot infer"): + xr.map_blocks(really_bad_func, map_da) + @pytest.mark.parametrize("obj", [map_da, map_ds]) def test_map_blocks(obj): From 54c77dd7262d2ef149421f83e4d24c2386cf59f5 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 23 Sep 2019 12:35:50 -0600 Subject: [PATCH 36/76] wip --- doc/whats-new.rst | 4 +++- xarray/core/parallel.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index aba95c09c73..bbf46e42555 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,7 +22,9 @@ v0.13.1 (unreleased) New functions/methods ~~~~~~~~~~~~~~~~~~~~~ -- Added :py:func:`~xarray.map_blocks`, modeled after :py:func:`dask.array.map_blocks` +- Added :py:func:`~xarray.map_blocks`, modeled after :py:func:`dask.array.map_blocks`. + Also added :py:meth:`Dataset.unify_chunks`, :py:meth:`DataArray.unify_chunks` and + :py:meth:`testing.assert_chunks_equal`. By `Deepak Cherian `_. Bug fixes diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index c9b216fa644..2cac6acd430 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -274,7 +274,7 @@ def _wrapper(func, obj, to_array, args, kwargs): graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset]) - result = Dataset() + result = Dataset(coords=indexes) # a quicker way to assign indexes? # indexes need to be known # otherwise compute is called when DataArray is created From fb1ff0b850d00f2fb8d7098e67fffbff7f203658 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 26 Sep 2019 09:01:08 -0600 Subject: [PATCH 37/76] Docstrings + nicer error message. --- xarray/core/dataarray.py | 4 ++-- xarray/core/dataset.py | 2 +- xarray/core/parallel.py | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 3387caf3821..5f0cb9e76ff 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3059,7 +3059,7 @@ def map_blocks(self, func, args=[], kwargs={}): ---------- func: callable User-provided function that should accept xarray objects. - This function will receive a subset of this dataset, corresponding to one chunk along + This function will receive a subset of this DataArray, corresponding to one chunk along each chunked dimension. The function will be run on a small piece of data that looks like 'obj' to determine properties of the returned object such as dtype, variable names, @@ -3091,7 +3091,7 @@ def map_blocks(self, func, args=[], kwargs={}): See Also -------- - dask.array.map_blocks, xarray.apply_ufunc + dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks, xarray.Dataset.map_blocks """ from .parallel import map_blocks diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1e148b2e8db..ef19158684f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5231,7 +5231,7 @@ def map_blocks(self, func, args=[], kwargs={}): See Also -------- - dask.array.map_blocks, xarray.apply_ufunc + dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks, xarray.DataArray.map_blocks """ from .parallel import map_blocks diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 2cac6acd430..52d668b8f33 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -114,9 +114,9 @@ def map_blocks(func, obj, args=[], kwargs={}): - change size of existing dimensions. - add new chunked dimensions. - This function is designed to work with whole xarray objects. If your function can be applied - to numpy or dask arrays (e.g. it doesn't need indices, variable names, etc.), - you should consider using :func:~xarray.apply_ufunc instead. + This function should work with whole xarray objects. If your function can be applied + to numpy or dask arrays (e.g. it doesn't need additional metadata such as dimension names, + variable names, etc.), you should consider using :py:func:`~xarray.apply_ufunc` instead. obj: DataArray, Dataset Chunks of this object will be provided to 'func'. args: list @@ -158,10 +158,10 @@ def _wrapper(func, obj, to_array, args, kwargs): return to_return if not isinstance(args, Sequence): - raise TypeError("args must be a sequence.") + raise TypeError("args must be a sequence (for example, a list).") if not isinstance(kwargs, Mapping): - raise TypeError("kwargs must be a mapping.") + raise TypeError("kwargs must be a mapping (for example, a dict)") if not dask.is_dask_collection(obj): raise TypeError( From bad0855569ca691fcb7d968be1d3d44821f29539 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 23 Sep 2019 16:52:49 -0600 Subject: [PATCH 38/76] wip --- xarray/tests/test_dask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 23167b6ad11..731f4676fb4 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -977,7 +977,7 @@ def bad_func(darray): def returns_numpy(darray): return (darray * darray.x + 5 * darray.y).values - with raises_regex(ValueError, "Function must return an xarray DataArray"): + with raises_regex(TypeError, "Function must return an xarray DataArray"): xr.map_blocks(returns_numpy, map_da) with raises_regex(TypeError, "args must be"): From 291e6e675d9ea201aaec1262ddb38e9e4dabbfa9 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 23 Sep 2019 16:54:44 -0600 Subject: [PATCH 39/76] better to_array --- xarray/core/parallel.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 52d668b8f33..5a32f294940 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -26,11 +26,7 @@ def dataset_to_dataarray(obj: Dataset) -> DataArray: "Trying to convert Dataset with more than one data variable to DataArray" ) - name = list(obj.data_vars)[0] - # this should be easier - da = obj.to_array().squeeze().drop("variable") - da.name = name - return da + return next(iter(obj.data_vars.values())) def make_meta(obj): From b31537cf22e32f7d0940b925f578248767d1c246 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 26 Sep 2019 09:11:16 -0600 Subject: [PATCH 40/76] remove unify_chunks in map_blocks + better tests. --- xarray/core/parallel.py | 4 +--- xarray/tests/test_dask.py | 28 +++++++++++++++------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 5a32f294940..9535bbe24f1 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -178,9 +178,7 @@ def _wrapper(func, obj, to_array, args, kwargs): elif isinstance(template, Dataset): result_is_array = False - # If two different variables have different chunking along the same dim - # fix that by "unifying chunks" - dataset = dataset.unify_chunks() + # if chunks are inconsistent, this will raise an error. input_chunks = dataset.chunks # TODO: add a test that fails when template and dataset are switched diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 731f4676fb4..89d1a05c784 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -929,6 +929,8 @@ def make_da(): name="a", ).chunk({"x": 4, "y": 5}) da.coords["c2"] = 0.5 + da.coords["ndcoord"] = da.x ** 2 + return da @@ -937,34 +939,34 @@ def make_ds(): map_ds["a"] = make_da() map_ds["b"] = map_ds.a + 50 map_ds["c"] = map_ds.x + 20 - map_ds = map_ds.chunk({"x": 4, "y": 5}) map_ds["d"] = ("z", [1, 1, 1, 1]) map_ds["z"] = [0, 1, 2, 3] - map_ds["e"] = map_ds.x + map_ds.y + map_ds.z + map_ds["e"] = map_ds.x + map_ds.y map_ds.coords["c1"] = 0.5 map_ds.coords["cx"] = ("x", np.arange(len(map_da.x))) map_ds.coords["cxy"] = (("x", "y"), map_da.x * map_da.y) - map_ds.coords["cxy"] = map_ds.cxy.chunk({"y": 10}) map_ds.attrs["test"] = "test" - map_ds["xx"] = (map_ds["a"] * map_ds.y).chunk({"y": 20}) + map_ds["xx"] = map_ds["a"] * map_ds.y + map_ds = map_ds.chunk({"x": 4, "y": 5}) return map_ds -# work around mypy error -# xarray/tests/test_dask.py:888: error: Dict entry 0 has incompatible type "str": "int"; expected "Hashable": "Union[None, Number, Tuple[Number, ...]]" map_da = make_da() map_ds = make_ds() def test_unify_chunks(): + ds_copy = map_ds.copy() + ds_copy["cxy"] = ds_copy.cxy.chunk({"y": 10}) + with raises_regex(ValueError, "inconsistent chunks"): - map_ds.chunks + ds_copy.chunks - expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5)} - actual_chunks = map_ds.unify_chunks().chunks - assert expected_chunks == actual_chunks - assert_identical(map_ds, map_ds.unify_chunks()) + expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5), "z": (4,)} + actual_chunks = ds_copy.unify_chunks().chunks + expected_chunks == actual_chunks + assert_identical(map_ds, ds_copy.unify_chunks()) def test_map_blocks_error(): @@ -1001,7 +1003,7 @@ def func(obj): with raise_if_dask_computes(): actual = xr.map_blocks(func, obj) - expected = func(obj).unify_chunks() + expected = func(obj) assert_chunks_equal(expected, actual) # why is compute needed? xr.testing.assert_equal(expected.compute(), actual.compute()) @@ -1009,7 +1011,7 @@ def func(obj): @pytest.mark.parametrize("obj", [map_da, map_ds]) def test_map_blocks_convert_args_to_list(obj): - expected = obj.unify_chunks() + 10 + expected = obj + 10 with raise_if_dask_computes(): actual = xr.map_blocks(operator.add, obj, [10]) assert_chunks_equal(expected, actual) From 72e791308a2a6a99b5c176dc1698901e7e523921 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 26 Sep 2019 09:41:45 -0600 Subject: [PATCH 41/76] typing for unify_chunks --- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 5f0cb9e76ff..a491c1f512e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3034,7 +3034,7 @@ def integrate( ds = self._to_temp_dataset().integrate(dim, datetime_unit) return self._from_temp_dataset(ds) - def unify_chunks(self): + def unify_chunks(self) -> "DataArray": """ Unifies chunksize along all chunked dimensions of this DataArray. Returns diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ef19158684f..3db240feecc 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5149,7 +5149,7 @@ def filter_by_attrs(self, **kwargs): selection.append(var_name) return self[selection] - def unify_chunks(self): + def unify_chunks(self) -> "Dataset": """ Unifies chunksize along all chunked dimensions of this Dataset. Returns From 0a6bbedba8e8917f260ab7149f85c4402d978c1c Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 26 Sep 2019 09:48:43 -0600 Subject: [PATCH 42/76] address more review comments. --- xarray/core/dataarray.py | 6 +++--- xarray/core/dataset.py | 6 +++--- xarray/core/parallel.py | 19 ++++++------------- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index a491c1f512e..c5857e80b74 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3050,7 +3050,7 @@ def unify_chunks(self) -> "DataArray": ds = self.copy()._to_temp_dataset().unify_chunks() return self._from_temp_dataset(ds) - def map_blocks(self, func, args=[], kwargs={}): + def map_blocks(self, func, args=(), kwargs=None): """ Apply a function to each chunk of this DataArray. This function is experimental and its signature may change. @@ -3074,9 +3074,9 @@ def map_blocks(self, func, args=[], kwargs={}): If your function expects numpy arrays, see `xarray.apply_ufunc` - args: list + args: Sequence Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. - kwargs: dict + kwargs: Mapping Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. Returns diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 3db240feecc..dd2eff55d44 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5190,7 +5190,7 @@ def unify_chunks(self) -> "Dataset": return ds - def map_blocks(self, func, args=[], kwargs={}): + def map_blocks(self, func, args=(), kwargs=None): """ Apply a function to each chunk of this Dataset. This function is experimental and its signature may change. @@ -5214,9 +5214,9 @@ def map_blocks(self, func, args=[], kwargs={}): If your function expects numpy arrays, see `xarray.apply_ufunc` - args: list + args: Sequence Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. - kwargs: dict + kwargs: Mapping Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. Returns diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 9535bbe24f1..4c5838d4a45 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -87,7 +87,7 @@ def make_dict(x): return {k: v.data for k, v in x.variables.items()} -def map_blocks(func, obj, args=[], kwargs={}): +def map_blocks(func, obj, args=(), kwargs=None): """ Apply a function to each chunk of a DataArray or Dataset. This function is experimental and its signature may change. @@ -156,7 +156,10 @@ def _wrapper(func, obj, to_array, args, kwargs): if not isinstance(args, Sequence): raise TypeError("args must be a sequence (for example, a list).") - if not isinstance(kwargs, Mapping): + if kwargs is None: + kwargs = {} + + elif not isinstance(kwargs, Mapping): raise TypeError("kwargs must be a mapping (for example, a dict)") if not dask.is_dask_collection(obj): @@ -178,13 +181,8 @@ def _wrapper(func, obj, to_array, args, kwargs): elif isinstance(template, Dataset): result_is_array = False - # if chunks are inconsistent, this will raise an error. input_chunks = dataset.chunks - - # TODO: add a test that fails when template and dataset are switched - indexes = dict(template.indexes) - indexes.update(dataset.indexes) - + indexes = {dim: dataset.indexes[dim] for dim in template.dims} graph = {} gname = "%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset)) @@ -269,11 +267,6 @@ def _wrapper(func, obj, to_array, args, kwargs): graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset]) result = Dataset(coords=indexes) - # a quicker way to assign indexes? - # indexes need to be known - # otherwise compute is called when DataArray is created - for name in template.dims: - result[name] = indexes[name] for name, key in var_key_map.items(): dims = template[name].dims var_chunks = [] From 210987e8fa3ddc252ec5fcbfc3be3c28f069a7ab Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 26 Sep 2019 10:03:39 -0600 Subject: [PATCH 43/76] more unify_chunks tests. --- xarray/tests/test_dask.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 89d1a05c784..215e9dc7eb8 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -929,7 +929,8 @@ def make_da(): name="a", ).chunk({"x": 4, "y": 5}) da.coords["c2"] = 0.5 - da.coords["ndcoord"] = da.x ** 2 + da.coords["ndcoord"] = (da.x ** 2).compute() + da.coords["cxy"] = da.x * da.y return da @@ -944,7 +945,6 @@ def make_ds(): map_ds["e"] = map_ds.x + map_ds.y map_ds.coords["c1"] = 0.5 map_ds.coords["cx"] = ("x", np.arange(len(map_da.x))) - map_ds.coords["cxy"] = (("x", "y"), map_da.x * map_da.y) map_ds.attrs["test"] = "test" map_ds["xx"] = map_ds["a"] * map_ds.y @@ -956,7 +956,11 @@ def make_ds(): map_ds = make_ds() -def test_unify_chunks(): +# DataArray.unify_chunks +# invoke unify_chunks when chunks are already unified (returned object must be a shallow copy) +# invoke unify_chunks when there are no chunks to begin with (returned object must be a shallow copy) +@pytest.mark.parametrize("obj", [map_ds, map_da]) +def test_unify_chunks(obj): ds_copy = map_ds.copy() ds_copy["cxy"] = ds_copy.cxy.chunk({"y": 10}) From 582e0d534ce6acdb0e2a8156f814447d8413a8d3 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 26 Sep 2019 17:43:16 -0600 Subject: [PATCH 44/76] Just use dask.core.utils.meta_from_array --- xarray/core/parallel.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 4c5838d4a45..ab230c916c6 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -41,11 +41,8 @@ def make_meta(obj): if isinstance(obj, Dataset): meta = Dataset() for name, variable in obj.variables.items(): - if dask.is_dask_collection(variable): - meta_obj = obj[name].data._meta - else: - meta_obj = meta_from_array(variable.data) - meta[name] = DataArray(meta_obj, dims=variable.dims) + meta_obj = meta_from_array(variable.data) + meta[name] = (variable.dims, meta_obj) else: meta = obj From d0fd87e9e50f46984c490377cc632ab4741ebc8a Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 28 Sep 2019 06:43:21 -0600 Subject: [PATCH 45/76] get tests working. assert_equal needs a lot of fixing. --- xarray/tests/test_dask.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 215e9dc7eb8..b3773fe5d32 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -940,15 +940,15 @@ def make_ds(): map_ds["a"] = make_da() map_ds["b"] = map_ds.a + 50 map_ds["c"] = map_ds.x + 20 + map_ds = map_ds.chunk({"x": 4, "y": 5}) map_ds["d"] = ("z", [1, 1, 1, 1]) map_ds["z"] = [0, 1, 2, 3] map_ds["e"] = map_ds.x + map_ds.y map_ds.coords["c1"] = 0.5 map_ds.coords["cx"] = ("x", np.arange(len(map_da.x))) map_ds.attrs["test"] = "test" - map_ds["xx"] = map_ds["a"] * map_ds.y + map_ds["xx"] = (map_ds["a"] * map_ds.y).chunk() - map_ds = map_ds.chunk({"x": 4, "y": 5}) return map_ds @@ -1009,8 +1009,7 @@ def func(obj): actual = xr.map_blocks(func, obj) expected = func(obj) assert_chunks_equal(expected, actual) - # why is compute needed? - xr.testing.assert_equal(expected.compute(), actual.compute()) + xr.testing.assert_equal(actual, expected) @pytest.mark.parametrize("obj", [map_da, map_ds]) @@ -1019,8 +1018,7 @@ def test_map_blocks_convert_args_to_list(obj): with raise_if_dask_computes(): actual = xr.map_blocks(operator.add, obj, [10]) assert_chunks_equal(expected, actual) - # why is compute needed? - xr.testing.assert_equal(expected.compute(), actual.compute()) + xr.testing.assert_equal(actual, expected) @pytest.mark.parametrize("obj", [map_da, map_ds]) @@ -1029,8 +1027,7 @@ def test_map_blocks_kwargs(obj): with raise_if_dask_computes(): actual = xr.map_blocks(xr.full_like, obj, kwargs=dict(fill_value=np.nan)) assert_chunks_equal(expected, actual) - # why is compute needed? - xr.testing.assert_equal(expected.compute(), actual.compute()) + xr.testing.assert_equal(actual, expected) @pytest.mark.parametrize( From 875264a0b8129f5bb70023f431ba909169b58936 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 28 Sep 2019 06:43:44 -0600 Subject: [PATCH 46/76] more unify_chunks test. --- xarray/core/dataset.py | 6 ++++++ xarray/tests/test_dask.py | 12 +++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index dd2eff55d44..97b68da7186 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5162,6 +5162,12 @@ def unify_chunks(self) -> "Dataset": dask.array.core.unify_chunks """ + + import dask + + if not dask.is_dask_collection(self): + return self.copy() + try: self.chunks return self.copy() diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index b3773fe5d32..2c4d3700e17 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -956,9 +956,6 @@ def make_ds(): map_ds = make_ds() -# DataArray.unify_chunks -# invoke unify_chunks when chunks are already unified (returned object must be a shallow copy) -# invoke unify_chunks when there are no chunks to begin with (returned object must be a shallow copy) @pytest.mark.parametrize("obj", [map_ds, map_da]) def test_unify_chunks(obj): ds_copy = map_ds.copy() @@ -973,6 +970,15 @@ def test_unify_chunks(obj): assert_identical(map_ds, ds_copy.unify_chunks()) +@pytest.mark.parametrize( + "obj", + [map_ds.compute(), map_da.compute(), map_ds.unify_chunks(), map_da.unify_chunks()], +) +def test_unify_chunks_shallow_copy(obj): + unified = obj.unify_chunks() + assert_identical(obj, unified) and obj is not obj.unify_chunks() + + def test_map_blocks_error(): def bad_func(darray): return (darray * darray.x + 5 * darray.y)[:1, :1] From 0f03e3730fa29b6ca169f4773ca91cd7cf1fbe21 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 28 Sep 2019 06:49:10 -0600 Subject: [PATCH 47/76] assert_chunks_equal fixes. --- xarray/testing.py | 11 ++++------- xarray/tests/test_dask.py | 6 +++--- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/xarray/testing.py b/xarray/testing.py index 7e04a586fba..21bd4292b33 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -154,14 +154,11 @@ def assert_chunks_equal(a, b): The second object to compare. """ - if isinstance(a, DataArray): - a = a._to_temp_dataset() + if isinstance(a, DataArray) != isinstance(b, DataArray): + raise TypeError("a and b have mismatched types") - if isinstance(b, DataArray): - b = b._to_temp_dataset() - - left = a.chunk().unify_chunks() - right = b.chunk().unify_chunks() + left = a.unify_chunks() + right = b.unify_chunks() assert left.chunks == right.chunks diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 2c4d3700e17..b1d71d7d9c9 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1014,7 +1014,7 @@ def func(obj): with raise_if_dask_computes(): actual = xr.map_blocks(func, obj) expected = func(obj) - assert_chunks_equal(expected, actual) + assert_chunks_equal(expected.chunk(), actual) xr.testing.assert_equal(actual, expected) @@ -1023,7 +1023,7 @@ def test_map_blocks_convert_args_to_list(obj): expected = obj + 10 with raise_if_dask_computes(): actual = xr.map_blocks(operator.add, obj, [10]) - assert_chunks_equal(expected, actual) + assert_chunks_equal(expected.chunk(), actual) xr.testing.assert_equal(actual, expected) @@ -1032,7 +1032,7 @@ def test_map_blocks_kwargs(obj): expected = xr.full_like(obj, fill_value=np.nan) with raise_if_dask_computes(): actual = xr.map_blocks(xr.full_like, obj, kwargs=dict(fill_value=np.nan)) - assert_chunks_equal(expected, actual) + assert_chunks_equal(expected.chunk(), actual) xr.testing.assert_equal(actual, expected) From 8175d73f3eedc1da0f82a77942c876400e17f98c Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 28 Sep 2019 07:02:55 -0600 Subject: [PATCH 48/76] copy over meta_from_array. --- xarray/core/dask_array_compat.py | 87 ++++++++++++++++++++++++++++++++ xarray/core/parallel.py | 2 +- 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index fe2cdc5c553..2fdd0c5f301 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -171,3 +171,90 @@ def gradient(f, *varargs, axis=None, **kwargs): results = results[0] return results + + +if LooseVersion(dask_version) > LooseVersion("2.0.0"): + meta_from_array = da.utils.meta_from_array +else: + # Copied from dask v2.4.0 + # Used under the terms of Dask's license, see licenses/DASK_LICENSE. + import numbers + + def meta_from_array(x, ndim=None, dtype=None): + """ Normalize an array to appropriate meta object + + Parameters + ---------- + x: array-like, callable + Either an object that looks sufficiently like a Numpy array, + or a callable that accepts shape and dtype keywords + ndim: int + Number of dimensions of the array + dtype: Numpy dtype + A valid input for ``np.dtype`` + + Returns + ------- + array-like with zero elements of the correct dtype + """ + # If using x._meta, x must be a Dask Array, some libraries (e.g. zarr) + # implement a _meta attribute that are incompatible with Dask Array._meta + if hasattr(x, "_meta") and isinstance(x, da.Array): + x = x._meta + + if dtype is None and x is None: + raise ValueError("You must specify the meta or dtype of the array") + + if np.isscalar(x): + x = np.array(x) + + if x is None: + x = np.ndarray + + if isinstance(x, type): + x = x(shape=(0,) * (ndim or 0), dtype=dtype) + + if ( + not hasattr(x, "shape") + or not hasattr(x, "dtype") + or not isinstance(x.shape, tuple) + ): + return x + + if isinstance(x, list) or isinstance(x, tuple): + ndims = [ + 0 + if isinstance(a, numbers.Number) + else a.ndim + if hasattr(a, "ndim") + else len(a) + for a in x + ] + a = [a if nd == 0 else meta_from_array(a, nd) for a, nd in zip(x, ndims)] + return a if isinstance(x, list) else tuple(x) + + if ndim is None: + ndim = x.ndim + + try: + meta = x[tuple(slice(0, 0, None) for _ in range(x.ndim))] + if meta.ndim != ndim: + if ndim > x.ndim: + meta = meta[ + (Ellipsis,) + tuple(None for _ in range(ndim - meta.ndim)) + ] + meta = meta[tuple(slice(0, 0, None) for _ in range(meta.ndim))] + elif ndim == 0: + meta = meta.sum() + else: + meta = meta.reshape((0,) * ndim) + except Exception: + meta = np.empty((0,) * ndim, dtype=dtype or x.dtype) + + if np.isscalar(meta): + meta = np.array(meta) + + if dtype and meta.dtype != dtype: + meta = meta.astype(dtype) + + return meta diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index ab230c916c6..1ca198248f7 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -2,7 +2,7 @@ import dask import dask.array from dask.highlevelgraph import HighLevelGraph - from dask.array.utils import meta_from_array + from .dask_array_compat import meta_from_array except ImportError: pass From 6ab8737dfad3ee54b8837b5bf77b5a660ab5bbb8 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 28 Sep 2019 07:03:09 -0600 Subject: [PATCH 49/76] minor fixes. --- xarray/core/parallel.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 1ca198248f7..2b48fb835cf 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -111,7 +111,8 @@ def map_blocks(func, obj, args=(), kwargs=None): to numpy or dask arrays (e.g. it doesn't need additional metadata such as dimension names, variable names, etc.), you should consider using :py:func:`~xarray.apply_ufunc` instead. obj: DataArray, Dataset - Chunks of this object will be provided to 'func'. + Chunks of this object will be provided to 'func'. If passed a numpy object, the function will + be run eagerly. args: list Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. kwargs: dict @@ -151,7 +152,7 @@ def _wrapper(func, obj, to_array, args, kwargs): return to_return if not isinstance(args, Sequence): - raise TypeError("args must be a sequence (for example, a list).") + raise TypeError("args must be a sequence (for example, a list or tuple).") if kwargs is None: kwargs = {} @@ -160,9 +161,7 @@ def _wrapper(func, obj, to_array, args, kwargs): raise TypeError("kwargs must be a mapping (for example, a dict)") if not dask.is_dask_collection(obj): - raise TypeError( - "map_blocks can only be used with dask-backed DataArrays. Use .chunk() to convert to a Dask array." - ) + return func(obj, *args, **kwargs) if isinstance(obj, DataArray): dataset = obj._to_temp_dataset() From 08c41b9f1bf9dce0d00a61832c88f572796406d9 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 28 Sep 2019 07:19:03 -0600 Subject: [PATCH 50/76] raise chunks error earlier and test for map_blocks raising chunk error --- xarray/core/parallel.py | 3 ++- xarray/tests/test_dask.py | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 2b48fb835cf..cde739a2e91 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -170,6 +170,8 @@ def _wrapper(func, obj, to_array, args, kwargs): dataset = obj input_is_array = False + input_chunks = dataset.chunks + template = infer_template(func, obj, *args, **kwargs) if isinstance(template, DataArray): result_is_array = True @@ -177,7 +179,6 @@ def _wrapper(func, obj, to_array, args, kwargs): elif isinstance(template, Dataset): result_is_array = False - input_chunks = dataset.chunks indexes = {dim: dataset.indexes[dim] for dim in template.dims} graph = {} gname = "%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset)) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index b1d71d7d9c9..57a338de24b 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -947,7 +947,7 @@ def make_ds(): map_ds.coords["c1"] = 0.5 map_ds.coords["cx"] = ("x", np.arange(len(map_da.x))) map_ds.attrs["test"] = "test" - map_ds["xx"] = (map_ds["a"] * map_ds.y).chunk() + map_ds["xx"] = map_ds["a"] * map_ds.y return map_ds @@ -1004,6 +1004,12 @@ def really_bad_func(darray): with raises_regex(Exception, "Cannot infer"): xr.map_blocks(really_bad_func, map_da) + ds_copy = map_ds.copy() + ds_copy["cxy"] = ds_copy.cxy.chunk({"y": 10}) + + with raises_regex(ValueError, "inconsistent chunks"): + xr.map_blocks(bad_func, ds_copy) + @pytest.mark.parametrize("obj", [map_da, map_ds]) def test_map_blocks(obj): From 76bc23cd39859275131c7bf970221e95b13a9213 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 28 Sep 2019 07:20:15 -0600 Subject: [PATCH 51/76] fix. --- xarray/core/dataarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c5857e80b74..cec1e5531c6 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3047,7 +3047,7 @@ def unify_chunks(self) -> "DataArray": dask.array.core.unify_chunks """ - ds = self.copy()._to_temp_dataset().unify_chunks() + ds = self._to_temp_dataset().unify_chunks() return self._from_temp_dataset(ds) def map_blocks(self, func, args=(), kwargs=None): From 49d3899387cb910233a43f44d540693405d0e123 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 1 Oct 2019 11:34:09 +0100 Subject: [PATCH 52/76] Type annotations --- xarray/core/dataarray.py | 10 +++- xarray/core/dataset.py | 10 +++- xarray/core/parallel.py | 101 +++++++++++++++++++++++--------------- xarray/tests/test_dask.py | 2 +- 4 files changed, 81 insertions(+), 42 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index cec1e5531c6..2de26744747 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -15,6 +15,7 @@ Optional, Sequence, Tuple, + TypeVar, Union, cast, overload, @@ -64,6 +65,8 @@ ) if TYPE_CHECKING: + T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset) + try: from dask.delayed import Delayed except ImportError: @@ -3050,7 +3053,12 @@ def unify_chunks(self) -> "DataArray": ds = self._to_temp_dataset().unify_chunks() return self._from_temp_dataset(ds) - def map_blocks(self, func, args=(), kwargs=None): + def map_blocks( + self, + func: "Callable[..., T_DSorDA]", + args: Sequence[Any] = (), + kwargs: Mapping[str, Any] = None, + ) -> "T_DSorDA": """ Apply a function to each chunk of this DataArray. This function is experimental and its signature may change. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 97b68da7186..8e2c5a5e065 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -22,6 +22,7 @@ Sequence, Set, Tuple, + TypeVar, Union, cast, overload, @@ -87,6 +88,8 @@ from .dataarray import DataArray from .merge import DatasetLike + T_DSorDA = TypeVar("T_DSorDA", DataArray, "Dataset") + try: from dask.delayed import Delayed except ImportError: @@ -5196,7 +5199,12 @@ def unify_chunks(self) -> "Dataset": return ds - def map_blocks(self, func, args=(), kwargs=None): + def map_blocks( + self, + func: "Callable[..., T_DSorDA]", + args: Sequence[Any] = (), + kwargs: Mapping[str, Any] = None, + ) -> "T_DSorDA": """ Apply a function to each chunk of this Dataset. This function is experimental and its signature may change. diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index cde739a2e91..7f3c6ff4300 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -14,7 +14,19 @@ from .dataarray import DataArray from .dataset import Dataset -from typing import Sequence, Mapping +from typing import ( + Any, + Callable, + Dict, + Hashable, + Mapping, + Sequence, + Tuple, + TypeVar, + Union, +) + +T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) def dataset_to_dataarray(obj: Dataset) -> DataArray: @@ -55,7 +67,9 @@ def make_meta(obj): return meta -def infer_template(func, obj, *args, **kwargs): +def infer_template( + func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], *args, **kwargs +) -> T_DSorDA: """ Infer return object by running the function on meta objects. """ meta_args = [make_meta(arg) for arg in (obj,) + args] @@ -75,30 +89,34 @@ def infer_template(func, obj, *args, **kwargs): return template -def make_dict(x): - # Dataset.to_dict() is too complicated - # maps variable name to numpy array +def make_dict(x: Union[DataArray, Dataset]) -> Dict[Hashable, Any]: + """Map variable name to numpy(-like) data + (Dataset.to_dict() is too complicated). + """ if isinstance(x, DataArray): x = x._to_temp_dataset() return {k: v.data for k, v in x.variables.items()} -def map_blocks(func, obj, args=(), kwargs=None): - """ - Apply a function to each chunk of a DataArray or Dataset. This function is experimental - and its signature may change. +def map_blocks( + func: Callable[..., T_DSorDA], + obj: Union[DataArray, Dataset], + args: Sequence[Any] = (), + kwargs: Mapping[str, Any] = None, +) -> T_DSorDA: + """Apply a function to each chunk of a DataArray or Dataset. This function is + experimental and its signature may change. Parameters ---------- func: callable - User-provided function that should accept xarray objects. - This function will receive a subset of this dataset, corresponding to one chunk along - each chunked dimension. - To determine properties of the returned object such as type (DataArray or Dataset), dtypes, - and new/removed dimensions and/or variables, the function will be run on dummy data - with the same variables, dimension names, and data types as this DataArray, but zero-sized - dimensions. + User-provided function that should accept xarray objects. This function will + receive a subset of this dataset, corresponding to one chunk along each chunked + dimension. To determine properties of the returned object such as type + (DataArray or Dataset), dtypes, and new/removed dimensions and/or variables, the + function will be run on dummy data with the same variables, dimension names, and + data types as this DataArray, but zero-sized dimensions. This function must - return either a single DataArray or a single Dataset @@ -107,16 +125,19 @@ def map_blocks(func, obj, args=(), kwargs=None): - change size of existing dimensions. - add new chunked dimensions. - This function should work with whole xarray objects. If your function can be applied - to numpy or dask arrays (e.g. it doesn't need additional metadata such as dimension names, - variable names, etc.), you should consider using :py:func:`~xarray.apply_ufunc` instead. + This function should work with whole xarray objects. If your function can be + applied to numpy or dask arrays (e.g. it doesn't need additional metadata such + as dimension names, variable names, etc.), you should consider using + :py:func:`~xarray.apply_ufunc` instead. obj: DataArray, Dataset - Chunks of this object will be provided to 'func'. If passed a numpy object, the function will - be run eagerly. - args: list - Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. + Chunks of this object will be provided to 'func'. If passed a numpy object, the + function will be run eagerly. + args: tuple + Passed on to func after unpacking. xarray objects, if any, will not be split by + chunks. kwargs: dict - Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. + Passed on to func after unpacking. xarray objects, if any, will not be split by + chunks. Returns ------- @@ -125,12 +146,13 @@ def map_blocks(func, obj, args=(), kwargs=None): Notes ----- - This function is designed to work with dask-backed xarray objects. See apply_ufunc for - a similar function that works with numpy arrays. + This function is designed to work with dask-backed xarray objects. See apply_ufunc + for a similar function that works with numpy arrays. See Also -------- - dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, xarray.DataArray.map_blocks + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, + xarray.DataArray.map_blocks """ def _wrapper(func, obj, to_array, args, kwargs): @@ -147,9 +169,7 @@ def _wrapper(func, obj, to_array, args, kwargs): % name ) - to_return = make_dict(result) - - return to_return + return make_dict(result) if not isinstance(args, Sequence): raise TypeError("args must be a sequence (for example, a list or tuple).") @@ -172,15 +192,19 @@ def _wrapper(func, obj, to_array, args, kwargs): input_chunks = dataset.chunks - template = infer_template(func, obj, *args, **kwargs) + template: Union[DataArray, Dataset] = infer_template(func, obj, *args, **kwargs) if isinstance(template, DataArray): result_is_array = True template = template._to_temp_dataset() elif isinstance(template, Dataset): result_is_array = False + else: + raise TypeError( + "func output must be DataArray or Dataset; got %s" % type(template) + ) indexes = {dim: dataset.indexes[dim] for dim in template.dims} - graph = {} + graph = {} # type: Dict[Any, Any] gname = "%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset)) # map dims to list of chunk indexes @@ -244,14 +268,14 @@ def _wrapper(func, obj, to_array, args, kwargs): ) # mapping from variable name to dask graph key - var_key_map = {} + var_key_map = {} # type: Dict[Hashable, str] for name, variable in template.variables.items(): if name in indexes: continue gname_l = "%s-%s" % (gname, name) var_key_map[name] = gname_l - key = (gname_l,) + key = (gname_l,) # type: Tuple[Any, ...] for dim in variable.dims: if dim in chunk_index_dict: key += (chunk_index_dict[dim],) @@ -264,7 +288,7 @@ def _wrapper(func, obj, to_array, args, kwargs): graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset]) result = Dataset(coords=indexes) - for name, key in var_key_map.items(): + for name, gname_l in var_key_map.items(): dims = template[name].dims var_chunks = [] for dim in dims: @@ -274,13 +298,12 @@ def _wrapper(func, obj, to_array, args, kwargs): var_chunks.append((len(indexes[dim]),)) data = dask.array.Array( - graph, name=key, chunks=var_chunks, dtype=template[name].dtype + graph, name=gname_l, chunks=var_chunks, dtype=template[name].dtype ) result[name] = (dims, data) result = result.set_coords(template._coord_names) if result_is_array: - result = dataset_to_dataarray(result) - - return result + return dataset_to_dataarray(result) # type: ignore + return result # type: ignore diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 57a338de24b..0287d194d26 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1057,7 +1057,7 @@ def test_map_blocks_kwargs(obj): [lambda x: x.rename({"a": "new1", "b": "new2"}), map_ds], ], ) -def map_blocks_transformations(func, obj, expected): +def test_map_blocks_transformations(func, obj): with raise_if_dask_computes(): assert_equal(xr.map_blocks(func, obj), func(obj)) From ae53b85c1f391556a343f21e286c9dec136b3652 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 1 Oct 2019 11:46:04 +0100 Subject: [PATCH 53/76] py35 compat --- xarray/core/parallel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 7f3c6ff4300..c5272ac2652 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -192,7 +192,9 @@ def _wrapper(func, obj, to_array, args, kwargs): input_chunks = dataset.chunks - template: Union[DataArray, Dataset] = infer_template(func, obj, *args, **kwargs) + template = infer_template( + func, obj, *args, **kwargs + ) # type: Union[DataArray, Dataset] if isinstance(template, DataArray): result_is_array = True template = template._to_temp_dataset() From f6dfb128eef8990756e5c621480512f499cce466 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 1 Oct 2019 08:06:41 -0600 Subject: [PATCH 54/76] make sure unify_chunks does not compute. --- xarray/tests/test_dask.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 0287d194d26..c3bd47d9b55 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -965,7 +965,8 @@ def test_unify_chunks(obj): ds_copy.chunks expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5), "z": (4,)} - actual_chunks = ds_copy.unify_chunks().chunks + with raise_if_dask_computes(): + actual_chunks = ds_copy.unify_chunks().chunks expected_chunks == actual_chunks assert_identical(map_ds, ds_copy.unify_chunks()) From c73eda1866c79ba724396e517a81d2c21b235278 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 1 Oct 2019 10:17:24 -0600 Subject: [PATCH 55/76] Make tests functional by call compute before assert_equal --- xarray/core/parallel.py | 18 +++++++++++++++--- xarray/tests/test_dask.py | 17 +++++++++++------ 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index c5272ac2652..94479a38751 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -184,7 +184,13 @@ def _wrapper(func, obj, to_array, args, kwargs): return func(obj, *args, **kwargs) if isinstance(obj, DataArray): - dataset = obj._to_temp_dataset() + # only using _to_temp_dataset would break + # func = lambda x: x.to_dataset() + # since that relies on preserving name. + if obj.name is None: + dataset = obj._to_temp_dataset() + else: + dataset = obj.to_dataset() input_is_array = True else: dataset = obj @@ -205,7 +211,13 @@ def _wrapper(func, obj, to_array, args, kwargs): "func output must be DataArray or Dataset; got %s" % type(template) ) - indexes = {dim: dataset.indexes[dim] for dim in template.dims} + template_indexes = set(template.indexes) + dataset_indexes = set(dataset.indexes) + preserved_indexes = template_indexes.intersection(dataset_indexes) + new_indexes = set(template_indexes) - set(dataset_indexes) + indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} + indexes.update({k: template.indexes[k] for k in new_indexes}) + graph = {} # type: Dict[Any, Any] gname = "%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset)) @@ -287,7 +299,7 @@ def _wrapper(func, obj, to_array, args, kwargs): graph[key] = (operator.getitem, from_wrapper, name) - graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset]) + graph = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) result = Dataset(coords=indexes) for name, gname_l in var_key_map.items(): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index c3bd47d9b55..0d10bbd9ef2 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1022,7 +1022,7 @@ def func(obj): actual = xr.map_blocks(func, obj) expected = func(obj) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_equal(actual, expected) + xr.testing.assert_equal(actual.compute(), expected.compute()) @pytest.mark.parametrize("obj", [map_da, map_ds]) @@ -1031,7 +1031,7 @@ def test_map_blocks_convert_args_to_list(obj): with raise_if_dask_computes(): actual = xr.map_blocks(operator.add, obj, [10]) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_equal(actual, expected) + xr.testing.assert_equal(actual.compute(), expected.compute()) @pytest.mark.parametrize("obj", [map_da, map_ds]) @@ -1040,7 +1040,7 @@ def test_map_blocks_kwargs(obj): with raise_if_dask_computes(): actual = xr.map_blocks(xr.full_like, obj, kwargs=dict(fill_value=np.nan)) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_equal(actual, expected) + xr.testing.assert_equal(actual.compute(), expected.compute()) @pytest.mark.parametrize( @@ -1048,11 +1048,14 @@ def test_map_blocks_kwargs(obj): [ [lambda x: x.to_dataset(), map_da], [lambda x: x.to_array(), map_ds], + [lambda x: x.drop("cxy"), map_ds], [lambda x: x.drop("a"), map_ds], + [lambda x: x.drop("x"), map_da], + [lambda x: x.drop("x"), map_ds], [lambda x: x.expand_dims(k=[1, 2, 3]), map_ds], [lambda x: x.expand_dims(k=[1, 2, 3]), map_da], - [lambda x: x.isel(x=1), map_ds], - [lambda x: x.isel(x=1).drop("x"), map_da], + # TODO: [lambda x: x.isel(x=1), map_ds], + # TODO: [lambda x: x.isel(x=1).drop("x"), map_da], [lambda x: x.assign_coords(new_coord=("y", x.y * 2)), map_da], [lambda x: x.astype(np.int32), map_da], [lambda x: x.rename({"a": "new1", "b": "new2"}), map_ds], @@ -1060,7 +1063,9 @@ def test_map_blocks_kwargs(obj): ) def test_map_blocks_transformations(func, obj): with raise_if_dask_computes(): - assert_equal(xr.map_blocks(func, obj), func(obj)) + actual = xr.map_blocks(func, obj) + + assert_equal(actual.compute(), func(obj).compute()) @pytest.mark.parametrize("obj", [map_da, map_ds]) From 8ad882b1b9b3bd7bd5f40d0e0e37109727beff10 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 3 Oct 2019 08:27:52 -0600 Subject: [PATCH 56/76] Update whats-new --- doc/whats-new.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index bbf46e42555..482302502d0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -24,8 +24,8 @@ New functions/methods - Added :py:func:`~xarray.map_blocks`, modeled after :py:func:`dask.array.map_blocks`. Also added :py:meth:`Dataset.unify_chunks`, :py:meth:`DataArray.unify_chunks` and - :py:meth:`testing.assert_chunks_equal`. - By `Deepak Cherian `_. + :py:meth:`testing.assert_chunks_equal`. By `Deepak Cherian `_ + and `Guido Imperiale `_. Bug fixes ~~~~~~~~~ From 3cda5ac39be175136d51a4582bf64c7c1cfbc359 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 3 Oct 2019 08:45:08 -0600 Subject: [PATCH 57/76] Work with attributes. --- xarray/core/parallel.py | 7 ++++--- xarray/tests/test_dask.py | 16 ++++++++++------ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 94479a38751..af743d2fb00 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -54,7 +54,8 @@ def make_meta(obj): meta = Dataset() for name, variable in obj.variables.items(): meta_obj = meta_from_array(variable.data) - meta[name] = (variable.dims, meta_obj) + meta[name] = (variable.dims, meta_obj, variable.attrs) + meta.attrs = obj.attrs else: meta = obj @@ -301,7 +302,7 @@ def _wrapper(func, obj, to_array, args, kwargs): graph = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) - result = Dataset(coords=indexes) + result = Dataset(coords=indexes, attrs=template.attrs) for name, gname_l in var_key_map.items(): dims = template[name].dims var_chunks = [] @@ -314,7 +315,7 @@ def _wrapper(func, obj, to_array, args, kwargs): data = dask.array.Array( graph, name=gname_l, chunks=var_chunks, dtype=template[name].dtype ) - result[name] = (dims, data) + result[name] = (dims, data, template[name].attrs) result = result.set_coords(template._coord_names) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 0d10bbd9ef2..4e91806c5ae 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -928,9 +928,10 @@ def make_da(): coords={"x": np.arange(10), "y": np.arange(100, 120)}, name="a", ).chunk({"x": 4, "y": 5}) + da.attrs["test"] = "test" da.coords["c2"] = 0.5 - da.coords["ndcoord"] = (da.x ** 2).compute() - da.coords["cxy"] = da.x * da.y + da.coords["ndcoord"] = da.x * 2 + da.coords["cxy"] = (da.x * da.y).chunk({"x": 4, "y": 5}) return da @@ -946,8 +947,9 @@ def make_ds(): map_ds["e"] = map_ds.x + map_ds.y map_ds.coords["c1"] = 0.5 map_ds.coords["cx"] = ("x", np.arange(len(map_da.x))) + map_ds.coords["cx"].attrs["test2"] = "test2" map_ds.attrs["test"] = "test" - map_ds["xx"] = map_ds["a"] * map_ds.y + map_ds.coords["xx"] = map_ds["a"] * map_ds.y return map_ds @@ -1022,7 +1024,7 @@ def func(obj): actual = xr.map_blocks(func, obj) expected = func(obj) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_equal(actual.compute(), expected.compute()) + xr.testing.assert_identical(actual.compute(), expected.compute()) @pytest.mark.parametrize("obj", [map_da, map_ds]) @@ -1031,7 +1033,7 @@ def test_map_blocks_convert_args_to_list(obj): with raise_if_dask_computes(): actual = xr.map_blocks(operator.add, obj, [10]) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_equal(actual.compute(), expected.compute()) + xr.testing.assert_identical(actual.compute(), expected.compute()) @pytest.mark.parametrize("obj", [map_da, map_ds]) @@ -1046,6 +1048,8 @@ def test_map_blocks_kwargs(obj): @pytest.mark.parametrize( "func, obj", [ + [lambda x: x, map_da], + [lambda x: x, map_ds], [lambda x: x.to_dataset(), map_da], [lambda x: x.to_array(), map_ds], [lambda x: x.drop("cxy"), map_ds], @@ -1065,7 +1069,7 @@ def test_map_blocks_transformations(func, obj): with raise_if_dask_computes(): actual = xr.map_blocks(func, obj) - assert_equal(actual.compute(), func(obj).compute()) + assert_identical(actual.compute(), func(obj).compute()) @pytest.mark.parametrize("obj", [map_da, map_ds]) From 49969a7d98bc0ea130016c583c256df5f1406b66 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 3 Oct 2019 09:23:35 -0600 Subject: [PATCH 58/76] Support attrs and name changes. --- xarray/core/parallel.py | 15 ++++++++++++--- xarray/tests/test_dask.py | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index af743d2fb00..104c8c01155 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -204,6 +204,7 @@ def _wrapper(func, obj, to_array, args, kwargs): ) # type: Union[DataArray, Dataset] if isinstance(template, DataArray): result_is_array = True + template_name = template.name template = template._to_temp_dataset() elif isinstance(template, Dataset): result_is_array = False @@ -249,7 +250,10 @@ def _wrapper(func, obj, to_array, args, kwargs): chunk = chunk[chunk_index_dict[dim]] chunk_variable_task = ("tuple-" + dask.base.tokenize(chunk),) + v - graph[chunk_variable_task] = (tuple, [variable.dims, chunk]) + graph[chunk_variable_task] = ( + tuple, + [variable.dims, chunk, variable.attrs], + ) else: # non-dask array with possibly chunked dimensions # index into variable appropriately @@ -264,7 +268,10 @@ def _wrapper(func, obj, to_array, args, kwargs): subset = variable.isel(subsetter) chunk_variable_task = (name + dask.base.tokenize(subset),) + v - graph[chunk_variable_task] = (tuple, [subset.dims, subset]) + graph[chunk_variable_task] = ( + tuple, + [subset.dims, subset, subset.attrs], + ) # this task creates dict mapping variable name to above tuple if name in dataset._coord_names: @@ -320,5 +327,7 @@ def _wrapper(func, obj, to_array, args, kwargs): result = result.set_coords(template._coord_names) if result_is_array: - return dataset_to_dataarray(result) # type: ignore + da = dataset_to_dataarray(result) + da.name = template_name + return da # type: ignore return result # type: ignore diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 4e91806c5ae..9d34490a79e 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1036,6 +1036,34 @@ def test_map_blocks_convert_args_to_list(obj): xr.testing.assert_identical(actual.compute(), expected.compute()) +@pytest.mark.parametrize("obj", [map_da, map_ds]) +def test_map_blocks_add_attrs(obj): + def add_attrs(obj): + obj = obj.copy(deep=True) + obj.attrs["new"] = "new" + return obj + + expected = add_attrs(obj) + with raise_if_dask_computes(): + actual = xr.map_blocks(add_attrs, obj) + + xr.testing.assert_identical(actual.compute(), expected.compute()) + + +@pytest.mark.parametrize("obj", [map_da]) +def test_map_blocks_change_name(obj): + def change_name(obj): + obj = obj.copy(deep=True) + obj.name = "new" + return obj + + expected = change_name(obj) + with raise_if_dask_computes(): + actual = xr.map_blocks(change_name, obj) + + xr.testing.assert_identical(actual.compute(), expected.compute()) + + @pytest.mark.parametrize("obj", [map_da, map_ds]) def test_map_blocks_kwargs(obj): expected = xr.full_like(obj, fill_value=np.nan) @@ -1045,13 +1073,20 @@ def test_map_blocks_kwargs(obj): xr.testing.assert_equal(actual.compute(), expected.compute()) +def test_map_blocks_to_array(): + with raise_if_dask_computes(): + actual = xr.map_blocks(lambda x: x.to_array(), map_ds) + + # to_array does not preserve name, so cannot use assert_identical + assert_equal(actual.compute(), map_ds.to_array().compute()) + + @pytest.mark.parametrize( "func, obj", [ [lambda x: x, map_da], [lambda x: x, map_ds], [lambda x: x.to_dataset(), map_da], - [lambda x: x.to_array(), map_ds], [lambda x: x.drop("cxy"), map_ds], [lambda x: x.drop("a"), map_ds], [lambda x: x.drop("x"), map_da], From 6faf79ef4c32e29308823fe30d3985f52e65935f Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 3 Oct 2019 09:43:21 -0600 Subject: [PATCH 59/76] more assert_equal --- xarray/tests/test_dask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 9d34490a79e..a09dccf730b 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1070,7 +1070,7 @@ def test_map_blocks_kwargs(obj): with raise_if_dask_computes(): actual = xr.map_blocks(xr.full_like, obj, kwargs=dict(fill_value=np.nan)) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_equal(actual.compute(), expected.compute()) + xr.testing.assert_identical(actual.compute(), expected.compute()) def test_map_blocks_to_array(): @@ -1117,7 +1117,7 @@ def func(obj): expected = xr.map_blocks(func, obj) actual = obj.map_blocks(func) - assert_equal(expected.compute(), actual.compute()) + assert_identical(expected.compute(), actual.compute()) def test_make_meta(): From 47baf76f31cbf4e2d7783ba9922b9c948a8ae4ee Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 3 Oct 2019 09:44:19 -0600 Subject: [PATCH 60/76] test changing coord attribute --- xarray/tests/test_dask.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index a09dccf730b..5d840bc4c09 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1041,6 +1041,7 @@ def test_map_blocks_add_attrs(obj): def add_attrs(obj): obj = obj.copy(deep=True) obj.attrs["new"] = "new" + obj.cxy.attrs["new2"] = "new2" return obj expected = add_attrs(obj) From ce252f2786cc8eb609a544d928ddadbbbd40504b Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 3 Oct 2019 09:45:13 -0600 Subject: [PATCH 61/76] fix whats new --- doc/whats-new.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4d091da4f4d..799ed23d3cb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,8 +26,6 @@ New functions/methods Also added :py:meth:`Dataset.unify_chunks`, :py:meth:`DataArray.unify_chunks` and :py:meth:`testing.assert_chunks_equal`. By `Deepak Cherian `_ and `Guido Imperiale `_. -New functions/methods -~~~~~~~~~~~~~~~~~~~~~ Enhancements ~~~~~~~~~~~~ From 50ae13fc3f4045df686acc6433aa4e7f6ddd0041 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 7 Oct 2019 09:16:36 -0600 Subject: [PATCH 62/76] rework tests to use fixtures (kind of) --- xarray/tests/test_dask.py | 94 ++++++++++++++++++++++++--------------- 1 file changed, 57 insertions(+), 37 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 5d840bc4c09..f703d74b4e6 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -946,7 +946,7 @@ def make_ds(): map_ds["z"] = [0, 1, 2, 3] map_ds["e"] = map_ds.x + map_ds.y map_ds.coords["c1"] = 0.5 - map_ds.coords["cx"] = ("x", np.arange(len(map_da.x))) + map_ds.coords["cx"] = ("x", np.arange(len(map_ds.x))) map_ds.coords["cx"].attrs["test2"] = "test2" map_ds.attrs["test"] = "test" map_ds.coords["xx"] = map_ds["a"] * map_ds.y @@ -954,12 +954,20 @@ def make_ds(): return map_ds -map_da = make_da() -map_ds = make_ds() +# fixtures cannot be used in parametrize statements +# instead use this workaround +# https://docs.pytest.org/en/latest/deprecations.html#calling-fixtures-directly +@pytest.fixture() +def map_da(): + return make_da() -@pytest.mark.parametrize("obj", [map_ds, map_da]) -def test_unify_chunks(obj): +@pytest.fixture() +def map_ds(): + return make_ds() + + +def test_unify_chunks(map_ds): ds_copy = map_ds.copy() ds_copy["cxy"] = ds_copy.cxy.chunk({"y": 10}) @@ -973,16 +981,17 @@ def test_unify_chunks(obj): assert_identical(map_ds, ds_copy.unify_chunks()) +@pytest.mark.parametrize("obj", [make_ds(), make_da()]) @pytest.mark.parametrize( - "obj", - [map_ds.compute(), map_da.compute(), map_ds.unify_chunks(), map_da.unify_chunks()], + "transform", [lambda x: x.compute(), lambda x: x.unify_chunks()] ) -def test_unify_chunks_shallow_copy(obj): +def test_unify_chunks_shallow_copy(obj, transform): + obj = transform(obj) unified = obj.unify_chunks() assert_identical(obj, unified) and obj is not obj.unify_chunks() -def test_map_blocks_error(): +def test_map_blocks_error(map_da, map_ds): def bad_func(darray): return (darray * darray.x + 5 * darray.y)[:1, :1] @@ -1014,7 +1023,7 @@ def really_bad_func(darray): xr.map_blocks(bad_func, ds_copy) -@pytest.mark.parametrize("obj", [map_da, map_ds]) +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks(obj): def func(obj): result = obj + obj.x + 5 * obj.y @@ -1027,7 +1036,7 @@ def func(obj): xr.testing.assert_identical(actual.compute(), expected.compute()) -@pytest.mark.parametrize("obj", [map_da, map_ds]) +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_convert_args_to_list(obj): expected = obj + 10 with raise_if_dask_computes(): @@ -1036,7 +1045,7 @@ def test_map_blocks_convert_args_to_list(obj): xr.testing.assert_identical(actual.compute(), expected.compute()) -@pytest.mark.parametrize("obj", [map_da, map_ds]) +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_add_attrs(obj): def add_attrs(obj): obj = obj.copy(deep=True) @@ -1051,21 +1060,20 @@ def add_attrs(obj): xr.testing.assert_identical(actual.compute(), expected.compute()) -@pytest.mark.parametrize("obj", [map_da]) -def test_map_blocks_change_name(obj): +def test_map_blocks_change_name(map_da): def change_name(obj): obj = obj.copy(deep=True) obj.name = "new" return obj - expected = change_name(obj) + expected = change_name(map_da) with raise_if_dask_computes(): - actual = xr.map_blocks(change_name, obj) + actual = xr.map_blocks(change_name, map_da) xr.testing.assert_identical(actual.compute(), expected.compute()) -@pytest.mark.parametrize("obj", [map_da, map_ds]) +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_kwargs(obj): expected = xr.full_like(obj, fill_value=np.nan) with raise_if_dask_computes(): @@ -1074,7 +1082,7 @@ def test_map_blocks_kwargs(obj): xr.testing.assert_identical(actual.compute(), expected.compute()) -def test_map_blocks_to_array(): +def test_map_blocks_to_array(map_ds): with raise_if_dask_computes(): actual = xr.map_blocks(lambda x: x.to_array(), map_ds) @@ -1083,32 +1091,44 @@ def test_map_blocks_to_array(): @pytest.mark.parametrize( - "func, obj", + "func", [ - [lambda x: x, map_da], - [lambda x: x, map_ds], - [lambda x: x.to_dataset(), map_da], - [lambda x: x.drop("cxy"), map_ds], - [lambda x: x.drop("a"), map_ds], - [lambda x: x.drop("x"), map_da], - [lambda x: x.drop("x"), map_ds], - [lambda x: x.expand_dims(k=[1, 2, 3]), map_ds], - [lambda x: x.expand_dims(k=[1, 2, 3]), map_da], - # TODO: [lambda x: x.isel(x=1), map_ds], + lambda x: x, + lambda x: x.to_dataset(), + lambda x: x.drop("x"), + lambda x: x.expand_dims(k=[1, 2, 3]), + lambda x: x.assign_coords(new_coord=("y", x.y * 2)), + lambda x: x.astype(np.int32), # TODO: [lambda x: x.isel(x=1).drop("x"), map_da], - [lambda x: x.assign_coords(new_coord=("y", x.y * 2)), map_da], - [lambda x: x.astype(np.int32), map_da], - [lambda x: x.rename({"a": "new1", "b": "new2"}), map_ds], ], ) -def test_map_blocks_transformations(func, obj): +def test_map_blocks_da_transformations(func, map_da): with raise_if_dask_computes(): - actual = xr.map_blocks(func, obj) + actual = xr.map_blocks(func, map_da) + + assert_identical(actual.compute(), func(map_da).compute()) + + +@pytest.mark.parametrize( + "func", + [ + lambda x: x, + lambda x: x.drop("cxy"), + lambda x: x.drop("a"), + lambda x: x.drop("x"), + lambda x: x.expand_dims(k=[1, 2, 3]), + lambda x: x.rename({"a": "new1", "b": "new2"}), + # TODO: [lambda x: x.isel(x=1)], + ], +) +def test_map_blocks_ds_transformations(func, map_ds): + with raise_if_dask_computes(): + actual = xr.map_blocks(func, map_ds) - assert_identical(actual.compute(), func(obj).compute()) + assert_identical(actual.compute(), func(map_ds).compute()) -@pytest.mark.parametrize("obj", [map_da, map_ds]) +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_object_method(obj): def func(obj): result = obj + obj.x + 5 * obj.y @@ -1121,7 +1141,7 @@ def func(obj): assert_identical(expected.compute(), actual.compute()) -def test_make_meta(): +def test_make_meta(map_ds): from ..core.parallel import make_meta meta = make_meta(map_ds) From cdcf2218712dab2e6addd148e41eda31e6636ee9 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 7 Oct 2019 09:35:52 -0600 Subject: [PATCH 63/76] more review changes. --- xarray/core/dataarray.py | 4 +++- xarray/core/dataset.py | 4 +++- xarray/core/parallel.py | 21 ++++++++++++++++----- xarray/tests/test_dask.py | 6 ++++++ 4 files changed, 28 insertions(+), 7 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 106c3485946..95b7723419e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3057,7 +3057,7 @@ def integrate( return self._from_temp_dataset(ds) def unify_chunks(self) -> "DataArray": - """ Unifies chunksize along all chunked dimensions of this DataArray. + """ Unify chunk size along all chunked dimensions of this DataArray. Returns ------- @@ -3103,8 +3103,10 @@ def map_blocks( args: Sequence Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. + Passing dask objects will raise an error. kwargs: Mapping Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. + Passing dask objects will raise an error. Returns ------- diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6ee089eb913..588f9170c9c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5403,7 +5403,7 @@ def filter_by_attrs(self, **kwargs): return self[selection] def unify_chunks(self) -> "Dataset": - """ Unifies chunksize along all chunked dimensions of this Dataset. + """ Unify chunk size along all chunked dimensions of this Dataset. Returns ------- @@ -5480,8 +5480,10 @@ def map_blocks( args: Sequence Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. + Passing dask objects will raise an error. kwargs: Mapping Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. + Passing dask objects will raise an error. Returns ------- diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 104c8c01155..aa08191f147 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -135,10 +135,10 @@ def map_blocks( function will be run eagerly. args: tuple Passed on to func after unpacking. xarray objects, if any, will not be split by - chunks. + chunks. Passing dask objects will raise an error. kwargs: dict Passed on to func after unpacking. xarray objects, if any, will not be split by - chunks. + chunks. Passing dask objects will raise an error. Returns ------- @@ -181,6 +181,12 @@ def _wrapper(func, obj, to_array, args, kwargs): elif not isinstance(kwargs, Mapping): raise TypeError("kwargs must be a mapping (for example, a dict)") + for value in list(args) + list(kwargs.values()): + if dask.is_dask_collection(value): + raise ValueError( + "Cannot pass dask variables in args or kwargs yet. Please compute or load values before passing to map_blocks" + ) + if not dask.is_dask_collection(obj): return func(obj, *args, **kwargs) @@ -221,7 +227,10 @@ def _wrapper(func, obj, to_array, args, kwargs): indexes.update({k: template.indexes[k] for k in new_indexes}) graph = {} # type: Dict[Any, Any] - gname = "%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset)) + gname = "%s-%s" % ( + dask.utils.funcname(func), + dask.base.tokenize(dataset, args, kwargs), + ) # map dims to list of chunk indexes ichunk = {dim: range(len(chunks_v)) for dim, chunks_v in input_chunks.items()} @@ -249,7 +258,7 @@ def _wrapper(func, obj, to_array, args, kwargs): for dim in variable.dims: chunk = chunk[chunk_index_dict[dim]] - chunk_variable_task = ("tuple-" + dask.base.tokenize(chunk),) + v + chunk_variable_task = ("%s-%s" % (gname, chunk[0]),) + v graph[chunk_variable_task] = ( tuple, [variable.dims, chunk, variable.attrs], @@ -267,7 +276,9 @@ def _wrapper(func, obj, to_array, args, kwargs): ) subset = variable.isel(subsetter) - chunk_variable_task = (name + dask.base.tokenize(subset),) + v + chunk_variable_task = ( + "%s-%s" % (gname, dask.base.tokenize(subset)), + ) + v graph[chunk_variable_task] = ( tuple, [subset.dims, subset, subset.attrs], diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index f703d74b4e6..6497b93a2c5 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1022,6 +1022,12 @@ def really_bad_func(darray): with raises_regex(ValueError, "inconsistent chunks"): xr.map_blocks(bad_func, ds_copy) + with raises_regex(ValueError, "Cannot pass dask variables"): + xr.map_blocks(bad_func, map_da, args=[map_da.chunk()]) + + with raises_regex(ValueError, "Cannot pass dask variables"): + xr.map_blocks(bad_func, map_da, kwargs=dict(a=map_da.chunk())) + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks(obj): From f1675379bd6b1eb9ba6e641436a669ccbb89b0d8 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 7 Oct 2019 09:43:22 -0600 Subject: [PATCH 64/76] cleanup --- xarray/tests/test_dask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 6497b93a2c5..46fd5c6a522 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -957,12 +957,12 @@ def make_ds(): # fixtures cannot be used in parametrize statements # instead use this workaround # https://docs.pytest.org/en/latest/deprecations.html#calling-fixtures-directly -@pytest.fixture() +@pytest.fixture def map_da(): return make_da() -@pytest.fixture() +@pytest.fixture def map_ds(): return make_ds() From 4390f73643f0f0f81892aaf1cb3ac9a765a662c8 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 7 Oct 2019 13:38:58 -0600 Subject: [PATCH 65/76] more review feedback. --- xarray/core/dataset.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 588f9170c9c..d74ff9bc256 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5416,14 +5416,10 @@ def unify_chunks(self) -> "Dataset": dask.array.core.unify_chunks """ - import dask - - if not dask.is_dask_collection(self): - return self.copy() - try: - self.chunks - return self.copy() + # if pure numpy + if len(self.chunks) == 0: + return self.copy() except ValueError: # "inconsistent chunks" pass From c93655793fa0085b931f9161ecf5e75530f0ccb0 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 7 Oct 2019 19:04:36 -0600 Subject: [PATCH 66/76] fix unify_chunks. --- xarray/core/dataset.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d74ff9bc256..4c18969897f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5417,12 +5417,15 @@ def unify_chunks(self) -> "Dataset": """ try: - # if pure numpy - if len(self.chunks) == 0: - return self.copy() + self.chunks except ValueError: # "inconsistent chunks" pass + else: + # No variables with dask backend, or all chunks are already aligned + return self.copy() + # import dask is placed after the quick exit test above to allow + # running this method if dask isn't installed and there are no chunks import dask.array ds = self.copy() From 2c7938ab4b38426043e12e14085aefbfcfdc97be Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 8 Oct 2019 19:37:03 -0600 Subject: [PATCH 67/76] read dask_array_compat :) --- xarray/core/dask_array_compat.py | 92 ++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 xarray/core/dask_array_compat.py diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py new file mode 100644 index 00000000000..f3bb849322c --- /dev/null +++ b/xarray/core/dask_array_compat.py @@ -0,0 +1,92 @@ +from distutils.version import LooseVersion + +import dask.array as da +import numpy as np +from dask import __version__ as dask_version + + +if LooseVersion(dask_version) > LooseVersion("2.0.0"): + meta_from_array = da.utils.meta_from_array +else: + # Copied from dask v2.4.0 + # Used under the terms of Dask's license, see licenses/DASK_LICENSE. + import numbers + + def meta_from_array(x, ndim=None, dtype=None): + """ Normalize an array to appropriate meta object + + Parameters + ---------- + x: array-like, callable + Either an object that looks sufficiently like a Numpy array, + or a callable that accepts shape and dtype keywords + ndim: int + Number of dimensions of the array + dtype: Numpy dtype + A valid input for ``np.dtype`` + + Returns + ------- + array-like with zero elements of the correct dtype + """ + # If using x._meta, x must be a Dask Array, some libraries (e.g. zarr) + # implement a _meta attribute that are incompatible with Dask Array._meta + if hasattr(x, "_meta") and isinstance(x, da.Array): + x = x._meta + + if dtype is None and x is None: + raise ValueError("You must specify the meta or dtype of the array") + + if np.isscalar(x): + x = np.array(x) + + if x is None: + x = np.ndarray + + if isinstance(x, type): + x = x(shape=(0,) * (ndim or 0), dtype=dtype) + + if ( + not hasattr(x, "shape") + or not hasattr(x, "dtype") + or not isinstance(x.shape, tuple) + ): + return x + + if isinstance(x, list) or isinstance(x, tuple): + ndims = [ + 0 + if isinstance(a, numbers.Number) + else a.ndim + if hasattr(a, "ndim") + else len(a) + for a in x + ] + a = [a if nd == 0 else meta_from_array(a, nd) for a, nd in zip(x, ndims)] + return a if isinstance(x, list) else tuple(x) + + if ndim is None: + ndim = x.ndim + + try: + meta = x[tuple(slice(0, 0, None) for _ in range(x.ndim))] + if meta.ndim != ndim: + if ndim > x.ndim: + meta = meta[ + (Ellipsis,) + tuple(None for _ in range(ndim - meta.ndim)) + ] + meta = meta[tuple(slice(0, 0, None) for _ in range(meta.ndim))] + elif ndim == 0: + meta = meta.sum() + else: + meta = meta.reshape((0,) * ndim) + except Exception: + meta = np.empty((0,) * ndim, dtype=dtype or x.dtype) + + if np.isscalar(meta): + meta = np.array(meta) + + if dtype and meta.dtype != dtype: + meta = meta.astype(dtype) + + return meta From 08ed8739f67f9650d5390d5d180ae78d88c216ac Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 10 Oct 2019 08:56:20 -0600 Subject: [PATCH 68/76] Dask 1.2.0 compat. --- xarray/core/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index aa08191f147..77ee1f1e17d 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -53,7 +53,7 @@ def make_meta(obj): if isinstance(obj, Dataset): meta = Dataset() for name, variable in obj.variables.items(): - meta_obj = meta_from_array(variable.data) + meta_obj = meta_from_array(variable.data, ndim=variable.ndim) meta[name] = (variable.dims, meta_obj, variable.attrs) meta.attrs = obj.attrs else: From 99d61fca2e0bd93a7c6ffb99cefab33b5a2d60d3 Mon Sep 17 00:00:00 2001 From: Unknown Date: Thu, 10 Oct 2019 23:18:06 +0100 Subject: [PATCH 69/76] documentation polish --- xarray/core/dataarray.py | 47 ++++++++++++++--------------- xarray/core/dataset.py | 49 +++++++++++++++--------------- xarray/core/parallel.py | 64 +++++++++++++++++++++------------------- 3 files changed, 82 insertions(+), 78 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 076bf70a052..31fd27c5671 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3064,50 +3064,51 @@ def map_blocks( kwargs: Mapping[str, Any] = None, ) -> "T_DSorDA": """ - Apply a function to each chunk of this DataArray. This function is experimental + Apply a function to each chunk of this DataArray. This method is experimental and its signature may change. Parameters ---------- func: callable - User-provided function that should accept xarray objects. - This function will receive a subset of this DataArray, corresponding to one chunk along - each chunked dimension. - The function will be run on a small piece of data that looks like 'obj' to determine - properties of the returned object such as dtype, variable names, - new dimensions and new indexes (if any). + User-provided function that accepts a DataArray as its first parameter. The + function will receive a subset of this DataArray, corresponding to one chunk + along each chunked dimension. - This function must - - return either a single DataArray or a single Dataset + The function will be first run on mocked-up data, that looks like this array + but has sizes 0, to determine properties of the returned object such as + dtype, variable names, new dimensions and new indexes (if any). - This function cannot - - change size of existing dimensions. - - add new chunked dimensions. - - If your function expects numpy arrays, see `xarray.apply_ufunc` + This function must return either a single DataArray or a single Dataset. + This function cannot change size of existing dimensions, or add new chunked + dimensions. args: Sequence - Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. - Passing dask objects will raise an error. + Passed verbatim to func after unpacking, after the sliced DataArray. xarray + objects, if any, will not be split by chunks. Passing dask collections is + not allowed. kwargs: Mapping - Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. - Passing dask objects will raise an error. + Passed verbatim to func after unpacking. xarray objects, if any, will not be + split by chunks. Passing dask collections is not allowed. Returns ------- - A single DataArray or Dataset + A single DataArray or Dataset with dask backend, reassembled from the outputs of + the function. Notes ----- + This method is designed for when one needs to manipulate a whole xarray object + within each chunk. In the more common case where one can work on numpy arrays, + it is recommended to use apply_ufunc. - This function is designed to work with dask-backed xarray objects. See apply_ufunc for - a similar function that works with numpy arrays. + If none of the variables in this DataArray is backed by dask, calling this + method is equivalent to calling ``func(self, *args, **kwargs)``. See Also -------- - dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks, xarray.Dataset.map_blocks + dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks, + xarray.Dataset.map_blocks """ - from .parallel import map_blocks return map_blocks(func, self, args, kwargs) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4399051ab51..62bf7b9bc4f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5425,50 +5425,51 @@ def map_blocks( kwargs: Mapping[str, Any] = None, ) -> "T_DSorDA": """ - Apply a function to each chunk of this Dataset. This function is experimental - and its signature may change. + Apply a function to each chunk of this Dataset. This method is experimental and + its signature may change. Parameters ---------- func: callable - User-provided function that should accept xarray objects. - This function will receive a subset of this dataset, corresponding to one chunk along - each chunked dimension. - The function will be run on a small piece of data that looks like 'obj' to determine - properties of the returned object such as dtype, variable names, - new dimensions and new indexes (if any). + User-provided function that accepts a Dataset as its first parameter. The + function will receive a subset of this Dataset, corresponding to one chunk + along each chunked dimension. - This function must - - return either a single DataArray or a single Dataset + The function will be first run on mocked-up data, that looks like this + Dataset but has sizes 0, to determine properties of the returned object such + as dtype, variable names, new dimensions and new indexes (if any). - This function cannot - - change size of existing dimensions. - - add new chunked dimensions. - - If your function expects numpy arrays, see `xarray.apply_ufunc` + This function must return either a single DataArray or a single Dataset. + This function cannot change size of existing dimensions, or add new chunked + dimensions. args: Sequence - Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. - Passing dask objects will raise an error. + Passed verbatim to func after unpacking, after the sliced DataArray. xarray + objects, if any, will not be split by chunks. Passing dask collections is + not allowed. kwargs: Mapping - Passed on to func after unpacking. xarray objects, if any, will not be split by chunks. - Passing dask objects will raise an error. + Passed verbatim to func after unpacking. xarray objects, if any, will not be + split by chunks. Passing dask collections is not allowed. Returns ------- - A single DataArray or Dataset + A single DataArray or Dataset with dask backend, reassembled from the outputs of + the function. Notes ----- + This method is designed for when one needs to manipulate a whole xarray object + within each chunk. In the more common case where one can work on numpy arrays, + it is recommended to use apply_ufunc. - This function is designed to work with dask-backed xarray objects. See apply_ufunc for - a similar function that works with numpy arrays. + If none of the variables in this Dataset is backed by dask, calling this method + is equivalent to calling ``func(self, *args, **kwargs)``. See Also -------- - dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks, xarray.DataArray.map_blocks + dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks, + xarray.DataArray.map_blocks """ - from .parallel import map_blocks return map_blocks(func, self, args, kwargs) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 77ee1f1e17d..0ffa0637b61 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -42,7 +42,11 @@ def dataset_to_dataarray(obj: Dataset) -> DataArray: def make_meta(obj): - + """If obj is a DataArray or Dataset, return a new object of the same type and with + the same variables and dtypes, but where all variables have size 0 and numpy + backend. + If obj is neither a DataArray nor Dataset, return it unaltered. + """ if isinstance(obj, DataArray): to_array = True obj_array = obj.copy() @@ -71,7 +75,8 @@ def make_meta(obj): def infer_template( func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], *args, **kwargs ) -> T_DSorDA: - """ Infer return object by running the function on meta objects. """ + """Infer return object by running the function on meta objects. + """ meta_args = [make_meta(arg) for arg in (obj,) + args] try: @@ -112,43 +117,40 @@ def map_blocks( Parameters ---------- func: callable - User-provided function that should accept xarray objects. This function will - receive a subset of this dataset, corresponding to one chunk along each chunked - dimension. To determine properties of the returned object such as type - (DataArray or Dataset), dtypes, and new/removed dimensions and/or variables, the - function will be run on dummy data with the same variables, dimension names, and - data types as this DataArray, but zero-sized dimensions. - - This function must - - return either a single DataArray or a single Dataset - - This function cannot - - change size of existing dimensions. - - add new chunked dimensions. - - This function should work with whole xarray objects. If your function can be - applied to numpy or dask arrays (e.g. it doesn't need additional metadata such - as dimension names, variable names, etc.), you should consider using - :py:func:`~xarray.apply_ufunc` instead. + User-provided function that accepts a DataArray or Dataset as its first + parameter. The function will receive a subset of 'obj' (see below), + corresponding to one chunk along each chunked dimension. + + The function will be first run on mocked-up data, that looks like 'obj' but + has sizes 0, to determine properties of the returned object such as dtype, + variable names, new dimensions and new indexes (if any). + + This function must return either a single DataArray or a single Dataset. + + This function cannot change size of existing dimensions, or add new chunked + dimensions. obj: DataArray, Dataset - Chunks of this object will be provided to 'func'. If passed a numpy object, the - function will be run eagerly. - args: tuple - Passed on to func after unpacking. xarray objects, if any, will not be split by - chunks. Passing dask objects will raise an error. - kwargs: dict - Passed on to func after unpacking. xarray objects, if any, will not be split by - chunks. Passing dask objects will raise an error. + Passed to the function as its first argument, one dask chunk at a time. + args: Sequence + Passed verbatim to func after unpacking, after the sliced obj. xarray objects, + if any, will not be split by chunks. Passing dask collections is not allowed. + kwargs: Mapping + Passed verbatim to func after unpacking. xarray objects, if any, will not be + split by chunks. Passing dask collections is not allowed. Returns ------- - A single DataArray or Dataset + A single DataArray or Dataset with dask backend, reassembled from the outputs of the + function. Notes ----- + This function is designed for when one needs to manipulate a whole xarray object + within each chunk. In the more common case where one can work on numpy arrays, it is + recommended to use apply_ufunc. - This function is designed to work with dask-backed xarray objects. See apply_ufunc - for a similar function that works with numpy arrays. + If none of the variables in obj is backed by dask, calling this function is + equivalent to calling ``func(obj, *args, **kwargs)``. See Also -------- From 687689e53d8f5d3d1349ce5dfee0e9a2c7b6c2f1 Mon Sep 17 00:00:00 2001 From: Unknown Date: Thu, 10 Oct 2019 23:20:00 +0100 Subject: [PATCH 70/76] make_meta reflow --- xarray/core/parallel.py | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 0ffa0637b61..8420dd79ab1 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -48,27 +48,22 @@ def make_meta(obj): If obj is neither a DataArray nor Dataset, return it unaltered. """ if isinstance(obj, DataArray): - to_array = True - obj_array = obj.copy() + obj_array = obj obj = obj._to_temp_dataset() + elif isinstance(obj, Dataset): + obj_array = None else: - to_array = False - - if isinstance(obj, Dataset): - meta = Dataset() - for name, variable in obj.variables.items(): - meta_obj = meta_from_array(variable.data, ndim=variable.ndim) - meta[name] = (variable.dims, meta_obj, variable.attrs) - meta.attrs = obj.attrs - else: - meta = obj - - if isinstance(obj, Dataset): - meta = meta.set_coords(obj.coords) + return obj - if to_array: - meta = obj_array._from_temp_dataset(meta) + meta = Dataset() + for name, variable in obj.variables.items(): + meta_obj = meta_from_array(variable.data, ndim=variable.ndim) + meta[name] = (variable.dims, meta_obj, variable.attrs) + meta.attrs = obj.attrs + meta = meta.set_coords(obj.coords) + if obj_array is not None: + return obj_array._from_temp_dataset(meta) return meta @@ -88,8 +83,8 @@ def infer_template( if not isinstance(template, (Dataset, DataArray)): raise TypeError( - "Function must return an xarray DataArray or Dataset. Instead it returned %r" - % type(template) + "Function must return an xarray DataArray or Dataset. Instead it returned " + f"{type(template)}" ) return template From f588cb6f15bca321beeb4427cf2203e981f9419e Mon Sep 17 00:00:00 2001 From: Unknown Date: Thu, 10 Oct 2019 23:38:01 +0100 Subject: [PATCH 71/76] cosmetic --- xarray/core/dataset.py | 38 +++++++++++++++++++------------------- xarray/core/parallel.py | 7 +++---- xarray/testing.py | 4 ++-- 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 62bf7b9bc4f..83895b19965 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1674,8 +1674,8 @@ def chunks(self) -> Mapping[Hashable, Tuple[int, ...]]: for dim, c in zip(v.dims, v.chunks): if dim in chunks and c != chunks[dim]: raise ValueError( - "Object has inconsistent chunks along dimension %r. This can be fixed by calling unify_chunks()." - % dim + f"Object has inconsistent chunks along dimension {dim}. " + "This can be fixed by calling unify_chunks()." ) chunks[dim] = c return Frozen(SortedKeysDict(chunks)) @@ -1861,7 +1861,7 @@ def isel( self, indexers: Mapping[Hashable, Any] = None, drop: bool = False, - **indexers_kwargs: Any + **indexers_kwargs: Any, ) -> "Dataset": """Returns a new dataset with each array indexed along the specified dimension(s). @@ -1944,7 +1944,7 @@ def sel( method: str = None, tolerance: Number = None, drop: bool = False, - **indexers_kwargs: Any + **indexers_kwargs: Any, ) -> "Dataset": """Returns a new dataset with each array indexed by tick labels along the specified dimension(s). @@ -2017,7 +2017,7 @@ def sel( def head( self, indexers: Union[Mapping[Hashable, int], int] = None, - **indexers_kwargs: Any + **indexers_kwargs: Any, ) -> "Dataset": """Returns a new dataset with the first `n` values of each array for the specified dimension(s). @@ -2064,7 +2064,7 @@ def head( def tail( self, indexers: Union[Mapping[Hashable, int], int] = None, - **indexers_kwargs: Any + **indexers_kwargs: Any, ) -> "Dataset": """Returns a new dataset with the last `n` values of each array for the specified dimension(s). @@ -2114,7 +2114,7 @@ def tail( def thin( self, indexers: Union[Mapping[Hashable, int], int] = None, - **indexers_kwargs: Any + **indexers_kwargs: Any, ) -> "Dataset": """Returns a new dataset with each array indexed along every `n`th value for the specified dimension(s) @@ -2252,7 +2252,7 @@ def reindex( tolerance: Number = None, copy: bool = True, fill_value: Any = dtypes.NA, - **indexers_kwargs: Any + **indexers_kwargs: Any, ) -> "Dataset": """Conform this object onto a new set of indexes, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -2453,7 +2453,7 @@ def interp( method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, - **coords_kwargs: Any + **coords_kwargs: Any, ) -> "Dataset": """ Multidimensional interpolation of Dataset. @@ -2679,7 +2679,7 @@ def rename( self, name_dict: Mapping[Hashable, Hashable] = None, inplace: bool = None, - **names: Hashable + **names: Hashable, ) -> "Dataset": """Returns a new object with renamed variables and dimensions. @@ -2882,7 +2882,7 @@ def expand_dims( self, dim: Union[None, Hashable, Sequence[Hashable], Mapping[Hashable, Any]] = None, axis: Union[None, int, Sequence[int]] = None, - **dim_kwargs: Any + **dim_kwargs: Any, ) -> "Dataset": """Return a new object with an additional axis (or axes) inserted at the corresponding position in the array shape. The new object is a @@ -3028,7 +3028,7 @@ def set_index( indexes: Mapping[Hashable, Union[Hashable, Sequence[Hashable]]] = None, append: bool = False, inplace: bool = None, - **indexes_kwargs: Union[Hashable, Sequence[Hashable]] + **indexes_kwargs: Union[Hashable, Sequence[Hashable]], ) -> "Dataset": """Set Dataset (multi-)indexes using one or more existing coordinates or variables. @@ -3130,7 +3130,7 @@ def reorder_levels( self, dim_order: Mapping[Hashable, Sequence[int]] = None, inplace: bool = None, - **dim_order_kwargs: Sequence[int] + **dim_order_kwargs: Sequence[int], ) -> "Dataset": """Rearrange index levels using input order. @@ -3196,7 +3196,7 @@ def _stack_once(self, dims, new_dim): def stack( self, dimensions: Mapping[Hashable, Sequence[Hashable]] = None, - **dimensions_kwargs: Sequence[Hashable] + **dimensions_kwargs: Sequence[Hashable], ) -> "Dataset": """ Stack any number of existing dimensions into a single new dimension. @@ -3897,7 +3897,7 @@ def interpolate_na( method: str = "linear", limit: int = None, use_coordinate: Union[bool, Hashable] = True, - **kwargs: Any + **kwargs: Any, ) -> "Dataset": """Interpolate values according to different methods. @@ -3948,7 +3948,7 @@ def interpolate_na( method=method, limit=limit, use_coordinate=use_coordinate, - **kwargs + **kwargs, ) return new @@ -4029,7 +4029,7 @@ def reduce( keepdims: bool = False, numeric_only: bool = False, allow_lazy: bool = False, - **kwargs: Any + **kwargs: Any, ) -> "Dataset": """Reduce this dataset by applying `func` along some dimension(s). @@ -4104,7 +4104,7 @@ def reduce( keep_attrs=keep_attrs, keepdims=keepdims, allow_lazy=allow_lazy, - **kwargs + **kwargs, ) coord_names = {k for k in self.coords if k in variables} @@ -4119,7 +4119,7 @@ def apply( func: Callable, keep_attrs: bool = None, args: Iterable[Any] = (), - **kwargs: Any + **kwargs: Any, ) -> "Dataset": """Apply a function over the data variables in this dataset. diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 8420dd79ab1..2c130f3cc9d 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -171,17 +171,16 @@ def _wrapper(func, obj, to_array, args, kwargs): if not isinstance(args, Sequence): raise TypeError("args must be a sequence (for example, a list or tuple).") - if kwargs is None: kwargs = {} - elif not isinstance(kwargs, Mapping): raise TypeError("kwargs must be a mapping (for example, a dict)") for value in list(args) + list(kwargs.values()): if dask.is_dask_collection(value): - raise ValueError( - "Cannot pass dask variables in args or kwargs yet. Please compute or load values before passing to map_blocks" + raise TypeError( + "Cannot pass dask collections in args or kwargs yet. Please compute or " + "load values before passing to map_blocks." ) if not dask.is_dask_collection(obj): diff --git a/xarray/testing.py b/xarray/testing.py index 5bd5c3676ab..95e41cfb10c 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -148,9 +148,9 @@ def assert_chunks_equal(a, b): Parameters ---------- - a : xarray.Dataset, xarray.DataArray or xarray.Variable + a : xarray.Dataset or xarray.DataArray The first object to compare. - b : xarray.Dataset, xarray.DataArray or xarray.Variable + b : xarray.Dataset or xarray.DataArray The second object to compare. """ From d476e2f90ffc1cf718f7e5d8d9ff1b7648f6968e Mon Sep 17 00:00:00 2001 From: Unknown Date: Thu, 10 Oct 2019 23:38:41 +0100 Subject: [PATCH 72/76] polish --- xarray/core/dask_array_compat.py | 2 +- xarray/core/parallel.py | 16 +++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index f3bb849322c..a7a69c4edf0 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -5,7 +5,7 @@ from dask import __version__ as dask_version -if LooseVersion(dask_version) > LooseVersion("2.0.0"): +if LooseVersion(dask_version) >= LooseVersion("2.0.0"): meta_from_array = da.utils.meta_from_array else: # Copied from dask v2.4.0 diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 2c130f3cc9d..a381a9115b4 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -201,9 +201,7 @@ def _wrapper(func, obj, to_array, args, kwargs): input_chunks = dataset.chunks - template = infer_template( - func, obj, *args, **kwargs - ) # type: Union[DataArray, Dataset] + template: Union[DataArray, Dataset] = infer_template(func, obj, *args, **kwargs) if isinstance(template, DataArray): result_is_array = True template_name = template.name @@ -212,17 +210,17 @@ def _wrapper(func, obj, to_array, args, kwargs): result_is_array = False else: raise TypeError( - "func output must be DataArray or Dataset; got %s" % type(template) + f"func output must be DataArray or Dataset; got {type(template)}" ) template_indexes = set(template.indexes) dataset_indexes = set(dataset.indexes) - preserved_indexes = template_indexes.intersection(dataset_indexes) - new_indexes = set(template_indexes) - set(dataset_indexes) + preserved_indexes = template_indexes & dataset_indexes + new_indexes = template_indexes - dataset_indexes indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} indexes.update({k: template.indexes[k] for k in new_indexes}) - graph = {} # type: Dict[Any, Any] + graph: Dict[Any, Any] = {} gname = "%s-%s" % ( dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs), @@ -297,14 +295,14 @@ def _wrapper(func, obj, to_array, args, kwargs): ) # mapping from variable name to dask graph key - var_key_map = {} # type: Dict[Hashable, str] + var_key_map: Dict[Hashable, str] = {} for name, variable in template.variables.items(): if name in indexes: continue gname_l = "%s-%s" % (gname, name) var_key_map[name] = gname_l - key = (gname_l,) # type: Tuple[Any, ...] + key: Tuple[Any, ...] = (gname_l,) for dim in variable.dims: if dim in chunk_index_dict: key += (chunk_index_dict[dim],) From 26a6a0dad8718784d67c918736c2b35ca386e45f Mon Sep 17 00:00:00 2001 From: Unknown Date: Thu, 10 Oct 2019 23:42:11 +0100 Subject: [PATCH 73/76] Fix tests --- xarray/tests/test_dask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index d8e6b5dc953..af0985ea394 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -999,10 +999,10 @@ def really_bad_func(darray): with raises_regex(ValueError, "inconsistent chunks"): xr.map_blocks(bad_func, ds_copy) - with raises_regex(ValueError, "Cannot pass dask variables"): + with raises_regex(TypeError, "Cannot pass dask collections"): xr.map_blocks(bad_func, map_da, args=[map_da.chunk()]) - with raises_regex(ValueError, "Cannot pass dask variables"): + with raises_regex(TypeError, "Cannot pass dask collections"): xr.map_blocks(bad_func, map_da, kwargs=dict(a=map_da.chunk())) From 64917537cb14ff3b341ee7a72f5573ec0ef70d40 Mon Sep 17 00:00:00 2001 From: Unknown Date: Thu, 10 Oct 2019 23:44:22 +0100 Subject: [PATCH 74/76] isort --- xarray/tests/test_dask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index af0985ea394..3e2e132825f 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -12,8 +12,8 @@ import xarray as xr import xarray.ufuncs as xu from xarray import DataArray, Dataset, Variable -from xarray.tests import mock from xarray.testing import assert_chunks_equal +from xarray.tests import mock from . import ( assert_allclose, From b227beac8acbcb3d016f6a424f174c5ad1041630 Mon Sep 17 00:00:00 2001 From: Unknown Date: Thu, 10 Oct 2019 23:45:48 +0100 Subject: [PATCH 75/76] isort --- xarray/coding/times.py | 1 - xarray/core/dask_array_compat.py | 1 - xarray/core/parallel.py | 10 +++++----- xarray/core/pdcompat.py | 1 - xarray/core/variable.py | 2 +- xarray/tests/__init__.py | 3 +-- xarray/tests/test_backends.py | 2 +- xarray/tests/test_coding_times.py | 2 -- 8 files changed, 8 insertions(+), 14 deletions(-) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 1508fb50b38..ed6908117a2 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -22,7 +22,6 @@ unpack_for_encoding, ) - # standard calendars recognized by cftime _STANDARD_CALENDARS = {"standard", "gregorian", "proleptic_gregorian"} diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index a7a69c4edf0..c3dbdd27098 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -4,7 +4,6 @@ import numpy as np from dask import __version__ as dask_version - if LooseVersion(dask_version) >= LooseVersion("2.0.0"): meta_from_array = da.utils.meta_from_array else: diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index a381a9115b4..fdf2445a233 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -8,12 +8,7 @@ pass import itertools -import numpy as np import operator - -from .dataarray import DataArray -from .dataset import Dataset - from typing import ( Any, Callable, @@ -26,6 +21,11 @@ Union, ) +import numpy as np + +from .dataarray import DataArray +from .dataset import Dataset + T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) diff --git a/xarray/core/pdcompat.py b/xarray/core/pdcompat.py index 7591fff3abe..f2e4518e0dc 100644 --- a/xarray/core/pdcompat.py +++ b/xarray/core/pdcompat.py @@ -41,7 +41,6 @@ import pandas as pd - # allow ourselves to type checks for Panel even after it's removed if LooseVersion(pd.__version__) < "0.25.0": Panel = pd.Panel diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 6d7a07c6791..24865d62666 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -3,7 +3,7 @@ from collections import OrderedDict, defaultdict from datetime import timedelta from distutils.version import LooseVersion -from typing import Any, Hashable, Mapping, Union, TypeVar +from typing import Any, Hashable, Mapping, TypeVar, Union import numpy as np import pandas as pd diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 8b4d3073e1c..acf8b67effa 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -9,6 +9,7 @@ import numpy as np import pytest from numpy.testing import assert_array_equal # noqa: F401 +from pandas.testing import assert_frame_equal # noqa: F401 import xarray.testing from xarray.core import utils @@ -17,8 +18,6 @@ from xarray.core.options import set_options from xarray.plot.utils import import_seaborn -from pandas.testing import assert_frame_equal # noqa: F401 - # import mpl and change the backend before other mpl imports try: import matplotlib as mpl diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0120e2ca0fe..b5421a6bc9f 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -14,8 +14,8 @@ import numpy as np import pandas as pd -from pandas.errors import OutOfBoundsDatetime import pytest +from pandas.errors import OutOfBoundsDatetime import xarray as xr from xarray import ( diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 406b9c1ba69..33a409e6f45 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -6,7 +6,6 @@ import pytest from pandas.errors import OutOfBoundsDatetime - from xarray import DataArray, Dataset, Variable, coding, decode_cf from xarray.coding.times import ( _import_cftime, @@ -30,7 +29,6 @@ requires_cftime_or_netCDF4, ) - _NON_STANDARD_CALENDARS_SET = { "noleap", "365_day", From 2a41906970ae720bcad04c0b7abd60655ac7fb5f Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 10 Oct 2019 17:18:57 -0600 Subject: [PATCH 76/76] Add func call to docstrings. --- xarray/core/dataarray.py | 3 ++- xarray/core/dataset.py | 3 ++- xarray/core/parallel.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 31fd27c5671..1b1d23bc2fc 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3072,7 +3072,8 @@ def map_blocks( func: callable User-provided function that accepts a DataArray as its first parameter. The function will receive a subset of this DataArray, corresponding to one chunk - along each chunked dimension. + along each chunked dimension. ``func`` will be executed as + ``func(obj_subset, *args, **kwargs)``. The function will be first run on mocked-up data, that looks like this array but has sizes 0, to determine properties of the returned object such as diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 83895b19965..42990df6f65 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5433,7 +5433,8 @@ def map_blocks( func: callable User-provided function that accepts a Dataset as its first parameter. The function will receive a subset of this Dataset, corresponding to one chunk - along each chunked dimension. + along each chunked dimension. ``func`` will be executed as + ``func(obj_subset, *args, **kwargs)``. The function will be first run on mocked-up data, that looks like this Dataset but has sizes 0, to determine properties of the returned object such diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index fdf2445a233..48bb9ccfc3d 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -114,7 +114,8 @@ def map_blocks( func: callable User-provided function that accepts a DataArray or Dataset as its first parameter. The function will receive a subset of 'obj' (see below), - corresponding to one chunk along each chunked dimension. + corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(obj_subset, *args, **kwargs)``. The function will be first run on mocked-up data, that looks like 'obj' but has sizes 0, to determine properties of the returned object such as dtype,