Skip to content

Commit e36b592

Browse files
committed
Mark CUDA frames as neither readonly no writeable
This special case of `None` will bypass both cases.
1 parent cc78d30 commit e36b592

File tree

4 files changed

+4
-2
lines changed

4 files changed

+4
-2
lines changed

distributed/protocol/cuda.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ def cuda_dumps(x):
1818
header, frames = dumps(x)
1919
header["type-serialized"] = pickle.dumps(type(x))
2020
header["serializer"] = "cuda"
21-
header["compression"] = (False,) * len(frames) # no compression for gpu data
22-
header["writeable"] = (False,) * len(frames) # no need to enforce writeable frames
2321
return header, frames
2422

2523

distributed/protocol/cupy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def cuda_serialize_cupy_ndarray(x):
2323
header = x.__cuda_array_interface__.copy()
2424
header["strides"] = tuple(x.strides)
2525
header["lengths"] = [x.nbytes]
26+
header["compression"] = (False,) * len(frames)
2627
frames = [
2728
cupy.ndarray(
2829
shape=(x.nbytes,), dtype=cupy.dtype("u1"), memptr=x.data, strides=(1,)
@@ -47,6 +48,7 @@ def cuda_deserialize_cupy_ndarray(header, frames):
4748
@dask_serialize.register(cupy.ndarray)
4849
def dask_serialize_cupy_ndarray(x):
4950
header, frames = cuda_serialize_cupy_ndarray(x)
51+
header["writeable"] = (None,) * len(frames)
5052
frames = [memoryview(cupy.asnumpy(f)) for f in frames]
5153
return header, frames
5254

distributed/protocol/numba.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def cuda_deserialize_numba_ndarray(header, frames):
5151
@dask_serialize.register(numba.cuda.devicearray.DeviceNDArray)
5252
def dask_serialize_numba_ndarray(x):
5353
header, frames = cuda_serialize_numba_ndarray(x)
54+
header["writeable"] = (None,) * len(frames)
5455
frames = [memoryview(f.copy_to_host()) for f in frames]
5556
return header, frames
5657

distributed/protocol/rmm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def cuda_deserialize_rmm_device_buffer(header, frames):
3030
@dask_serialize.register(rmm.DeviceBuffer)
3131
def dask_serialize_rmm_device_buffer(x):
3232
header, frames = cuda_serialize_rmm_device_buffer(x)
33+
header["writeable"] = (None,) * len(frames)
3334
frames = [numba.cuda.as_cuda_array(f).copy_to_host().data for f in frames]
3435
return header, frames
3536

0 commit comments

Comments
 (0)