Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize transitions #4451

Merged
merged 22 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
63c60d0
Annotate `keys` & `new`
jakirkham Jan 29, 2021
a9648a0
Combine `recommendations` annotation with others
jakirkham Jan 29, 2021
ead596f
Annotate `dependents` & `dependencies`
jakirkham Jan 29, 2021
6e5bcab
Annotate `start` & `finish`
jakirkham Jan 29, 2021
23b34ba
Create empty `dict` for `recommendations` once
jakirkham Jan 29, 2021
20e0dfb
Use `.get(...)` to retrieve `TaskState`
jakirkham Jan 29, 2021
dbeb0cd
Assign `start, finish` to a variable
jakirkham Jan 29, 2021
9c8b820
Just use `.get(...)` to retrieve transition func
jakirkham Jan 29, 2021
2907628
Annotate `a` & `b`
jakirkham Jan 29, 2021
5d7bd6e
Use `.get(...)` to get `key` from `a`
jakirkham Jan 29, 2021
b26e527
Just `update` `recommendations` with `a` & `b`
jakirkham Jan 29, 2021
b8e09fe
Drop unneeded `KeyError` handling
jakirkham Jan 29, 2021
43a09bf
Annotate `finish2`
jakirkham Jan 29, 2021
f47b9b6
Replace generator with simple `for`-loop
jakirkham Jan 29, 2021
a8282dd
Bind `tuple` results to typed variable
jakirkham Jan 29, 2021
3d7ad9c
Collect `list` of messages for clients and workers
jakirkham Jan 29, 2021
6b62ef1
Extend `BatchedSend`'s `send` to take many msgs
jakirkham Jan 29, 2021
6f46e72
Add `send_all` method and use in `transition`
jakirkham Jan 29, 2021
e2c849f
Deliver all messages to batched send
jakirkham Jan 29, 2021
8a39818
Refactor out private `_transition` function
jakirkham Jan 29, 2021
6da1d47
Send all messages after processing all transitions
jakirkham Jan 29, 2021
cdcfcf2
`declare` `ALL_TASK_STATES` a `set`
jakirkham Jan 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions distributed/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,16 @@ def _background_send(self):
self.stopped.set()
self.abort()

def send(self, msg):
def send(self, *msgs):
"""Schedule a message for sending to the other side

This completes quickly and synchronously
"""
if self.comm is not None and self.comm.closed():
raise CommClosedError

self.message_count += 1
self.buffer.append(msg)
self.message_count += len(msgs)
self.buffer.extend(msgs)
Comment on lines -130 to +139
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this is cleaner than I expected :)

# Avoid spurious wakeups if possible
if self.next_deadline is None:
self.waker.set()
Expand Down
225 changes: 171 additions & 54 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ def nogil(func):
EventExtension,
]

ALL_TASK_STATES = {"released", "waiting", "no-worker", "processing", "erred", "memory"}
ALL_TASK_STATES = declare(
set, {"released", "waiting", "no-worker", "processing", "erred", "memory"}
)
globals()["ALL_TASK_STATES"] = ALL_TASK_STATES


@final
Expand Down Expand Up @@ -1961,7 +1964,7 @@ def transition_waiting_processing(self, key):

# logger.debug("Send job to worker: %s, %s", worker, key)

worker_msgs[worker] = _task_to_msg(self, ts)
worker_msgs[worker] = [_task_to_msg(self, ts)]

return {}, worker_msgs, client_msgs
except Exception as e:
Expand Down Expand Up @@ -2168,11 +2171,13 @@ def transition_memory_released(self, key, safe: bint = False):
ws._has_what.remove(ts)
ws._nbytes -= ts.get_nbytes()
ts._group._nbytes_in_memory -= ts.get_nbytes()
worker_msgs[ws._address] = {
"op": "delete-data",
"keys": [key],
"report": False,
}
worker_msgs[ws._address] = [
{
"op": "delete-data",
"keys": [key],
"report": False,
}
]

ts._who_has.clear()

Expand All @@ -2181,7 +2186,7 @@ def transition_memory_released(self, key, safe: bint = False):
report_msg = {"op": "lost-data", "key": key}
cs: ClientState
for cs in ts._who_wants:
client_msgs[cs._client_key] = report_msg
client_msgs[cs._client_key] = [report_msg]

if not ts._run_spec: # pure data
recommendations[key] = "forgotten"
Expand Down Expand Up @@ -2234,7 +2239,7 @@ def transition_released_erred(self, key):
}
cs: ClientState
for cs in ts._who_wants:
client_msgs[cs._client_key] = report_msg
client_msgs[cs._client_key] = [report_msg]

ts.state = "erred"

Expand Down Expand Up @@ -2276,7 +2281,7 @@ def transition_erred_released(self, key):
report_msg = {"op": "task-retried", "key": key}
cs: ClientState
for cs in ts._who_wants:
client_msgs[cs._client_key] = report_msg
client_msgs[cs._client_key] = [report_msg]

ts.state = "released"

Expand Down Expand Up @@ -2343,7 +2348,7 @@ def transition_processing_released(self, key):

w: str = _remove_from_processing(self, ts)
if w:
worker_msgs[w] = {"op": "release-task", "key": key}
worker_msgs[w] = [{"op": "release-task", "key": key}]

ts.state = "released"

Expand Down Expand Up @@ -2432,7 +2437,7 @@ def transition_processing_erred(
}
cs: ClientState
for cs in ts._who_wants:
client_msgs[cs._client_key] = report_msg
client_msgs[cs._client_key] = [report_msg]

cs = self._clients["fire-and-forget"]
if ts in cs._wants_what:
Expand Down Expand Up @@ -4706,6 +4711,29 @@ def client_send(self, client, msg):
if self.status == Status.running:
logger.critical("Tried writing to closed comm: %s", msg)

def send_all(self, client_msgs: dict, worker_msgs: dict):
"""Send messages to client and workers"""
stream_comms: dict = self.stream_comms
client_comms: dict = self.client_comms
msgs: list

for worker, msgs in worker_msgs.items():
try:
w = stream_comms[worker]
w.send(*msgs)
except (CommClosedError, AttributeError):
self.loop.add_callback(self.remove_worker, address=worker)

for client, msgs in client_msgs.items():
c = client_comms.get(client)
if c is None:
continue
try:
c.send(*msgs)
except CommClosedError:
if self.status == Status.running:
logger.critical("Tried writing to closed comm: %s", msgs)

############################
# Less common interactions #
############################
Expand Down Expand Up @@ -5814,12 +5842,12 @@ async def register_worker_plugin(self, comm, plugin, name=None):
# State Transitions #
#####################

def transition(self, key, finish, *args, **kwargs):
def _transition(self, key, finish: str, *args, **kwargs):
"""Transition a key from its current state to the finish state

Examples
--------
>>> self.transition('x', 'waiting')
>>> self._transition('x', 'waiting')
{'x': 'processing'}

Returns
Expand All @@ -5832,47 +5860,85 @@ def transition(self, key, finish, *args, **kwargs):
"""
parent: SchedulerState = cast(SchedulerState, self)
ts: TaskState
start: str
start_finish: tuple
finish2: str
recommendations: dict
worker_msgs: dict
client_msgs: dict
msgs: list
new_msgs: list
dependents: set
dependencies: set
try:
try:
ts = parent._tasks[key]
except KeyError:
return {}
recommendations = {}
worker_msgs = {}
client_msgs = {}

ts = parent._tasks.get(key)
if ts is None:
return recommendations, worker_msgs, client_msgs
start = ts._state
if start == finish:
return {}
return recommendations, worker_msgs, client_msgs

if self.plugins:
dependents = set(ts._dependents)
dependencies = set(ts._dependencies)

recommendations: dict = {}
worker_msgs = {}
client_msgs = {}
if (start, finish) in self._transitions:
func = self._transitions[start, finish]
recommendations, worker_msgs, client_msgs = func(key, *args, **kwargs)
elif "released" not in (start, finish):
start_finish = (start, finish)
func = self._transitions.get(start_finish)
if func is not None:
a: tuple = func(key, *args, **kwargs)
recommendations, worker_msgs, client_msgs = a
elif "released" not in start_finish:
func = self._transitions["released", finish]
assert not args and not kwargs
a = self.transition(key, "released")
if key in a:
func = self._transitions["released", a[key]]
b, worker_msgs, client_msgs = func(key)
a = a.copy()
a.update(b)
recommendations = a
a_recs: dict
a_wmsgs: dict
a_cmsgs: dict
a: tuple = self._transition(key, "released")
a_recs, a_wmsgs, a_cmsgs = a
v = a_recs.get(key)
if v is not None:
func = self._transitions["released", v]
b_recs: dict
b_wmsgs: dict
b_cmsgs: dict
b: tuple = func(key)
b_recs, b_wmsgs, b_cmsgs = b

recommendations.update(a_recs)
for w, new_msgs in a_wmsgs.items():
msgs = worker_msgs.get(w)
if msgs is not None:
msgs.extend(new_msgs)
else:
worker_msgs[w] = new_msgs
for c, new_msgs in a_cmsgs.items():
msgs = client_msgs.get(c)
if msgs is not None:
msgs.extend(new_msgs)
else:
client_msgs[c] = new_msgs

