diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index e4fba2b7ba9..4ed27bf278b 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -9,12 +9,13 @@ import msgpack from . import pickle -from ..utils import has_keyword, typename +from ..utils import has_keyword, nbytes, typename from .compression import maybe_compress, decompress from .utils import ( unpack_frames, pack_frames_prelude, frame_split_size, + merge_frames, ensure_bytes, msgpack_opts, ) @@ -473,6 +474,8 @@ def replace_inner(x): def serialize_bytelist(x, **kwargs): header, frames = serialize(x, **kwargs) + if "lengths" not in header: + header["lengths"] = tuple(map(nbytes, frames)) frames = sum(map(frame_split_size, frames), []) if frames: compression, frames = zip(*map(maybe_compress, frames)) @@ -499,6 +502,7 @@ def deserialize_bytes(b): else: header = {} frames = decompress(header, frames) + frames = merge_frames(header, frames) return deserialize(header, frames) diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 4cad5a3653b..57fceaea0c9 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -190,9 +190,19 @@ def test_empty_loads_deep(): assert isinstance(e2[0][0][0], Empty) -def test_serialize_bytes(): - for x in [1, "abc", np.arange(5), b"ab" * int(40e6)]: - b = serialize_bytes(x) +@pytest.mark.parametrize( + "kwargs", [{}, {"serializers": ["pickle"]},], +) +def test_serialize_bytes(kwargs): + for x in [ + 1, + "abc", + np.arange(5), + b"ab" * int(40e6), + int(2 ** 26) * b"ab", + (int(2 ** 25) * b"ab", int(2 ** 25) * b"ab"), + ]: + b = serialize_bytes(x, **kwargs) assert isinstance(b, bytes) y = deserialize_bytes(b) assert str(x) == str(y) diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index e58732b881c..fa020dae909 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -59,9 +59,6 @@ def merge_frames(header, frames): """ lengths = list(header["lengths"]) - if not frames: - return frames - assert sum(lengths) == sum(map(nbytes, frames)) if all(len(f) == l for f, l in zip(frames, lengths)): diff --git a/distributed/utils.py b/distributed/utils.py index f43b2f7acc0..dec1b6b79d3 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -931,7 +931,9 @@ def ensure_bytes(s): >>> ensure_bytes(b'123') b'123' """ - if hasattr(s, "encode"): + if isinstance(s, bytes): + return s + elif hasattr(s, "encode"): return s.encode() else: try: