Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track mutable frames #4004

Merged
merged 34 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
27ec2c7
Flag mutable frames as part of serialization
jakirkham Jul 31, 2020
7dda1b7
Just use `bytes` in `merge_frames`
jakirkham Jul 31, 2020
64c58ca
Only copy mutable frames
jakirkham Jul 31, 2020
78f301f
Mark all CUDA frames as non-writeable
jakirkham Jul 31, 2020
2d75629
Test NumPy array preserves writeability
jakirkham Jul 31, 2020
7ffa2c2
Fix-up `merge_frames` test with `writeable` header
jakirkham Jul 31, 2020
3ecaff8
Relax the `memoryview` requirement of frames
jakirkham Jul 31, 2020
95a2f9f
Drop test for frames being `memoryview`s
jakirkham Jul 31, 2020
a581f01
Assert `writeable` and `lengths` have same # items
jakirkham Jul 31, 2020
934a55e
Optionally determine `writeable`
jakirkham Jul 31, 2020
0c12ee1
Rename `m` to `w` for clarity
jakirkham Jul 31, 2020
c0427cb
Fix Pandas test name
jakirkham Jul 31, 2020
5e712be
Test writing to Pandas Series after serialization
jakirkham Jul 31, 2020
f5f2f12
Force Pandas serialized frames to be readonly
jakirkham Jul 31, 2020
58e3b03
Test `merge_frames` with other `writeable` values
jakirkham Jul 31, 2020
90bf603
Use read-only frames for serialized NumPy array
jakirkham Jul 31, 2020
73f34b0
Use `list` to compare `writeable`
jakirkham Jul 31, 2020
2473f6d
Drop unneeded CUDA array interface check
jakirkham Jul 31, 2020
10abc22
Use a `bytearray` to join writeable frames
jakirkham Jul 31, 2020
01d41d7
Handle singleton frame case as well
jakirkham Aug 1, 2020
4b25ada
Handle fast-path at the beginning
jakirkham Aug 1, 2020
cc78d30
Fix-up readonly case in fast path
jakirkham Aug 1, 2020
09c1ce7
Mark CUDA frames as neither readonly no writeable
jakirkham Aug 1, 2020
2ef5de3
Fix-up logic for copying singleton frame
jakirkham Aug 1, 2020
913cf8a
Test a few more `merge_frames` cases
jakirkham Aug 1, 2020
040ef75
Drop unneeded CuPy customization
jakirkham Aug 1, 2020
00bc781
Always use `join` path when copying is needed
jakirkham Aug 1, 2020
7012d85
Go back to `.extend()`
jakirkham Aug 1, 2020
41dcec4
Skip assignment and return `out`
jakirkham Aug 1, 2020
f40d91c
One more test with `bytearray`s
jakirkham Aug 2, 2020
8a624ad
Add function to check whether frame is writeable
jakirkham Aug 3, 2020
942d7ea
Use `is_writeable` to detect writeable frames
jakirkham Aug 3, 2020
60190e9
Merge dask/master into jakirkham/track_mutable_frames
jakirkham Aug 3, 2020
87252c9
Explain return results from `is_writeable`
jakirkham Aug 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .compression import compressions, maybe_compress, decompress
from .serialize import serialize, deserialize, Serialize, Serialized, extract_serialize
from .utils import frame_split_size, merge_frames, msgpack_opts
from ..utils import nbytes
from ..utils import is_writeable, nbytes

_deserialize = deserialize

Expand Down Expand Up @@ -46,6 +46,8 @@ def dumps(msg, serializers=None, on_error="message", context=None):
out_frames = []

for key, (head, frames) in data.items():
if "writeable" not in head:
head["writeable"] = tuple(map(is_writeable, frames))
if "lengths" not in head:
head["lengths"] = tuple(map(nbytes, frames))

Expand All @@ -71,6 +73,8 @@ def dumps(msg, serializers=None, on_error="message", context=None):
out_frames.extend(_out_frames)

for key, (head, frames) in pre.items():
if "writeable" not in head:
head["writeable"] = tuple(map(is_writeable, frames))
if "lengths" not in head:
head["lengths"] = tuple(map(nbytes, frames))
head["count"] = len(frames)
Expand Down
1 change: 1 addition & 0 deletions distributed/protocol/cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def cuda_deserialize_cupy_ndarray(header, frames):
@dask_serialize.register(cupy.ndarray)
def dask_serialize_cupy_ndarray(x):
header, frames = cuda_serialize_cupy_ndarray(x)
header["writeable"] = (None,) * len(frames)
frames = [memoryview(cupy.asnumpy(f)) for f in frames]
return header, frames

