Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

map_blocks leads to downstream TypeError: can not serialize 'function' object #4574

Closed
m-albert opened this issue Mar 9, 2021 · 5 comments
Closed

Comments

@m-albert
Copy link

m-albert commented Mar 9, 2021

What happened:

The motivation for this bug report comes from the observation that using dask.array.map_blocks on a function def func(block_info=None) in combination with a distributed.Client to visualize arrays in napari leads to the error TypeError: can not serialize 'function' object.

The error occurs only when using the dask.distributed backend, it works with a dask backend.

The error occurs only when mapping a function def func(block_info=None) and not when using def func() or def func(x, block_info=None).

For more context, see dask/dask-image#194.

What you expected to happen:

That passing dask arrays to napari (which slices the array and calls np.asarray) should work independently of whether a dask or dask.distributed backend is used.

Minimal Complete Verifiable Example:

The original error occurring in combination with napari:

%gui qt

import dask.array as da
import numpy as np

import napari
from dask.distributed import Client
from dask_image.imread import imread
client = Client()

xn = np.random.randint(0, 100, (20, 30, 30))
xd = da.from_array(xn, chunks=(1, 5, 5))

def func(block_info=None):
    return np.random.randint(0, 100, (1, 5, 5))

# this instead works (no `block_info` argument)
# def func():
#    return np.random.randint(0, 100, (1, 5, 5))

xm = da.map_blocks(func, chunks=xd.chunks, dtype=xd.dtype)
napari.view_image(xm)

Likely, this minimal example reproduces the error without napari:

import dask.array as da
import numpy as np

from dask.core import flatten
from dask.distributed import Client
client = Client()

xn = np.random.randint(0, 100, (2, 4, 4))
xd = da.from_array(xn, chunks=(1, 2, 2))

# fails
def func(block_info=None):
    return np.random.randint(0, 100, (1, 2, 2))

# works
# def func():
#     return np.random.randint(0, 100, (1, 2, 2))

xm = da.map_blocks(func, chunks=xd.chunks, dtype=xd.dtype)

xm.dask.__dask_distributed_pack__(client, set(flatten(xm.__dask_keys__())))

Anything else we need to know?:

The original error does not occur with dask and dask.distributed version 2.30.0 and lower (the minimal example is incompatible with older versions).

Environment:

  • Dask version: 2021.03.0
  • Python version: 3.9.2
  • Operating System: macOS
  • Install method (conda, pip, source): pip
@jakirkham
Copy link
Member

@madsbk, do you have any thoughts here? 🙂

@jrbourbeau
Copy link
Member

Thanks for reporting @m-albert! I've not looked deeply into this, but dask.Array.map_blocks has some additional code paths it goes down when the function being mapped has a block_info= keyword. Based on code snippets above, that might be a good place to start looking

@jakirkham
Copy link
Member

Just to clarify the issue is that HLG serialization uses MsgPack, but it encounters a function in the Dask graph, which causes things to fall apart. Please see this traceback for context ( dask/dask-image#194 (comment) ). This came up when using Napari with Dask, but it appears the issue traces back to Dask alone.

@jakirkham
Copy link
Member

Also as @jni points out, it seems a critical piece here is whether array fusion is enabled/disabled

import dask
import distributed

import dask.array as da
import numpy as np


c = distributed.Client()

xn = np.random.randint(0, 100, (20, 30, 30))
xd = da.from_array(xn, chunks=(1, 5, 5))

def func(block_info=None):
    return np.random.randint(0, 100, (1, 5, 5))

xm = da.map_blocks(func, chunks=xd.chunks, dtype=xd.dtype)


dask.config.set({"optimization.fuse.active": True})
xm[10:17, 4:6, 4:6].compute()  # works

dask.config.set({"optimization.fuse.active": False})
xm[10:17, 4:6, 4:6].compute()  # fails
Traceback:
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-1-b2e0d1fb182c> in <module>
     20 
     21 dask.config.set({"optimization.fuse.active": False})
---> 22 xm[10:17, 4:6, 4:6].compute()  # fails

~/Developer/dask/dask/base.py in compute(self, **kwargs)
    281         dask.base.compute
    282         """
--> 283         (result,) = compute(self, traverse=False, **kwargs)
    284         return result
    285 

~/Developer/dask/dask/base.py in compute(*args, **kwargs)
    563         postcomputes.append(x.__dask_postcompute__())
    564 
--> 565     results = schedule(dsk, keys, **kwargs)
    566     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    567 

~/Developer/distributed/distributed/client.py in get(self, dsk, keys, workers, allow_other_workers, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
   2640         Client.compute : Compute asynchronous collections
   2641         """
-> 2642         futures = self._graph_to_futures(
   2643             dsk,
   2644             keys=set(flatten([keys])),

~/Developer/distributed/distributed/client.py in _graph_to_futures(self, dsk, keys, workers, allow_other_workers, priority, user_priority, resources, retries, fifo_timeout, actors)
   2548                 dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
   2549 
-> 2550             dsk = dsk.__dask_distributed_pack__(self, keyset)
   2551 
   2552             annotations = {}

~/Developer/dask/dask/highlevelgraph.py in __dask_distributed_pack__(self, client, client_keys)
    942                 }
    943             )
--> 944         return dumps_msgpack({"layers": layers})
    945 
    946     @staticmethod

~/Developer/distributed/distributed/protocol/core.py in dumps_msgpack(msg, compression)
    120     """
    121     header = {}
--> 122     payload = msgpack.dumps(msg, default=msgpack_encode_default, use_bin_type=True)
    123 
    124     fmt, payload = maybe_compress(payload, compression=compression)

~/miniconda/envs/dask/lib/python3.8/site-packages/msgpack/__init__.py in packb(o, **kwargs)
     33     See :class:`Packer` for options.
     34     """
---> 35     return Packer(**kwargs).pack(o)
     36 
     37 

msgpack/_packer.pyx in msgpack._cmsgpack.Packer.pack()

msgpack/_packer.pyx in msgpack._cmsgpack.Packer.pack()

msgpack/_packer.pyx in msgpack._cmsgpack.Packer.pack()

msgpack/_packer.pyx in msgpack._cmsgpack.Packer._pack()

msgpack/_packer.pyx in msgpack._cmsgpack.Packer._pack()

msgpack/_packer.pyx in msgpack._cmsgpack.Packer._pack()

msgpack/_packer.pyx in msgpack._cmsgpack.Packer._pack()

msgpack/_packer.pyx in msgpack._cmsgpack.Packer._pack()

msgpack/_packer.pyx in msgpack._cmsgpack.Packer._pack()

TypeError: can not serialize 'function' object

@jrbourbeau
Copy link
Member

Closed via dask/dask#7353. Thanks @m-albert for reporting and @madsbk for fixing!

@m-albert
Copy link
Author

Wow that was so quick :) Thanks @madsbk @jakirkham @jrbourbeau

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants