Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forget erred tasks // Fix deadlocks on worker #4784

Merged
merged 11 commits into from
Jun 11, 2021
1 change: 1 addition & 0 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def failing(x):
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "error"},
{"key": "task", "state": "error"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand Down
64 changes: 50 additions & 14 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,10 @@ class TaskState:
failed task is stored here (possibly itself). Otherwise this
is ``None``.

.. attribute:: erred_on: set(str)

Worker addresses on which errors appeared causing this task to be in an error state.

.. attribute:: suspicious: int

The number of times this task has been involved in a worker death.
Expand Down Expand Up @@ -1293,6 +1297,7 @@ class TaskState:
_exception: object
_traceback: object
_exception_blame: object
_erred_on: set
_suspicious: Py_ssize_t
_host_restrictions: set
_worker_restrictions: set
Expand Down Expand Up @@ -1343,6 +1348,7 @@ class TaskState:
"_who_wants",
"_exception",
"_traceback",
"_erred_on",
"_exception_blame",
"_suspicious",
"_retries",
Expand Down Expand Up @@ -1381,6 +1387,7 @@ def __init__(self, key: str, run_spec: object):
self._group = None
self._metadata = {}
self._annotations = {}
self._erred_on = set()

def __hash__(self):
return self._hash
Expand Down Expand Up @@ -1528,6 +1535,10 @@ def group_key(self):
def prefix_key(self):
return self._prefix._name

@property
def erred_on(self):
return self._erred_on
Comment on lines +1538 to +1540
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the reasons why the erred task state is never forgotten is that keys are only "remembered" in processing_on (processing) or who_has (memory), depending on the state of the task. This would be the equivalent for erred tasks to allow us to tell the worker to forget the task.

@jakirkham you have been involved in the scheduler state machine a lot recently. Just pinging in case you have thoughts about adding more state here or if you see other options


@ccall
def add_dependency(self, other: "TaskState"):
"""Add another task as a dependency of this task"""
Expand Down Expand Up @@ -1842,7 +1853,6 @@ def __init__(
("no-worker", "waiting"): self.transition_no_worker_waiting,
("released", "forgotten"): self.transition_released_forgotten,
("memory", "forgotten"): self.transition_memory_forgotten,
("erred", "forgotten"): self.transition_released_forgotten,
("erred", "released"): self.transition_erred_released,
("memory", "released"): self.transition_memory_released,
("released", "erred"): self.transition_released_erred,
Expand Down Expand Up @@ -2629,9 +2639,9 @@ def transition_memory_released(self, key, safe: bint = False):
# XXX factor this out?
ts_nbytes: Py_ssize_t = ts.get_nbytes()
worker_msg = {
"op": "delete-data",
"op": "free-keys",
"keys": [key],
"report": False,
"reason": f"Memory->Released {key}",
}
for ws in ts._who_has:
del ws._has_what[ts]
Expand Down Expand Up @@ -2722,7 +2732,6 @@ def transition_erred_released(self, key):

if self._validate:
with log_errors(pdb=LOG_PDB):
assert all([dts._state != "erred" for dts in ts._dependencies])
assert ts._exception_blame
assert not ts._who_has
assert not ts._waiting_on
Expand All @@ -2736,6 +2745,11 @@ def transition_erred_released(self, key):
if dts._state == "erred":
recommendations[dts._key] = "waiting"

w_msg = {"op": "free-keys", "keys": [key], "reason": "Erred->Released"}
for ws_addr in ts._erred_on:
worker_msgs[ws_addr] = [w_msg]
ts._erred_on.clear()

report_msg = {"op": "task-retried", "key": key}
cs: ClientState
for cs in ts._who_wants:
Expand Down Expand Up @@ -2805,7 +2819,9 @@ def transition_processing_released(self, key):

w: str = _remove_from_processing(self, ts)
if w:
worker_msgs[w] = [{"op": "release-task", "key": key}]
worker_msgs[w] = [
{"op": "free-keys", "keys": [key], "reason": "Processing->Released"}
]

ts.state = "released"

Expand Down Expand Up @@ -2835,7 +2851,7 @@ def transition_processing_released(self, key):
raise

def transition_processing_erred(
self, key, cause=None, exception=None, traceback=None, **kwargs
self, key, cause=None, exception=None, traceback=None, worker=None, **kwargs
):
ws: WorkerState
try:
Expand All @@ -2856,8 +2872,9 @@ def transition_processing_erred(
ws = ts._processing_on
ws._actors.remove(ts)

_remove_from_processing(self, ts)
w = _remove_from_processing(self, ts)

ts._erred_on.add(w or worker)
if exception is not None:
ts._exception = exception
if traceback is not None:
Expand Down Expand Up @@ -4456,7 +4473,9 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs):
ts._who_has,
)
if ws not in ts._who_has:
worker_msgs[worker] = [{"op": "release-task", "key": key}]
worker_msgs[worker] = [
{"op": "free-keys", "keys": [key], "reason": "Stimulus Finished"}
]

return recommendations, client_msgs, worker_msgs

Expand Down Expand Up @@ -5113,7 +5132,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
def release_worker_data(self, comm=None, keys=None, worker=None):
parent: SchedulerState = cast(SchedulerState, self)
ws: WorkerState = parent._workers_dv[worker]
tasks: set = {parent._tasks[k] for k in keys}
tasks: set = {parent._tasks[k] for k in keys if k in parent._tasks}
removed_tasks: set = tasks.intersection(ws._has_what)

ts: TaskState
Expand Down Expand Up @@ -5519,8 +5538,11 @@ async def _delete_worker_data(self, worker_address, keys):
List of keys to delete on the specified worker
"""
parent: SchedulerState = cast(SchedulerState, self)

await retry_operation(
self.rpc(addr=worker_address).delete_data, keys=list(keys), report=False
self.rpc(addr=worker_address).free_keys,
keys=list(keys),
reason="rebalance/replicate",
)

ws: WorkerState = parent._workers_dv[worker_address]
Expand Down Expand Up @@ -6271,6 +6293,7 @@ def add_keys(self, comm=None, worker=None, keys=()):
if worker not in parent._workers_dv:
return "not found"
ws: WorkerState = parent._workers_dv[worker]
superfluous_data = []
for key in keys:
ts: TaskState = parent._tasks.get(key)
if ts is not None and ts._state == "memory":
Expand All @@ -6279,9 +6302,16 @@ def add_keys(self, comm=None, worker=None, keys=()):
ws._has_what[ts] = None
ts._who_has.add(ws)
else:
self.worker_send(
worker, {"op": "delete-data", "keys": [key], "report": False}
)
superfluous_data.append(key)
if superfluous_data:
self.worker_send(
worker,
{
"op": "superfluous-data",
"keys": superfluous_data,
"reason": f"Add keys which are not in-memory {superfluous_data}",
},
)

return "OK"

Expand Down Expand Up @@ -7308,7 +7338,13 @@ def _propagate_forgotten(
ws._nbytes -= ts_nbytes
w: str = ws._address
if w in state._workers_dv: # in case worker has died
worker_msgs[w] = [{"op": "delete-data", "keys": [key], "report": False}]
worker_msgs[w] = [
{
"op": "free-keys",
"keys": [key],
"reason": f"propagate-forgotten {ts.key}",
}
]
ts._who_has.clear()


Expand Down
Loading