Expand Down
1 change: 1 addition & 0 deletions distributed/protocol/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def cuda_deserialize_numba_ndarray(header, frames):
@dask_serialize.register(numba.cuda.devicearray.DeviceNDArray)
def dask_serialize_numba_ndarray(x):
header, frames = cuda_serialize_numba_ndarray(x)
header["writeable"] = (None,) * len(frames)
frames = [memoryview(f.copy_to_host()) for f in frames]
return header, frames

Expand Down
1 change: 1 addition & 0 deletions distributed/protocol/rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def cuda_deserialize_rmm_device_buffer(header, frames):
@dask_serialize.register(rmm.DeviceBuffer)
def dask_serialize_rmm_device_buffer(x):
header, frames = cuda_serialize_rmm_device_buffer(x)
header["writeable"] = (None,) * len(frames)
frames = [numba.cuda.as_cuda_array(f).copy_to_host().data for f in frames]
return header, frames

Expand Down
7 changes: 4 additions & 3 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import msgpack

from . import pickle
from ..utils import has_keyword, nbytes, typename, ensure_bytes
from ..utils import has_keyword, nbytes, typename, ensure_bytes, is_writeable
from .compression import maybe_compress, decompress
from .utils import (
unpack_frames,
Expand Down Expand Up @@ -473,6 +473,8 @@ def replace_inner(x):

def serialize_bytelist(x, **kwargs):
header, frames = serialize(x, **kwargs)
if "writeable" not in header:
header["writeable"] = tuple(map(is_writeable, frames))
if "lengths" not in header:
header["lengths"] = tuple(map(nbytes, frames))
if frames:
Expand Down Expand Up @@ -501,8 +503,7 @@ def deserialize_bytes(b):
else:
header = {}
frames = decompress(header, frames)
if not any(hasattr(f, "__cuda_array_interface__") for f in frames):
frames = merge_frames(header, frames)
frames = merge_frames(header, frames)
return deserialize(header, frames)


Expand Down
14 changes: 9 additions & 5 deletions distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from distributed.protocol.pickle import HIGHEST_PROTOCOL
from distributed.protocol.compression import maybe_compress
from distributed.system import MEMORY_LIMIT
from distributed.utils import tmpfile, nbytes
from distributed.utils import ensure_bytes, tmpfile, nbytes
from distributed.utils_test import gen_cluster


Expand Down Expand Up @@ -100,12 +100,16 @@ def test_dumps_serialize_numpy(x):
np.testing.assert_equal(x, y)


def test_dumps_numpy_writable():
@pytest.mark.parametrize("writeable", [True, False])
def test_dumps_numpy_writable(writeable):
a1 = np.arange(1000)
a1.flags.writeable = False
(a2,) = loads(dumps([to_serialize(a1)]))
a1.flags.writeable = writeable
fs = dumps([to_serialize(a1)])
# Make all frames read-only
fs = list(map(ensure_bytes, fs))
(a2,) = loads(fs)
assert (a1 == a2).all()
assert a2.flags.writeable
assert a2.flags.writeable == a1.flags.writeable


@pytest.mark.parametrize(
Expand Down
23 changes: 21 additions & 2 deletions distributed/protocol/tests/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@

from dask.dataframe.utils import assert_eq

from distributed.protocol import serialize, deserialize, decompress
from distributed.protocol import (
serialize,
deserialize,
decompress,
dumps,
loads,
to_serialize,
)
from distributed.utils import ensure_bytes


dfs = [
Expand Down Expand Up @@ -63,10 +71,21 @@


@pytest.mark.parametrize("df", dfs)
def test_dumps_serialize_numpy(df):
def test_dumps_serialize_pandas(df):
header, frames = serialize(df)
if "compression" in header:
frames = decompress(header, frames)
df2 = deserialize(header, frames)

assert_eq(df, df2)


def test_dumps_pandas_writable():
a1 = np.arange(1000)
s1 = pd.Series(a1)
fs = dumps([to_serialize(s1)])
# Make all frames read-only
fs = list(map(ensure_bytes, fs))
(s2,) = loads(fs)
assert (s1 == s2).all()
s2[...] = 0
31 changes: 20 additions & 11 deletions distributed/protocol/tests/test_protocol_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
import pytest

from distributed.protocol.utils import merge_frames, pack_frames, unpack_frames
from distributed.utils import ensure_bytes
from distributed.utils import ensure_bytes, is_writeable


@pytest.mark.parametrize(
"lengths,frames",
"lengths,writeable,frames",
[
([3], [b"123"]),
([3, 3], [b"123", b"456"]),
([2, 3, 2], [b"12345", b"67"]),
([5, 2], [b"123", b"45", b"67"]),
([3, 4], [b"12", b"34", b"567"]),
([3], [False], [b"123"]),
([3], [True], [b"123"]),
([3], [None], [b"123"]),
([3], [False], [bytearray(b"123")]),
([3], [True], [bytearray(b"123")]),
([3], [None], [bytearray(b"123")]),
([3, 3], [False, False], [b"123", b"456"]),
([2, 3, 2], [False, True, None], [b"12345", b"67"]),
([2, 3, 2], [False, True, None], [bytearray(b"12345"), bytearray(b"67")]),
([5, 2], [False, True], [b"123", b"45", b"67"]),
([3, 4], [None, False], [b"12", b"34", b"567"]),
],
)
def test_merge_frames(lengths, frames):
header = {"lengths": lengths}
def test_merge_frames(lengths, writeable, frames):
header = {
"lengths": lengths,
"writeable": writeable,
}
result = merge_frames(header, frames)

data = b"".join(frames)
Expand All @@ -24,8 +33,8 @@ def test_merge_frames(lengths, frames):
expected.append(data[:i])
data = data[i:]

assert all(isinstance(f, memoryview) for f in result)
assert all(not f.readonly for f in result)
writeables = list(map(is_writeable, result))
assert (r == e for r, e in zip(writeables, header["writeable"]) if e is not None)
assert list(map(ensure_bytes, result)) == expected


Expand Down
60 changes: 34 additions & 26 deletions distributed/protocol/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,36 +51,44 @@ def merge_frames(header, frames):
[b'123456']
"""
lengths = list(header["lengths"])
frames = list(map(memoryview, frames))
writeables = list(header["writeable"])

assert len(lengths) == len(writeables)
assert sum(lengths) == sum(map(nbytes, frames))

if not all(len(f) == l for f, l in zip(frames, lengths)):
frames = frames[::-1]
lengths = lengths[::-1]

out = []
while lengths:
l = lengths.pop()
L = []
while l:
frame = frames.pop()
if nbytes(frame) <= l:
L.append(frame)
l -= nbytes(frame)
else:
L.append(frame[:l])
frames.append(frame[l:])
l = 0
if len(L) == 1: # no work necessary
out.append(L[0])
else:
out.append(memoryview(bytearray().join(L)))
frames = out

frames = [memoryview(bytearray(f)) if f.readonly else f for f in frames]
if all(len(f) == l for f, l in zip(frames, lengths)):
return [
(bytearray(f) if w else bytes(f)) if w == memoryview(f).readonly else f
for w, f in zip(header["writeable"], frames)
]

return frames
frames = frames[::-1]
lengths = lengths[::-1]
writeables = writeables[::-1]

out = []
while lengths:
l = lengths.pop()
w = writeables.pop()
L = []
while l:
frame = frames.pop()
if nbytes(frame) <= l:
L.append(frame)
l -= nbytes(frame)
else:
frame = memoryview(frame)
L.append(frame[:l])
frames.append(frame[l:])
l = 0
if len(L) == 1 and w != memoryview(L[0]).readonly: # no work necessary
out.extend(L)
elif w:
out.append(bytearray().join(L))
else:
out.append(bytes().join(L))

return out


def pack_frames_prelude(frames):
Expand Down
13 changes: 13 additions & 0 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,19 @@ def nbytes(frame, _bytes_like=(bytes, bytearray)):
return len(frame)


def is_writeable(frame):
"""
Check whether frame is writeable

Will return ``True`` if writeable, ``False`` if readonly, and
``None`` if undetermined.
"""
try:
return not memoryview(frame).readonly
except TypeError:
return None
jakirkham marked this conversation as resolved.
Show resolved Hide resolved


@contextmanager
def time_warn(duration, text):
start = time()
Expand Down