Skip to content

Commit

Permalink
Reduce graph size through writing indexes directly into graph for ``m…
Browse files Browse the repository at this point in the history
…ap_blocks`` (#9658)

* Reduce graph size through writing indexes directly into graph for map_blocks

* Reduce graph size through writing indexes directly into graph for map_blocks

* Update xarray/core/parallel.py

---------

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
  • Loading branch information
phofl and dcherian authored Oct 22, 2024
1 parent 863184d commit 5632c8e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 10 deletions.
36 changes: 26 additions & 10 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class ExpectedDict(TypedDict):
shapes: dict[Hashable, int]
coords: set[Hashable]
data_vars: set[Hashable]
indexes: dict[Hashable, Index]


def unzip(iterable):
Expand Down Expand Up @@ -337,6 +336,7 @@ def _wrapper(
kwargs: dict,
arg_is_array: Iterable[bool],
expected: ExpectedDict,
expected_indexes: dict[Hashable, Index],
):
"""
Wrapper function that receives datasets in args; converts to dataarrays when necessary;
Expand Down Expand Up @@ -372,7 +372,7 @@ def _wrapper(

# ChainMap wants MutableMapping, but xindexes is Mapping
merged_indexes = collections.ChainMap(
expected["indexes"],
expected_indexes,
merged_coordinates.xindexes, # type: ignore[arg-type]
)
expected_index = merged_indexes.get(name, None)
Expand Down Expand Up @@ -412,6 +412,7 @@ def _wrapper(
try:
import dask
import dask.array
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph

except ImportError:
Expand Down Expand Up @@ -551,6 +552,20 @@ def _wrapper(
for isxr, arg in zip(is_xarray, npargs, strict=True)
]

# only include new or modified indexes to minimize duplication of data
indexes = {
dim: coordinates.xindexes[dim][
_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
]
for dim in (new_indexes | modified_indexes)
}

tokenized_indexes: dict[Hashable, str] = {}
for k, v in indexes.items():
tokenized_v = tokenize(v)
graph[f"{k}-coordinate-{tokenized_v}"] = v
tokenized_indexes[k] = f"{k}-coordinate-{tokenized_v}"

# raise nice error messages in _wrapper
expected: ExpectedDict = {
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
Expand All @@ -562,17 +577,18 @@ def _wrapper(
},
"data_vars": set(template.data_vars.keys()),
"coords": set(template.coords.keys()),
# only include new or modified indexes to minimize duplication of data, and graph size.
"indexes": {
dim: coordinates.xindexes[dim][
_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
]
for dim in (new_indexes | modified_indexes)
},
}

from_wrapper = (gname,) + chunk_tuple
graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected)
graph[from_wrapper] = (
_wrapper,
func,
blocked_args,
kwargs,
is_array,
expected,
(dict, [[k, v] for k, v in tokenized_indexes.items()]),
)

# mapping from variable name to dask graph key
var_key_map: dict[Hashable, str] = {}
Expand Down
8 changes: 8 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from xarray import DataArray, Dataset, Variable
from xarray.core import duck_array_ops
from xarray.core.duck_array_ops import lazy_array_equiv
from xarray.core.indexes import PandasIndex
from xarray.testing import assert_chunks_equal
from xarray.tests import (
assert_allclose,
Expand Down Expand Up @@ -1375,6 +1376,13 @@ def test_map_blocks_da_ds_with_template(obj):
actual = xr.map_blocks(func, obj, template=template)
assert_identical(actual, template)

# Check that indexes are written into the graph directly
dsk = dict(actual.__dask_graph__())
assert len({k for k in dsk if "x-coordinate" in k})
assert all(
isinstance(v, PandasIndex) for k, v in dsk.items() if "x-coordinate" in k
)

with raise_if_dask_computes():
actual = obj.map_blocks(func, template=template)
assert_identical(actual, template)
Expand Down

0 comments on commit 5632c8e

Please sign in to comment.