diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 84728007b42..a0dfe56807b 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -25,7 +25,6 @@ class ExpectedDict(TypedDict): shapes: dict[Hashable, int] coords: set[Hashable] data_vars: set[Hashable] - indexes: dict[Hashable, Index] def unzip(iterable): @@ -337,6 +336,7 @@ def _wrapper( kwargs: dict, arg_is_array: Iterable[bool], expected: ExpectedDict, + expected_indexes: dict[Hashable, Index], ): """ Wrapper function that receives datasets in args; converts to dataarrays when necessary; @@ -372,7 +372,7 @@ def _wrapper( # ChainMap wants MutableMapping, but xindexes is Mapping merged_indexes = collections.ChainMap( - expected["indexes"], + expected_indexes, merged_coordinates.xindexes, # type: ignore[arg-type] ) expected_index = merged_indexes.get(name, None) @@ -412,6 +412,7 @@ def _wrapper( try: import dask import dask.array + from dask.base import tokenize from dask.highlevelgraph import HighLevelGraph except ImportError: @@ -551,6 +552,20 @@ def _wrapper( for isxr, arg in zip(is_xarray, npargs, strict=True) ] + # only include new or modified indexes to minimize duplication of data + indexes = { + dim: coordinates.xindexes[dim][ + _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) + ] + for dim in (new_indexes | modified_indexes) + } + + tokenized_indexes: dict[Hashable, str] = {} + for k, v in indexes.items(): + tokenized_v = tokenize(v) + graph[f"{k}-coordinate-{tokenized_v}"] = v + tokenized_indexes[k] = f"{k}-coordinate-{tokenized_v}" + # raise nice error messages in _wrapper expected: ExpectedDict = { # input chunk 0 along a dimension maps to output chunk 0 along the same dimension @@ -562,17 +577,18 @@ def _wrapper( }, "data_vars": set(template.data_vars.keys()), "coords": set(template.coords.keys()), - # only include new or modified indexes to minimize duplication of data, and graph size. - "indexes": { - dim: coordinates.xindexes[dim][ - _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) - ] - for dim in (new_indexes | modified_indexes) - }, } from_wrapper = (gname,) + chunk_tuple - graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected) + graph[from_wrapper] = ( + _wrapper, + func, + blocked_args, + kwargs, + is_array, + expected, + (dict, [[k, v] for k, v in tokenized_indexes.items()]), + ) # mapping from variable name to dask graph key var_key_map: dict[Hashable, str] = {} diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index a46a9d43c4c..cc795b75118 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -14,6 +14,7 @@ from xarray import DataArray, Dataset, Variable from xarray.core import duck_array_ops from xarray.core.duck_array_ops import lazy_array_equiv +from xarray.core.indexes import PandasIndex from xarray.testing import assert_chunks_equal from xarray.tests import ( assert_allclose, @@ -1375,6 +1376,13 @@ def test_map_blocks_da_ds_with_template(obj): actual = xr.map_blocks(func, obj, template=template) assert_identical(actual, template) + # Check that indexes are written into the graph directly + dsk = dict(actual.__dask_graph__()) + assert len({k for k in dsk if "x-coordinate" in k}) + assert all( + isinstance(v, PandasIndex) for k, v in dsk.items() if "x-coordinate" in k + ) + with raise_if_dask_computes(): actual = obj.map_blocks(func, template=template) assert_identical(actual, template)