diff --git a/dask_cuda/device_host_file.py b/dask_cuda/device_host_file.py index 5f0944722..ccd55583a 100644 --- a/dask_cuda/device_host_file.py +++ b/dask_cuda/device_host_file.py @@ -13,7 +13,6 @@ from distributed.utils import nbytes from distributed.worker import weight -import numpy from zict import Buffer, File, Func from zict.common import ZictBase @@ -31,37 +30,35 @@ class DeviceSerialized: that are in host memory """ - def __init__(self, header, parts): + def __init__(self, header, frames): self.header = header - self.parts = parts + self.frames = frames def __sizeof__(self): - return sum(map(nbytes, self.parts)) + return sum(map(nbytes, self.frames)) @dask_serialize.register(DeviceSerialized) def device_serialize(obj): - headers, frames = serialize(obj.parts) - header = {"sub-headers": headers, "main-header": obj.header} + header = {"obj-header": dict(obj.header)} + frames = list(obj.frames) return header, frames @dask_deserialize.register(DeviceSerialized) def device_deserialize(header, frames): - parts = deserialize(header["sub-headers"], frames) - return DeviceSerialized(header["main-header"], parts) + return DeviceSerialized(header["obj-header"], frames) @nvtx_annotate("SPILL_D2H", color="red", domain="dask_cuda") def device_to_host(obj: object) -> DeviceSerialized: header, frames = serialize(obj, serializers=["dask", "pickle"]) - frames = [numpy.asarray(f) for f in frames] return DeviceSerialized(header, frames) @nvtx_annotate("SPILL_H2D", color="green", domain="dask_cuda") def host_to_device(s: DeviceSerialized) -> object: - return deserialize(s.header, s.parts) + return deserialize(s.header, s.frames) class DeviceHostFile(ZictBase):