Skip to content

Commit

Permalink
Skip 2nd serialization pass of DeviceSerialized
Browse files Browse the repository at this point in the history
As `"dask"` serialization already converts a CUDA object into headers
and frames that Dask is able to work with, drop code that tries to
serialize frames on host further (as they are already as simple as they
can be). Cuts a fair bit of boilerplate from the spilling path, which
should simplify things a bit.
  • Loading branch information
jakirkham committed Jun 5, 2020
1 parent 2d902ea commit f36593e
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions dask_cuda/device_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from distributed.utils import nbytes
from distributed.worker import weight

import numpy
from zict import Buffer, File, Func
from zict.common import ZictBase

Expand All @@ -31,37 +30,35 @@ class DeviceSerialized:
that are in host memory
"""

def __init__(self, header, parts):
def __init__(self, header, frames):
self.header = header
self.parts = parts
self.frames = frames

def __sizeof__(self):
return sum(map(nbytes, self.parts))
return sum(map(nbytes, self.frames))


@dask_serialize.register(DeviceSerialized)
def device_serialize(obj):
headers, frames = serialize(obj.parts)
header = {"sub-headers": headers, "main-header": obj.header}
header = {"obj-header": dict(obj.header)}
frames = list(obj.frames)
return header, frames


@dask_deserialize.register(DeviceSerialized)
def device_deserialize(header, frames):
parts = deserialize(header["sub-headers"], frames)
return DeviceSerialized(header["main-header"], parts)
return DeviceSerialized(header["obj-header"], frames)


@nvtx_annotate("SPILL_D2H", color="red", domain="dask_cuda")
def device_to_host(obj: object) -> DeviceSerialized:
header, frames = serialize(obj, serializers=["dask", "pickle"])
frames = [numpy.asarray(f) for f in frames]
return DeviceSerialized(header, frames)


@nvtx_annotate("SPILL_H2D", color="green", domain="dask_cuda")
def host_to_device(s: DeviceSerialized) -> object:
return deserialize(s.header, s.parts)
return deserialize(s.header, s.frames)


class DeviceHostFile(ZictBase):
Expand Down

0 comments on commit f36593e

Please sign in to comment.