Skip to content

Commit f6ec885

Browse files
committed
ENH: simplify serialization protocol
1 parent f15abc5 commit f6ec885

File tree

1 file changed

+66
-52
lines changed

1 file changed

+66
-52
lines changed

distributed/protocol/core.py

+66-52
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,32 @@
1919
logger = logging.getLogger(__name__)
2020

2121

22+
OFFLOAD_KEY = "_$"
23+
24+
25+
def _make_offload_value(value):
26+
return {OFFLOAD_KEY: value}
27+
28+
29+
def _extract_offload_value(value):
30+
if not isinstance(value, dict) or len(value) != 1:
31+
return None
32+
val = value.get(OFFLOAD_KEY)
33+
if val is None:
34+
return None
35+
return val
36+
37+
2238
def dumps(msg, serializers=None, on_error="message", context=None):
2339
""" Transform Python message to bytestream suitable for communication """
2440
try:
2541
data = {}
2642
# Only lists and dicts can contain serialized values
2743
if isinstance(msg, (list, dict)):
2844
msg, data, bytestrings = extract_serialize(msg)
29-
small_header, small_payload = dumps_msgpack(msg)
3045

3146
if not data: # fast path without serialized data
32-
return small_header, small_payload
47+
return dumps_msgpack(msg)
3348

3449
pre = {
3550
key: (value.header, value.frames)
@@ -45,11 +60,17 @@ def dumps(msg, serializers=None, on_error="message", context=None):
4560
if type(value) is Serialize
4661
}
4762

48-
header = {"headers": {}, "keys": [], "bytestrings": list(bytestrings)}
49-
5063
out_frames = []
5164

65+
def patch_offload_header(path, header, frame_index, context):
66+
accessor, key = path[:-1], path[-1]
67+
holder = reduce(operator.getitem, accessor, context)
68+
header["deserialize"] = path in bytestrings
69+
header["frame_index"] = frame_index
70+
holder[key] = _make_offload_value(header)
71+
5272
for key, (head, frames) in data.items():
73+
head = dict(head)
5374
if "lengths" not in head:
5475
head["lengths"] = tuple(map(nbytes, frames))
5576
if "compression" not in head:
@@ -60,16 +81,15 @@ def dumps(msg, serializers=None, on_error="message", context=None):
6081
compression = []
6182
head["compression"] = compression
6283
head["count"] = len(frames)
63-
header["headers"][key] = head
64-
header["keys"].append(key)
84+
patch_offload_header(key, head, len(out_frames), msg)
6585
out_frames.extend(frames)
6686

6787
for key, (head, frames) in pre.items():
88+
head = dict(head)
6889
if "lengths" not in head:
6990
head["lengths"] = tuple(map(nbytes, frames))
7091
head["count"] = len(frames)
71-
header["headers"][key] = head
72-
header["keys"].append(key)
92+
patch_offload_header(key, head, len(out_frames), msg)
7393
out_frames.extend(frames)
7494

7595
for i, frame in enumerate(out_frames):
@@ -80,66 +100,60 @@ def dumps(msg, serializers=None, on_error="message", context=None):
80100
frame = frame.tobytes()
81101
out_frames[i] = frame
82102

83-
return [
84-
small_header,
85-
small_payload,
86-
msgpack.dumps(header, use_bin_type=True),
87-
] + out_frames
103+
return dumps_msgpack(msg) + out_frames
88104
except Exception:
89105
logger.critical("Failed to Serialize", exc_info=True)
90106
raise
91107

92108

93109
def loads(frames, deserialize=True, deserializers=None):
94110
""" Transform bytestream back into Python value """
95-
frames = frames[::-1] # reverse order to improve pop efficiency
96111
if not isinstance(frames, list):
97112
frames = list(frames)
98113
try:
99-
small_header = frames.pop()
100-
small_payload = frames.pop()
114+
small_header = frames[0]
115+
small_payload = frames[1]
101116
msg = loads_msgpack(small_header, small_payload)
102-
if not frames:
117+
if len(frames) < 3:
103118
return msg
104119

105-
header = frames.pop()
106-
header = msgpack.loads(header, use_list=False, **msgpack_opts)
107-
keys = header["keys"]
108-
headers = header["headers"]
109-
bytestrings = set(header["bytestrings"])
110-
111-
for key in keys:
112-
head = headers[key]
113-
count = head["count"]
114-
if count:
115-
fs = frames[-count::][::-1]
116-
del frames[-count:]
117-
else:
118-
fs = []
119-
120-
if deserialize or key in bytestrings:
121-
if "compression" in head:
122-
fs = decompress(head, fs)
123-
fs = merge_frames(head, fs)
124-
value = _deserialize(head, fs, deserializers=deserializers)
125-
else:
126-
value = Serialized(head, fs)
127-
128-
def put_in(keys, coll, val):
129-
"""Inverse of get_in, but does type promotion in the case of lists"""
130-
if keys:
131-
holder = reduce(operator.getitem, keys[:-1], coll)
132-
if isinstance(holder, tuple):
133-
holder = list(holder)
134-
coll = put_in(keys[:-1], coll, holder)
135-
holder[keys[-1]] = val
120+
out_frames_start = 2
121+
122+
def _traverse(item):
123+
placeholder = _extract_offload_value(item)
124+
if placeholder is not None:
125+
header = placeholder
126+
deserialize_key = header["deserialize"]
127+
frame_index = header["frame_index"]
128+
count = header["count"]
129+
if count:
130+
start_index = out_frames_start + frame_index
131+
end_index = start_index + count
132+
fs = frames[start_index:end_index]
133+
frames[start_index:end_index] = [None] * count # free memory
136134
else:
137-
coll = val
138-
return coll
135+
fs = []
139136

140-
msg = put_in(key, msg, value)
137+
if deserialize or deserialize_key:
138+
if "compression" in header:
139+
fs = decompress(header, fs)
140+
fs = merge_frames(header, fs)
141+
value = _deserialize(header, fs, deserializers=deserializers)
142+
else:
143+
value = Serialized(header, fs)
144+
return value
145+
146+
if isinstance(item, (list, tuple)):
147+
return type(item)(_traverse(i) for i in item)
148+
elif isinstance(item, dict):
149+
return {
150+
key: _traverse(val)
151+
for (key, val) in item.items()
152+
}
153+
else:
154+
return item
141155

142-
return msg
156+
return _traverse(msg)
143157
except Exception:
144158
logger.critical("Failed to deserialize", exc_info=True)
145159
raise

0 commit comments

Comments
 (0)