diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index cc41bd79d0..86eb11e49e 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -131,6 +131,7 @@ def failing(x): {"key": "task", "start": "waiting", "finish": "ready"}, {"key": "task", "start": "ready", "finish": "executing"}, {"key": "task", "start": "executing", "finish": "error"}, + {"key": "task", "state": "error"}, ] plugin = MyPlugin(1, expected_notifications=expected_notifications) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f5ad740b52..45e38b24a5 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1205,6 +1205,10 @@ class TaskState: failed task is stored here (possibly itself). Otherwise this is ``None``. + .. attribute:: erred_on: set(str) + + Worker addresses on which errors appeared causing this task to be in an error state. + .. attribute:: suspicious: int The number of times this task has been involved in a worker death. @@ -1293,6 +1297,7 @@ class TaskState: _exception: object _traceback: object _exception_blame: object + _erred_on: set _suspicious: Py_ssize_t _host_restrictions: set _worker_restrictions: set @@ -1343,6 +1348,7 @@ class TaskState: "_who_wants", "_exception", "_traceback", + "_erred_on", "_exception_blame", "_suspicious", "_retries", @@ -1381,6 +1387,7 @@ def __init__(self, key: str, run_spec: object): self._group = None self._metadata = {} self._annotations = {} + self._erred_on = set() def __hash__(self): return self._hash @@ -1528,6 +1535,10 @@ def group_key(self): def prefix_key(self): return self._prefix._name + @property + def erred_on(self): + return self._erred_on + @ccall def add_dependency(self, other: "TaskState"): """Add another task as a dependency of this task""" @@ -1842,7 +1853,6 @@ def __init__( ("no-worker", "waiting"): self.transition_no_worker_waiting, ("released", "forgotten"): self.transition_released_forgotten, ("memory", "forgotten"): self.transition_memory_forgotten, - ("erred", "forgotten"): self.transition_released_forgotten, ("erred", "released"): self.transition_erred_released, ("memory", "released"): self.transition_memory_released, ("released", "erred"): self.transition_released_erred, @@ -2629,9 +2639,9 @@ def transition_memory_released(self, key, safe: bint = False): # XXX factor this out? ts_nbytes: Py_ssize_t = ts.get_nbytes() worker_msg = { - "op": "delete-data", + "op": "free-keys", "keys": [key], - "report": False, + "reason": f"Memory->Released {key}", } for ws in ts._who_has: del ws._has_what[ts] @@ -2722,7 +2732,6 @@ def transition_erred_released(self, key): if self._validate: with log_errors(pdb=LOG_PDB): - assert all([dts._state != "erred" for dts in ts._dependencies]) assert ts._exception_blame assert not ts._who_has assert not ts._waiting_on @@ -2736,6 +2745,11 @@ def transition_erred_released(self, key): if dts._state == "erred": recommendations[dts._key] = "waiting" + w_msg = {"op": "free-keys", "keys": [key], "reason": "Erred->Released"} + for ws_addr in ts._erred_on: + worker_msgs[ws_addr] = [w_msg] + ts._erred_on.clear() + report_msg = {"op": "task-retried", "key": key} cs: ClientState for cs in ts._who_wants: @@ -2805,7 +2819,9 @@ def transition_processing_released(self, key): w: str = _remove_from_processing(self, ts) if w: - worker_msgs[w] = [{"op": "release-task", "key": key}] + worker_msgs[w] = [ + {"op": "free-keys", "keys": [key], "reason": "Processing->Released"} + ] ts.state = "released" @@ -2835,7 +2851,7 @@ def transition_processing_released(self, key): raise def transition_processing_erred( - self, key, cause=None, exception=None, traceback=None, **kwargs + self, key, cause=None, exception=None, traceback=None, worker=None, **kwargs ): ws: WorkerState try: @@ -2856,8 +2872,9 @@ def transition_processing_erred( ws = ts._processing_on ws._actors.remove(ts) - _remove_from_processing(self, ts) + w = _remove_from_processing(self, ts) + ts._erred_on.add(w or worker) if exception is not None: ts._exception = exception if traceback is not None: @@ -4456,7 +4473,9 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): ts._who_has, ) if ws not in ts._who_has: - worker_msgs[worker] = [{"op": "release-task", "key": key}] + worker_msgs[worker] = [ + {"op": "free-keys", "keys": [key], "reason": "Stimulus Finished"} + ] return recommendations, client_msgs, worker_msgs @@ -5113,7 +5132,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): def release_worker_data(self, comm=None, keys=None, worker=None): parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState = parent._workers_dv[worker] - tasks: set = {parent._tasks[k] for k in keys} + tasks: set = {parent._tasks[k] for k in keys if k in parent._tasks} removed_tasks: set = tasks.intersection(ws._has_what) ts: TaskState @@ -5519,8 +5538,11 @@ async def _delete_worker_data(self, worker_address, keys): List of keys to delete on the specified worker """ parent: SchedulerState = cast(SchedulerState, self) + await retry_operation( - self.rpc(addr=worker_address).delete_data, keys=list(keys), report=False + self.rpc(addr=worker_address).free_keys, + keys=list(keys), + reason="rebalance/replicate", ) ws: WorkerState = parent._workers_dv[worker_address] @@ -6271,6 +6293,7 @@ def add_keys(self, comm=None, worker=None, keys=()): if worker not in parent._workers_dv: return "not found" ws: WorkerState = parent._workers_dv[worker] + superfluous_data = [] for key in keys: ts: TaskState = parent._tasks.get(key) if ts is not None and ts._state == "memory": @@ -6279,9 +6302,16 @@ def add_keys(self, comm=None, worker=None, keys=()): ws._has_what[ts] = None ts._who_has.add(ws) else: - self.worker_send( - worker, {"op": "delete-data", "keys": [key], "report": False} - ) + superfluous_data.append(key) + if superfluous_data: + self.worker_send( + worker, + { + "op": "superfluous-data", + "keys": superfluous_data, + "reason": f"Add keys which are not in-memory {superfluous_data}", + }, + ) return "OK" @@ -7308,7 +7338,13 @@ def _propagate_forgotten( ws._nbytes -= ts_nbytes w: str = ws._address if w in state._workers_dv: # in case worker has died - worker_msgs[w] = [{"op": "delete-data", "keys": [key], "report": False}] + worker_msgs[w] = [ + { + "op": "free-keys", + "keys": [key], + "reason": f"propagate-forgotten {ts.key}", + } + ] ts._who_has.clear() diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index fa571d0f37..32c872dec6 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -3,6 +3,7 @@ import random from contextlib import suppress from time import sleep +from unittest import mock import pytest from tlz import first, partition_all @@ -384,7 +385,26 @@ async def test_restart_during_computation(c, s, a, b): assert not s.tasks -@gen_cluster(client=True, timeout=60) +class SlowTransmitData: + def __init__(self, data, delay=0.1): + self.delay = delay + self.data = data + + def __reduce__(self): + import time + + time.sleep(self.delay) + return (SlowTransmitData, (self.delay,)) + + def __sizeof__(self) -> int: + # Ensure this is offloaded to avoid blocking loop + import dask + from dask.utils import parse_bytes + + return parse_bytes(dask.config.get("distributed.comm.offload")) + 1 + + +@gen_cluster(client=True) async def test_worker_who_has_clears_after_failed_connection(c, s, a, b): n = await Nanny(s.address, nthreads=2, loop=s.loop) @@ -393,23 +413,32 @@ async def test_worker_who_has_clears_after_failed_connection(c, s, a, b): await asyncio.sleep(0.01) assert time() < start + 5 - futures = c.map(slowinc, range(20), delay=0.01, key=["f%d" % i for i in range(20)]) - await wait(futures) - - result = await c.submit(sum, futures, workers=a.address) - deps = [dep for dep in a.tasks.values() if dep.key not in a.data_needed] - for dep in deps: - a.release_key(dep.key, report=True) + def slow_ser(x, delay): + return SlowTransmitData(x, delay=delay) n_worker_address = n.worker_address + futures = c.map( + slow_ser, + range(20), + delay=0.1, + key=["f%d" % i for i in range(20)], + workers=[n_worker_address], + allow_other_workers=True, + ) + + def sink(*args): + pass + + await wait(futures) + result_fut = c.submit(sink, futures, workers=a.address) + with suppress(CommClosedError): await c._run(os._exit, 1, workers=[n_worker_address]) while len(s.workers) > 2: await asyncio.sleep(0.01) - total = c.submit(sum, futures, workers=a.address) - await total + await result_fut assert not a.has_what.get(n_worker_address) assert not any(n_worker_address in s for ts in a.tasks.values() for s in ts.who_has) @@ -417,6 +446,51 @@ async def test_worker_who_has_clears_after_failed_connection(c, s, a, b): await n.close() +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2), ("127.0.0.1", 3)], +) +async def test_worker_same_host_replicas_missing(c, s, a, b, x): + # See GH4784 + def mock_address_host(addr): + # act as if A and X are on the same host + nonlocal a, b, x + if addr in [a.address, x.address]: + return "A" + else: + return "B" + + with mock.patch("distributed.worker.get_address_host", mock_address_host): + futures = c.map( + slowinc, + range(20), + delay=0.1, + key=["f%d" % i for i in range(20)], + workers=[a.address], + allow_other_workers=True, + ) + await wait(futures) + + # replicate data to avoid the scheduler retriggering the computation + # retriggering cleans up the state nicely but doesn't reflect real world + # scenarios where there may be replicas on the cluster, e.g. they are + # replicated as a dependency somewhere else + await c.replicate(futures, n=2, workers=[a.address, b.address]) + + def sink(*args): + pass + + # Since A and X are mocked to be co-located, X will consistently pick A + # to fetch data from. It will never succeed since we're removing data + # artificially, without notifying the scheduler. + # This can only succeed if B handles the missing data properly by + # removing A from the known sources of keys + a.handle_free_keys(keys=["f1"], reason="Am I evil?") # Yes, I am! + result_fut = c.submit(sink, futures, workers=x.address) + + await result_fut + + @pytest.mark.slow @gen_cluster(client=True, timeout=60, Worker=Nanny, nthreads=[("127.0.0.1", 1)]) async def test_restart_timeout_on_long_running_task(c, s, a): @@ -448,3 +522,101 @@ async def test_worker_time_to_live(c, s, a, b): assert time() < start + interval + 0.1 set(s.workers) == {b.address} + + +class SlowDeserialize: + def __init__(self, data, delay=0.1): + self.delay = delay + self.data = data + + def __getstate__(self): + return self.delay + + def __setstate__(self, state): + delay = state + import time + + time.sleep(delay) + return SlowDeserialize(delay) + + def __sizeof__(self) -> int: + # Ensure this is offloaded to avoid blocking loop + import dask + from dask.utils import parse_bytes + + return parse_bytes(dask.config.get("distributed.comm.offload")) + 1 + + +@gen_cluster(client=True, timeout=None) +async def test_handle_superfluous_data(c, s, a, b): + """ + See https://github.com/dask/distributed/pull/4784#discussion_r649210094 + """ + + def slow_deser(x, delay): + return SlowDeserialize(x, delay=delay) + + futA = c.submit( + slow_deser, 1, delay=1, workers=[a.address], key="A", allow_other_workers=True + ) + futB = c.submit(inc, 1, workers=[b.address], key="B") + await wait([futA, futB]) + + def reducer(*args): + return + + assert len(a.tasks) == 1 + assert futA.key in a.tasks + + assert len(b.tasks) == 1 + assert futB.key in b.tasks + + red = c.submit(reducer, [futA, futB], workers=[b.address], key="reducer") + + dep_key = futA.key + + # Wait for the connection to be established + while dep_key not in b.tasks or not b.tasks[dep_key].state == "flight": + await asyncio.sleep(0.001) + + # Wait for the connection to be returned to the pool. this signals that + # worker B is done with the communication and is about to deserialize the + # result + while a.address not in b.rpc.available and not b.rpc.available[a.address]: + await asyncio.sleep(0.001) + + assert b.tasks[dep_key].state == "flight" + # After the comm is finished and the deserialization starts, Worker B + # wouldn't notice that A dies. + await a.close() + # However, while B is busy deserializing a third worker might notice that A + # is dead and issues a handle-missing signal to the scheduler. Since at this + # point in time, A was the only worker with a verified replica, the + # scheduler reschedules the computation by transitioning it to released. The + # released transition has the side effect that it purges all data which is + # in memory which exposes us to a race condition on B if B also receives the + # signal to compute that task in the meantime. + s.handle_missing_data(key=dep_key, errant_worker=a.address) + await red + + +@gen_cluster() +async def test_forget_data_not_supposed_to_have(s, a, b): + """ + If a depednecy fetch finishes on a worker after the scheduler already + released everything, the worker might be stuck with a redundant replica + which is never cleaned up. + """ + # FIXME: Replace with "blackbox test" which shows an actual example where + # this situation is provoked if this is even possible. + # If this cannot be constructed, the entire superfuous_data handler and its + # corresponding pieces on the scheduler side may be removed + from distributed.worker import TaskState + + ts = TaskState("key") + ts.state = "flight" + a.tasks["key"] = ts + a.transition_flight_memory(ts, value=123) + assert a.data + while a.data: + await asyncio.sleep(0.001) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 18c05b849f..1746be6d4f 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1890,11 +1890,25 @@ async def test_task_groups(c, s, a, b): assert tg.states["released"] == 5 assert tp.states["memory"] == 0 assert tp.states["released"] == 5 + assert tp.groups == [tg] assert tg.prefix is tp - assert tg in tp.groups + # these must be true since in this simple case there is a 1to1 mapping + # between prefix and group assert tg.duration == tp.duration assert tg.nbytes_in_memory == tp.nbytes_in_memory assert tg.nbytes_total == tp.nbytes_total + # It should map down to individual tasks + assert tg.nbytes_total == sum( + [ts.get_nbytes() for ts in s.tasks.values() if ts.group is tg] + ) + in_memory_ts = sum( + [ + ts.get_nbytes() + for ts in s.tasks.values() + if ts.group is tg and ts.state == "memory" + ] + ) + assert tg.nbytes_in_memory == in_memory_ts tg = s.task_groups[y.name] assert tg.states["memory"] == 5 @@ -1902,6 +1916,7 @@ async def test_task_groups(c, s, a, b): assert s.task_groups[y.name].dependencies == {s.task_groups[x.name]} await c.replicate(y) + # TODO: Are we supposed to track replicated memory here? See also Scheduler.add_keys assert tg.nbytes_in_memory == y.nbytes assert "array" in str(tg.types) assert "array" in str(tp.types) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 41b0e171c7..ba4de5199d 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -748,7 +748,6 @@ async def test_log_exception_on_failed_task(c, s, a, b): logger.removeHandler(fh) -@pytest.mark.flaky(reruns=10, reruns_delay=5) @gen_cluster(client=True) async def test_clean_up_dependencies(c, s, a, b): x = delayed(inc)(1) @@ -760,15 +759,12 @@ async def test_clean_up_dependencies(c, s, a, b): zz = c.persist(z) await wait(zz) - start = time() while len(a.data) + len(b.data) > 1: await asyncio.sleep(0.01) - assert time() < start + 2 assert set(a.data) | set(b.data) == {zz.key} -@pytest.mark.flaky(reruns=10, reruns_delay=5) @gen_cluster(client=True) async def test_hold_onto_dependents(c, s, a, b): x = c.submit(inc, 1, workers=a.address) @@ -778,9 +774,8 @@ async def test_hold_onto_dependents(c, s, a, b): assert x.key in b.data await c._cancel(y) - await asyncio.sleep(0.1) - - assert x.key in b.data + while x.key not in b.data: + await asyncio.sleep(0.1) @pytest.mark.slow @@ -1840,8 +1835,8 @@ async def test_story_with_deps(c, s, a, b): ), (key, "waiting", "ready"), (key, "ready", "executing"), - (key, "put-in-memory"), (key, "executing", "memory"), + (key, "put-in-memory"), ] assert story == expected_story @@ -2035,3 +2030,312 @@ def get_thread_name(): default_result, gpu_result = await c.gather(futures) assert "Dask-Default-Threads" in default_result assert "Dask-GPU-Threads" in gpu_result + + +def assert_task_states_on_worker(expected, worker): + for dep_key, expected_state in expected.items(): + assert dep_key in worker.tasks, (worker.name, dep_key, worker.tasks) + dep_ts = worker.tasks[dep_key] + assert dep_ts.state == expected_state, (worker.name, dep_ts, expected_state) + assert set(expected) == set(worker.tasks) + + +@gen_cluster(client=True) +async def test_worker_state_error_release_error_last(c, s, a, b): + """ + Create a chain of tasks and err one of them. Then release tasks in a certain + order and ensure the tasks are released and/or kept in memory as appropriate + + F -- RES (error) + / + / + G + + Free error last + """ + + def raise_exc(*args): + raise RuntimeError() + + f = c.submit(inc, 1, workers=[a.address], key="f") + g = c.submit(inc, 1, workers=[b.address], key="g") + res = c.submit(raise_exc, f, g, workers=[a.address]) + + with pytest.raises(RuntimeError): + await res.result() + + # Nothing bad happened on B, therefore B should hold on to G + assert len(b.tasks) == 1 + assert g.key in b.tasks + + # A raised the exception therefore we should hold on to the erroneous task + assert res.key in a.tasks + ts = a.tasks[res.key] + assert ts.state == "error" + + expected_states = { + # A was instructed to compute this result and we're still holding a ref via `f` + f.key: "memory", + # This was fetched from another worker. While we hold a ref via `g`, the + # scheduler only instructed to compute this on B + g.key: "memory", + res.key: "error", + } + assert_task_states_on_worker(expected_states, a) + # Expected states after we release references to the futures + f.release() + g.release() + + # We no longer hold any refs to f or g and B didn't have any erros. It + # releases everything as expected + while b.tasks: + await asyncio.sleep(0.01) + + expected_states = { + # We currently don't have a good way to actually release this memory as + # long as the tasks still have a dependent. We'll need to live with this + # memory for now + f.key: "memory", + g.key: "memory", + res.key: "error", + } + + assert_task_states_on_worker(expected_states, a) + + res.release() + + # We no longer hold any refs. Cluster should reset completely + # This is not happening + for server in [s, a, b]: + while server.tasks: + await asyncio.sleep(0.01) + + +@gen_cluster(client=True) +async def test_worker_state_error_release_error_first(c, s, a, b): + """ + Create a chain of tasks and err one of them. Then release tasks in a certain + order and ensure the tasks are released and/or kept in memory as appropriate + + F -- RES (error) + / + / + G + + Free error first + """ + + def raise_exc(*args): + raise RuntimeError() + + f = c.submit(inc, 1, workers=[a.address], key="f") + g = c.submit(inc, 1, workers=[b.address], key="g") + res = c.submit(raise_exc, f, g, workers=[a.address]) + + with pytest.raises(RuntimeError): + await res.result() + + # Nothing bad happened on B, therefore B should hold on to G + assert len(b.tasks) == 1 + assert g.key in b.tasks + + # A raised the exception therefore we should hold on to the erroneous task + assert res.key in a.tasks + ts = a.tasks[res.key] + assert ts.state == "error" + + expected_states = { + # A was instructed to compute this result and we're still holding a ref + # via `f` + f.key: "memory", + # This was fetched from another worker. While we hold a ref via `g`, the + # scheduler only instructed to compute this on B + g.key: "memory", + res.key: "error", + } + assert_task_states_on_worker(expected_states, a) + # Expected states after we release references to the futures + + res.release() + # We no longer hold any refs to f or g and B didn't have any erros. It + # releases everything as expected + while res.key in a.tasks: + await asyncio.sleep(0.01) + + expected_states = { + f.key: "memory", + } + + assert_task_states_on_worker(expected_states, a) + + f.release() + g.release() + + # This is not happening + for server in [s, a, b]: + while server.tasks: + await asyncio.sleep(0.01) + + +@gen_cluster(client=True) +async def test_worker_state_error_release_error_int(c, s, a, b): + """ + Create a chain of tasks and err one of them. Then release tasks in a certain + order and ensure the tasks are released and/or kept in memory as appropriate + + F -- RES (error) + / + / + G + + Free one successful task, then error, then last task + """ + + def raise_exc(*args): + raise RuntimeError() + + f = c.submit(inc, 1, workers=[a.address], key="f") + g = c.submit(inc, 1, workers=[b.address], key="g") + res = c.submit(raise_exc, f, g, workers=[a.address]) + + with pytest.raises(RuntimeError): + await res.result() + + # Nothing bad happened on B, therefore B should hold on to G + assert len(b.tasks) == 1 + assert g.key in b.tasks + + # A raised the exception therefore we should hold on to the erroneous task + assert res.key in a.tasks + ts = a.tasks[res.key] + assert ts.state == "error" + + expected_states = { + # A was instructed to compute this result and we're still holding a ref via `f` + f.key: "memory", + # This was fetched from another worker. While we hold a ref via `g`, the + # scheduler only instructed to compute this on B + g.key: "memory", + res.key: "error", + } + assert_task_states_on_worker(expected_states, a) + # Expected states after we release references to the futures + + f.release() + res.release() + # We no longer hold any refs to f or g and B didn't have any erros. It + # releases everything as expected + while a.tasks: + await asyncio.sleep(0.01) + + expected_states = { + g.key: "memory", + } + + assert_task_states_on_worker(expected_states, b) + + g.release() + + # We no longer hold any refs. Cluster should reset completely + for server in [s, a, b]: + while server.tasks: + await asyncio.sleep(0.01) + + +@gen_cluster(client=True) +async def test_worker_state_error_long_chain(c, s, a, b): + def raise_exc(*args): + raise RuntimeError() + + # f (A) --------> res (B) + # / + # g (B) -> h (A) + + f = c.submit(inc, 1, workers=[a.address], key="f", allow_other_workers=False) + g = c.submit(inc, 1, workers=[b.address], key="g", allow_other_workers=False) + h = c.submit(inc, g, workers=[a.address], key="h", allow_other_workers=False) + res = c.submit( + raise_exc, f, h, workers=[b.address], allow_other_workers=False, key="res" + ) + + with pytest.raises(RuntimeError): + await res.result() + + expected_states_A = { + f.key: "memory", + g.key: "memory", + h.key: "memory", + } + await asyncio.sleep(0.05) + assert_task_states_on_worker(expected_states_A, a) + + expected_states_B = { + f.key: "memory", + g.key: "memory", + h.key: "memory", + res.key: "error", + } + await asyncio.sleep(0.05) + assert_task_states_on_worker(expected_states_B, b) + + f.release() + + expected_states_A = { + g.key: "memory", + h.key: "memory", + } + await asyncio.sleep(0.05) + assert_task_states_on_worker(expected_states_A, a) + + expected_states_B = { + f.key: "memory", + g.key: "memory", + h.key: "memory", + res.key: "error", + } + await asyncio.sleep(0.05) + assert_task_states_on_worker(expected_states_B, b) + + g.release() + + expected_states_A = { + h.key: "memory", + } + await asyncio.sleep(0.05) + assert_task_states_on_worker(expected_states_A, a) + + # B must not forget a task since all have a still valid dependent + expected_states_B = { + f.key: "memory", + # We actually cannot hold on to G even though the graph would suggest + # otherwise. This is because H was only introduced as a dependency and + # the scheduler never told the worker how H fits into the big picture. + # Therefore, it thinks that G does not have any dependents anymore and + # releases it. Too bad. Once we have speculative task assignments this + # should be more exact since we should always tell the worker what's + # going on + # g.key: released, + h.key: "memory", + res.key: "error", + } + assert_task_states_on_worker(expected_states_B, b) + h.release() + await asyncio.sleep(0.05) + + expected_states_A = {} + assert_task_states_on_worker(expected_states_A, a) + expected_states_B = { + f.key: "memory", + # See above + # g.key: released, + h.key: "memory", + res.key: "error", + } + + assert_task_states_on_worker(expected_states_B, b) + res.release() + + # We no longer hold any refs. Cluster should reset completely + for server in [s, a, b]: + while server.tasks: + await asyncio.sleep(0.01) diff --git a/distributed/worker.py b/distributed/worker.py index 0735043355..e802af68c7 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -14,10 +14,9 @@ from collections.abc import MutableMapping from contextlib import suppress from datetime import timedelta -from functools import partial from inspect import isawaitable from pickle import PicklingError -from typing import Dict, Iterable +from typing import Dict, Iterable, Optional from tlz import first, keymap, merge, pluck # noqa: F401 from tornado import gen @@ -188,6 +187,7 @@ def __init__(self, key, runspec=None): self.metadata = {} self.nbytes = None self.annotations = None + self.scheduler_holds_ref = False def __repr__(self): return "" % (self.key, self.state) @@ -685,7 +685,7 @@ def __init__( "run_coroutine": self.run_coroutine, "get_data": self.get_data, "update_data": self.update_data, - "delete_data": self.delete_data, + "free_keys": self.handle_free_keys, "terminate": self.close, "ping": pingpong, "upload_file": self.upload_file, @@ -706,8 +706,8 @@ def __init__( stream_handlers = { "close": self.close, "compute-task": self.add_task, - "release-task": partial(self.release_key, report=False), - "delete-data": self.delete_data, + "free-keys": self.handle_free_keys, + "superfluous-data": self.handle_superfluous_data, "steal-request": self.steal_request, } @@ -1477,21 +1477,62 @@ def update_data(self, comm=None, data=None, report=True, serializers=None): self.put_key_in_memory(ts, value) ts.priority = None ts.duration = None + ts.scheduler_holds_ref = True self.log.append((key, "receive-from-scatter")) if report: + + self.log.append( + ("Notifying scheduler about in-memory in update-data", list(data)) + ) self.batched_stream.send({"op": "add-keys", "keys": list(data)}) info = {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} return info - def delete_data(self, comm=None, keys=None, report=True): - if keys: - for key in list(keys): - self.log.append((key, "delete")) - self.release_key(key, cause="delete data") + def handle_free_keys(self, comm=None, keys=None, reason=None): + """ + Handler to be called by the scheduler. + + The given keys are no longer referred to and required by the scheduler. + The worker is now allowed to release the key, if applicable. + + This does not guarantee that the memory is released since the worker may + still decide to hold on to the data and task since it is required by an + upstream dependency. + """ + self.log.append(("free-keys", keys, reason)) + for key in keys: + ts = self.tasks.get(key) + if ts: + ts.scheduler_holds_ref = False + self.release_key(key, report=False, reason=reason) + + def handle_superfluous_data(self, keys=(), reason=None): + """Stream handler notifying the worker that it might be holding unreferenced, superfluous data. + + This should not actually happen during ordinary operations and is only + intended to correct any erroneous state. An example where this is + necessary is if a worker fetches data for a downstream task but that + task is released before the data arrives. + In this case, the scheduler will notify the worker that it may be + holding this unnecessary data, if the worker hasn't released the data itself, already. + + This handler does not guarantee the task nor the data to be actually + released but only asks the worker to release the data on a best effort + guarantee. This protects from race conditions where the given keys may + already have been rescheduled for compute in which case the compute + would win and this handler is ignored. + + For stronger guarantees, see handler free_keys + """ + self.log.append(("Handle superfluous data", keys, reason)) + for key in list(keys): + ts = self.tasks.get(key) + if ts and not ts.scheduler_holds_ref: + self.release_key(key, reason=f"delete data: {reason}", report=False) - logger.debug("Worker %s -- Deleted %d keys", self.name, len(keys)) + logger.debug("Worker %s -- Deleted %d keys", self.name, len(keys)) return "OK" async def set_resources(self, **resources): @@ -1532,6 +1573,7 @@ def add_task( runspec = SerializedTask(function, args, kwargs, task) if key in self.tasks: ts = self.tasks[key] + ts.scheduler_holds_ref = True if ts.state == "memory": assert key in self.data or key in self.actors logger.debug( @@ -1555,7 +1597,6 @@ def add_task( key=key, runspec=SerializedTask(function, args, kwargs, task) ) self.transition(ts, "waiting") - # TODO: move transition of `ts` to end of `add_task` # This will require a chained recommendation transition system like # the scheduler @@ -1567,6 +1608,7 @@ def add_task( if actor: self.actors[ts.key] = None + ts.scheduler_holds_ref = True ts.runspec = runspec ts.priority = priority ts.duration = duration @@ -1593,10 +1635,23 @@ def add_task( # transition from new -> fetch handles adding dependency # to waiting_for_data - self.transition(dep_ts, state) + discarded_self = False + if self.address in workers and state == "fetch": + discarded_self = True + workers = set(workers) + workers.discard(self.address) + who_has[dependency] = tuple(workers) + + self.transition(dep_ts, state, who_has=workers) self.log.append( - (dependency, "new-dep", dep_ts.state, f"requested by {ts.key}") + ( + dependency, + "new-dep", + dep_ts.state, + f"requested by {ts.key}", + discarded_self, + ) ) else: @@ -1610,13 +1665,8 @@ def add_task( if dep_ts.state in ("fetch", "flight"): # if we _need_ to grab data or are in the process ts.waiting_for_data.add(dep_ts.key) - # Ensure we know which workers to grab data from - dep_ts.who_has.update(workers) - - for worker in workers: - self.has_what[worker].add(dep_ts.key) - self.pending_data_per_worker[worker].append(dep_ts.key) + self.update_who_has(who_has=who_has) if nbytes is not None: for key, value in nbytes.items(): self.tasks[key].nbytes = value @@ -1650,8 +1700,10 @@ def transition(self, ts, finish, **kwargs): if start == finish: return func = self._transitions[start, finish] + self.log.append((ts.key, start, finish)) state = func(ts, **kwargs) - self.log.append((ts.key, start, state or finish)) + if state and finish != state: + self.log.append((ts.key, start, finish, state)) ts.state = state or finish if self.validate: self.validate_task(ts) @@ -1671,15 +1723,21 @@ def transition_new_waiting(self, ts): pdb.set_trace() raise - def transition_new_fetch(self, ts): + def transition_new_fetch(self, ts, who_has): try: if self.validate: assert ts.state == "new" assert ts.runspec is None + assert who_has for dependent in ts.dependents: dependent.waiting_for_data.add(ts.key) + ts.who_has.update(who_has) + for w in who_has: + self.has_what[w].add(ts.key) + self.pending_data_per_worker[w].append(ts.key) + except Exception as e: logger.exception(e) if LOG_PDB: @@ -1767,19 +1825,22 @@ def transition_fetch_flight(self, ts, worker=None): pdb.set_trace() raise - def transition_flight_fetch(self, ts, worker=None, runspec=None): + def transition_flight_fetch(self, ts): try: if self.validate: assert ts.state == "flight" self.in_flight_tasks -= 1 ts.coming_from = None - ts.runspec = runspec or ts.runspec + ts.runspec = None if not ts.who_has: if ts.key not in self._missing_dep_flight: self._missing_dep_flight.add(ts.key) + logger.info("Task %s does not know who has", ts) self.loop.add_callback(self.handle_missing_dep, ts) + for w in ts.who_has: + self.pending_data_per_worker[w].append(ts.key) for dependent in ts.dependents: dependent.waiting_for_data.add(ts.key) if dependent.state == "waiting": @@ -1808,6 +1869,7 @@ def transition_flight_memory(self, ts, value=None): except KeyError: pass + self.log.append(("Notifying scheduler about in-memory", ts.key)) self.batched_stream.send({"op": "add-keys", "keys": [ts.key]}) except Exception as e: @@ -2050,10 +2112,9 @@ def ensure_communicating(self): else: dependencies_fetch.add(dependency_ts) - del dependencies + del dependencies, dependency_ts if dependencies_missing: - logger.info("Can't find dependencies for key %s", key) missing_deps2 = { dep for dep in dependencies_missing @@ -2061,8 +2122,13 @@ def ensure_communicating(self): } for dep in missing_deps2: self._missing_dep_flight.add(dep.key) - self.loop.add_callback(self.handle_missing_dep, *missing_deps2) - + if missing_deps2: + logger.info( + "Can't find dependencies %s for key %s", + missing_deps2.copy(), + key, + ) + self.loop.add_callback(self.handle_missing_dep, *missing_deps2) dependencies_fetch -= dependencies_missing self.log.append( @@ -2099,6 +2165,7 @@ def ensure_communicating(self): for d in to_gather: dependencies_fetch.discard(self.tasks.get(d)) self.transition(self.tasks[d], "flight", worker=worker) + assert not worker == self.address self.loop.add_callback( self.gather_dep, worker=worker, @@ -2238,6 +2305,9 @@ async def gather_dep( cause : TaskState Task we want to gather dependencies for """ + + if self.validate: + self.validate_state() if self.status != Status.running: return with log_errors(): @@ -2250,10 +2320,17 @@ async def gather_dep( dependency_ts = self.tasks.get(dependency_key) if dependency_ts and dependency_ts.state == "flight": to_gather_keys.add(dependency_key) - del to_gather + # Keep namespace clean since this func is long and has many + # dep*, *ts* variables + del to_gather, dependency_key, dependency_ts self.log.append(("request-dep", cause.key, worker, to_gather_keys)) - logger.debug("Request %d keys for task %s", len(to_gather_keys), cause) + logger.debug( + "Request %d keys for task %s from %s", + len(to_gather_keys), + cause, + worker, + ) start = time() response = await get_data_from_worker( @@ -2321,9 +2398,12 @@ async def gather_dep( self.log.append(("receive-dep", worker, list(response["data"]))) except EnvironmentError: logger.exception("Worker stream died during communication: %s", worker) - self.log.append(("receive-dep-failed", worker)) - for d in self.has_what.pop(worker): - self.tasks[d].who_has.remove(worker) + has_what = self.has_what.pop(worker) + self.pending_data_per_worker.pop(worker) + self.log.append(("receive-dep-failed", worker, has_what)) + for d in has_what: + ts = self.tasks[d] + ts.who_has.remove(worker) except Exception as e: logger.exception(e) @@ -2337,6 +2417,10 @@ async def gather_dep( busy = response.get("status", "") == "busy" data = response.get("data", {}) + # FIXME: We should not handle keys which were skipped by this coro. to_gather_keys is only a subset + assert set(to_gather_keys).issubset( + set(self.in_flight_workers.get(worker)) + ) for d in self.in_flight_workers.pop(worker): ts = self.tasks.get(d) @@ -2344,16 +2428,32 @@ async def gather_dep( if not busy and d in data: self.transition(ts, "memory", value=data[d]) elif ts is None or ts.state == "executing": - self.release_key(d, cause="already executing at gather") - continue - elif ts.state not in ("ready", "memory"): - self.transition(ts, "fetch", worker=worker) - - if not busy and d not in data and ts.dependents: + self.log.append(("already-executing", d)) + self.release_key(d, reason="already executing at gather") + elif ts.state == "flight" and not ts.dependents: + self.log.append(("flight no-dependents", d)) + self.release_key( + d, reason="In-flight task no longer has dependents." + ) + elif ( + not busy + and d not in data + and ts.dependents + and ts.state != "memory" + ): + ts.who_has.discard(worker) + self.has_what[worker].discard(ts.key) self.log.append(("missing-dep", d)) self.batched_stream.send( {"op": "missing-data", "errant_worker": worker, "key": d} ) + self.transition(ts, "fetch") + elif ts.state not in ("ready", "memory"): + self.transition(ts, "fetch") + else: + logger.debug( + "Unexpected task state encountered for %s after gather_dep" + ) if self.validate: self.validate_state() @@ -2380,7 +2480,7 @@ def bad_dep(self, dep): ts.exception = msg["exception"] ts.traceback = msg["traceback"] self.transition(ts, "error") - self.release_key(dep.key, cause="bad dep") + self.release_key(dep.key, reason="bad dep") async def handle_missing_dep(self, *deps, **kwargs): self.log.append(("handle-missing", deps)) @@ -2408,18 +2508,40 @@ async def handle_missing_dep(self, *deps, **kwargs): ) who_has = {k: v for k, v in who_has.items() if v} self.update_who_has(who_has) + still_missing = set() for dep in deps: dep.suspicious_count += 1 if not who_has.get(dep.key): + logger.info( + "No workers found for %s", + dep.key, + ) self.log.append((dep.key, "no workers found", dep.dependents)) - self.release_key(dep.key) + self.release_key(dep.key, reason="Handle missing no workers") + elif self.address in who_has and dep.state != "memory": + + still_missing.add(dep) + self.batched_stream.send( + { + "op": "release-worker-data", + "keys": [dep.key], + "worker": self.address, + } + ) else: + logger.debug("New workers found for %s", dep.key) self.log.append((dep.key, "new workers found")) for dependent in dep.dependents: if dependent.key in dep.waiting_for_data: self.data_needed.append(dependent.key) - + if still_missing: + logger.debug( + "Found self referencing who has response from scheduler for keys %s.\n" + "Trying again handle_missing", + deps, + ) + await self.handle_missing_dep(*deps) except Exception: logger.error("Handle missing dep failed, retrying", exc_info=True) retries = kwargs.get("retries", 5) @@ -2450,6 +2572,14 @@ def update_who_has(self, who_has): continue if dep in self.tasks: + if self.address in workers and self.tasks[dep].state != "memory": + logger.debug( + "Scheduler claims worker %s holds data for task %s which is not true.", + self.name, + dep, + ) + # Do not mutate the input dict. That's rude + workers = set(workers) - {self.address} self.tasks[dep].who_has.update(workers) for worker in workers: @@ -2479,20 +2609,41 @@ def steal_request(self, key): # If task is marked as "constrained" we haven't yet assigned it an # `available_resources` to run on, that happens in # `transition_constrained_executing` - self.release_key(ts.key, cause="stolen") + ts.scheduler_holds_ref = False + self.release_key(ts.key, reason="stolen") if self.validate: assert ts.key not in self.tasks - def release_key(self, key, cause=None, reason=None, report=True): + def release_key( + self, + key: str, + cause: Optional[TaskState] = None, + reason: Optional[str] = None, + report: bool = True, + ): try: + if self.validate: assert isinstance(key, str) ts = self.tasks.get(key, TaskState(key=key)) - + # If the scheduler holds a reference which is usually the + # case when it instructed the task to be computed here or if + # data was scattered we must not release it unless the + # scheduler allow us to. See also handle_delete_data and + if ts and ts.scheduler_holds_ref: + return + logger.debug( + "Release key %s", + { + "key": key, + "cause": cause, + "reason": reason, + }, + ) if cause: - self.log.append((key, "release-key", {"cause": cause})) + self.log.append((key, "release-key", {"cause": cause}, reason)) else: - self.log.append((key, "release-key")) + self.log.append((key, "release-key", reason)) if key in self.data and not ts.dependents: try: del self.data[key] @@ -2504,14 +2655,16 @@ def release_key(self, key, cause=None, reason=None, report=True): # for any dependencies of key we are releasing remove task as dependent for dependency in ts.dependencies: dependency.dependents.discard(ts) - # don't boot keys that are in flight - # we don't know if they're already queued up for transit - # in a gather_dep callback - if not dependency.dependents and dependency.state in ( - "waiting", - "fetch", + + if not dependency.dependents and dependency.state not in ( + # don't boot keys that are in flight + # we don't know if they're already queued up for transit + # in a gather_dep callback + "flight", + # The same is true for already executing keys. + "executing", ): - self.release_key(dependency.key, cause=f"Dependent {ts} released") + self.release_key(dependency.key, reason=f"Dependent {ts} released") for worker in ts.who_has: self.has_what[worker].discard(ts.key) @@ -2528,10 +2681,18 @@ def release_key(self, key, cause=None, reason=None, report=True): for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity - # Inform the scheduler of keys which will have gone missing - # We are releasing them before they have completed - if report and ts.state in PROCESSING: - self.batched_stream.send({"op": "release", "key": key, "cause": cause}) + if report: + # Inform the scheduler of keys which will have gone missing + # We are releasing them before they have completed + if ts.state in PROCESSING: + msg = {"op": "release", "key": key, "cause": cause} + else: + msg = { + "op": "release-worker-data", + "keys": [key], + "worker": self.address, + } + self.batched_stream.send(msg) self._notify_plugins("release_key", key, ts.state, cause, reason, report) if key in self.tasks and not ts.dependents: @@ -2880,7 +3041,7 @@ async def execute(self, key, report=False): elif isinstance(result.pop("actual-exception"), Reschedule): self.batched_stream.send({"op": "reschedule", "key": ts.key}) self.transition(ts, "rescheduled", report=False) - self.release_key(ts.key, report=False) + self.release_key(ts.key, report=False, reason="Reschedule") else: ts.exception = result["exception"] ts.traceback = result["traceback"] @@ -3184,11 +3345,33 @@ def validate_task_waiting(self, ts): def validate_task_flight(self, ts): assert ts.key not in self.data assert not any(dep.key in self.ready for dep in ts.dependents) + assert ts.coming_from + assert ts.coming_from in self.in_flight_workers assert ts.key in self.in_flight_workers[ts.coming_from] def validate_task_fetch(self, ts): assert ts.runspec is None assert ts.key not in self.data + assert self.address not in ts.who_has #!!!!!!!! + # FIXME This is currently not an invariant since upon comm failure we + # remove the erroneous worker from all who_has and correct the state + # upon the next ensure_communicate + + # if not ts.who_has: + # # If we do not know who_has for a fetch task, it must be logged in + # # the missing dep. There should be a handle_missing_dep running for + # # all of these keys + + # assert ts.key in self._missing_dep_flight, ( + # ts.key, + # self.story(ts), + # self._missing_dep_flight.copy(), + # self.in_flight_workers.copy(), + # ) + assert ts.dependents + + for w in ts.who_has: + assert ts.key in self.has_what[w] def validate_task(self, ts): try: @@ -3229,7 +3412,7 @@ def validate_state(self): # dependency can still be in `memory` before GC grabs it...? # Might need better bookkeeping assert dep.state is not None - assert ts in dep.dependents + assert ts in dep.dependents, ts for key in ts.waiting_for_data: ts_wait = self.tasks[key] assert ( @@ -3251,6 +3434,7 @@ def validate_state(self): self.validate_task(ts) except Exception as e: + self.loop.add_callback(self.close) logger.exception(e) if LOG_PDB: import pdb