recommendations.update(b_recs)
for w, new_msgs in b_wmsgs.items():
msgs = worker_msgs.get(w)
if msgs is not None:
msgs.extend(new_msgs)
else:
worker_msgs[w] = new_msgs
for c, new_msgs in b_cmsgs.items():
msgs = client_msgs.get(c)
if msgs is not None:
msgs.extend(new_msgs)
else:
client_msgs[c] = new_msgs

start = "released"
else:
raise RuntimeError(
"Impossible transition from %r to %r" % (start, finish)
)

for worker, msg in worker_msgs.items():
self.worker_send(worker, msg)
for client, msg in client_msgs.items():
self.client_send(client, msg)
raise RuntimeError("Impossible transition from %r to %r" % start_finish)

finish2 = ts._state
self.transition_log.append((key, start, finish2, recommendations, time()))
Expand All @@ -5888,11 +5954,8 @@ def transition(self, key, finish, *args, **kwargs):
if self.plugins:
# Temporarily put back forgotten key for plugin to retrieve it
if ts._state == "forgotten":
try:
ts._dependents = dependents
ts._dependencies = dependencies
except KeyError:
pass
ts._dependents = dependents
ts._dependencies = dependencies
parent._tasks[ts._key] = ts
for plugin in list(self.plugins):
try:
Expand All @@ -5905,11 +5968,16 @@ def transition(self, key, finish, *args, **kwargs):
tg: TaskGroup = ts._group
if ts._state == "forgotten" and tg._name in parent._task_groups:
# Remove TaskGroup if all tasks are in the forgotten state
if not any([tg._states.get(s) for s in ALL_TASK_STATES]):
all_forgotten: bint = True
for s in ALL_TASK_STATES:
if tg._states.get(s):
all_forgotten = False
break
if all_forgotten:
ts._prefix._groups.remove(tg)
del parent._task_groups[tg._name]

return recommendations
return recommendations, worker_msgs, client_msgs
except Exception as e:
logger.exception("Error transitioning %r from %r to %r", key, start, finish)
if LOG_PDB:
Expand All @@ -5918,20 +5986,69 @@ def transition(self, key, finish, *args, **kwargs):
pdb.set_trace()
raise

def transition(self, key, finish: str, *args, **kwargs):
"""Transition a key from its current state to the finish state

Examples
--------
>>> self.transition('x', 'waiting')
{'x': 'processing'}

Returns
-------
Dictionary of recommendations for future transitions

See Also
--------
Scheduler.transitions: transitive version of this function
"""
recommendations: dict
worker_msgs: dict
client_msgs: dict
a: tuple = self._transition(key, finish, *args, **kwargs)
recommendations, worker_msgs, client_msgs = a
self.send_all(client_msgs, worker_msgs)
return recommendations

def transitions(self, recommendations: dict):
"""Process transitions until none are left

This includes feedback from previous transitions and continues until we
reach a steady state
"""
parent: SchedulerState = cast(SchedulerState, self)
keys = set()
keys: set = set()
recommendations = recommendations.copy()
worker_msgs: dict = {}
client_msgs: dict = {}
msgs: list
new_msgs: list
new: tuple
new_recs: dict
new_wmsgs: dict
new_cmsgs: dict
while recommendations:
key, finish = recommendations.popitem()
keys.add(key)
new = self.transition(key, finish)
recommendations.update(new)

new = self._transition(key, finish)
new_recs, new_wmsgs, new_cmsgs = new

recommendations.update(new_recs)
for w, new_msgs in new_wmsgs.items():
msgs = worker_msgs.get(w)
if msgs is not None:
msgs.extend(new_msgs)
else:
worker_msgs[w] = new_msgs
for c, new_msgs in new_cmsgs.items():
msgs = client_msgs.get(c)
if msgs is not None:
msgs.extend(new_msgs)
else:
client_msgs[c] = new_msgs

self.send_all(client_msgs, worker_msgs)

if parent._validate:
for key in keys:
Expand Down Expand Up @@ -6513,7 +6630,7 @@ def _add_to_memory(
report_msg["type"] = type

for cs in ts._who_wants:
client_msgs[cs._client_key] = report_msg
client_msgs[cs._client_key] = [report_msg]

ts.state = "memory"
ts._type = typename
Expand Down Expand Up @@ -6567,7 +6684,7 @@ def _propagate_forgotten(
ws._nbytes -= ts.get_nbytes()
w: str = ws._address
if w in state._workers_dv: # in case worker has died
worker_msgs[w] = {"op": "delete-data", "keys": [key], "report": False}
worker_msgs[w] = [{"op": "delete-data", "keys": [key], "report": False}]
ts._who_has.clear()


Expand Down Expand Up @@ -6674,7 +6791,7 @@ def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict:

client_msgs: dict = {}
for k in client_keys:
client_msgs[k] = report_msg
client_msgs[k] = [report_msg]

return client_msgs

Expand Down