77except ImportError :
88 pass
99
10+ import collections
1011import itertools
1112import operator
1213from typing import (
1314 Any ,
1415 Callable ,
1516 Dict ,
17+ DefaultDict ,
1618 Hashable ,
1719 Mapping ,
1820 Sequence ,
@@ -221,7 +223,12 @@ def _wrapper(func, obj, to_array, args, kwargs):
221223 indexes = {dim : dataset .indexes [dim ] for dim in preserved_indexes }
222224 indexes .update ({k : template .indexes [k ] for k in new_indexes })
223225
226+ # We're building a new HighLevelGraph hlg. We'll have one new layer
227+ # for each variable in the dataset, which is the result of the
228+ # func applied to the values.
229+
224230 graph : Dict [Any , Any ] = {}
231+ new_layers : DefaultDict [str , Dict [Any , Any ]] = collections .defaultdict (dict )
225232 gname = "{}-{}" .format (
226233 dask .utils .funcname (func ), dask .base .tokenize (dataset , args , kwargs )
227234 )
@@ -310,9 +317,20 @@ def _wrapper(func, obj, to_array, args, kwargs):
310317 # unchunked dimensions in the input have one chunk in the result
311318 key += (0 ,)
312319
313- graph [key ] = (operator .getitem , from_wrapper , name )
320+ # We're adding multiple new layers to the graph:
321+ # The first new layer is the result of the computation on
322+ # the array.
323+ # Then we add one layer per variable, which extracts the
324+ # result for that variable, and depends on just the first new
325+ # layer.
326+ new_layers [gname_l ][key ] = (operator .getitem , from_wrapper , name )
327+
328+ hlg = HighLevelGraph .from_collections (gname , graph , dependencies = [dataset ])
314329
315- graph = HighLevelGraph .from_collections (gname , graph , dependencies = [dataset ])
330+ for gname_l , layer in new_layers .items ():
331+ # This adds in the getitems for each variable in the dataset.
332+ hlg .dependencies [gname_l ] = {gname }
333+ hlg .layers [gname_l ] = layer
316334
317335 result = Dataset (coords = indexes , attrs = template .attrs )
318336 for name , gname_l in var_key_map .items ():
@@ -325,7 +343,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
325343 var_chunks .append ((len (indexes [dim ]),))
326344
327345 data = dask .array .Array (
328- graph , name = gname_l , chunks = var_chunks , dtype = template [name ].dtype
346+ hlg , name = gname_l , chunks = var_chunks , dtype = template [name ].dtype
329347 )
330348 result [name ] = (dims , data , template [name ].attrs )
331349
0 commit comments