From 9479d54bb5f5bf3c23e3966b65d9c12857d4bcf8 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 25 May 2022 11:11:39 +0100 Subject: [PATCH 1/9] Overhaul update_who_has --- distributed/_stories.py | 17 +- distributed/scheduler.py | 6 +- distributed/tests/test_cancelled_state.py | 4 +- distributed/tests/test_stories.py | 1 + distributed/tests/test_worker.py | 3 +- .../tests/test_worker_state_machine.py | 240 ++++++++++++++++-- distributed/worker.py | 153 ++++++++--- 7 files changed, 350 insertions(+), 74 deletions(-) diff --git a/distributed/_stories.py b/distributed/_stories.py index d17e54df53f..e461ae31ca2 100644 --- a/distributed/_stories.py +++ b/distributed/_stories.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Iterable @@ -19,26 +21,29 @@ def scheduler_story(keys: set, transition_log: Iterable) -> list: return [t for t in transition_log if t[0] in keys or keys.intersection(t[3])] -def worker_story(keys: set, log: Iterable) -> list: +def worker_story(keys_or_tags: set[str], log: Iterable) -> list: """Creates a story from the worker log given a set of keys describing tasks or stimuli. Parameters ---------- - keys : set - A set of task `keys` or `stimulus_id`'s + keys_or_tags : set[str] + A set of task `keys` or arbitrary tags from the event log, e.g. `stimulus_id`'s log : iterable The worker log Returns ------- - story : list + story : list[str] """ return [ msg for msg in log - if any(key in msg for key in keys) + if any(key in msg for key in keys_or_tags) or any( - key in c for key in keys for c in msg if isinstance(c, (tuple, list, set)) + key in c + for key in keys_or_tags + for c in msg + if isinstance(c, (tuple, list, set)) ) ] diff --git a/distributed/scheduler.py b/distributed/scheduler.py index c08a79d73a8..cb5318a8231 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4141,7 +4141,9 @@ def stimulus_retry(self, keys, client=None): return tuple(seen) @log_errors - async def remove_worker(self, address, stimulus_id, safe=False, close=True): + async def remove_worker( + self, address: str, *, stimulus_id: str, safe: bool = False, close: bool = True + ) -> Literal["OK", "already-removed"]: """ Remove worker from cluster @@ -4150,7 +4152,7 @@ async def remove_worker(self, address, stimulus_id, safe=False, close=True): state. """ if self.status == Status.closed: - return + return "already-removed" address = self.coerce_address(address) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index a4dca9e2872..409d530277c 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -300,7 +300,7 @@ async def get_data(self, comm, *args, **kwargs): s.set_restrictions({fut1.key: [a.address, b.address]}) # It is removed, i.e. get_data is guaranteed to fail and f1 is scheduled # to be recomputed on B - await s.remove_worker(a.address, "foo", close=False, safe=True) + await s.remove_worker(a.address, stimulus_id="foo", close=False, safe=True) while not b.tasks[fut1.key].state == "resumed": await asyncio.sleep(0.01) @@ -438,7 +438,7 @@ async def get_data(self, comm, *args, **kwargs): f3.key: {w2.address}, } ) - await s.remove_worker(w1.address, "stim-id") + await s.remove_worker(w1.address, stimulus_id="stim-id") await wait_for_state(f3.key, "resumed", w2) assert_story( diff --git a/distributed/tests/test_stories.py b/distributed/tests/test_stories.py index ec81ddaff8b..0500bc44a48 100644 --- a/distributed/tests/test_stories.py +++ b/distributed/tests/test_stories.py @@ -156,6 +156,7 @@ async def test_worker_story_with_deps(c, s, a, b): assert stimulus_ids == {"compute-task"} expected = [ ("dep", "ensure-task-exists", "released"), + ("dep", "update-who-has", [a.address], []), ("dep", "released", "fetch", "fetch", {}), ("gather-dependencies", a.address, {"dep"}), ("dep", "fetch", "flight", "flight", {}), diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 9e9330e45e9..6d882f0b2a9 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1741,7 +1741,7 @@ async def close(self): ) as wlogger, captured_logger( "distributed.scheduler", level=logging.WARNING ) as slogger: - await s.remove_worker(a.address, "foo") + await s.remove_worker(a.address, stimulus_id="foo") assert not s.workers # Wait until the close signal reaches the worker and it starts shutting down. @@ -2765,6 +2765,7 @@ def __getstate__(self): story[b], [ ("x", "ensure-task-exists", "released"), + ("x", "update-who-has", [a], []), ("x", "released", "fetch", "fetch", {}), ("gather-dependencies", a, {"x"}), ("x", "fetch", "flight", "flight", {}), diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 1b6f2e3f80b..3da0ecfe3ab 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1,11 +1,21 @@ import asyncio +import logging +from contextlib import contextmanager from itertools import chain import pytest +from distributed import Worker +from distributed.core import Status from distributed.protocol.serialize import Serialize from distributed.utils import recursive_to_dict -from distributed.utils_test import _LockedCommPool, assert_story, gen_cluster, inc +from distributed.utils_test import ( + _LockedCommPool, + assert_story, + captured_logger, + gen_cluster, + inc, +) from distributed.worker_state_machine import ( ExecuteFailureEvent, ExecuteSuccessEvent, @@ -16,12 +26,13 @@ SendMessageToScheduler, StateMachineEvent, TaskState, + TaskStateState, UniqueTaskHeap, merge_recs_instructions, ) -async def wait_for_state(key, state, dask_worker): +async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) -> None: while key not in dask_worker.tasks or dask_worker.tasks[key].state != state: await asyncio.sleep(0.005) @@ -245,28 +256,36 @@ def test_executefailure_to_dict(): assert ev3.traceback_text == "tb text" -@gen_cluster(client=True) -async def test_fetch_to_compute(c, s, a, b): - # Block ensure_communicating to ensure we indeed know that the task is in - # fetch and doesn't leave it accidentally - old_out_connections, b.total_out_connections = b.total_out_connections, 0 - old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0 +@contextmanager +def freeze_data_fetching(w: Worker): + """Prevent any task from transitioning from fetch to flight on the worker while + inside the context. - f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) - f2 = c.submit(inc, f1, workers=[b.address], key="f2") + This is not the same as setting the worker to Status=paused, which would also + inform the Scheduler and prevent further tasks to be enqueued on the worker. + """ + old_out_connections = w.total_out_connections + old_comm_threshold = w.comm_threshold_bytes + w.total_out_connections = 0 + w.comm_threshold_bytes = 0 + yield + w.total_out_connections = old_out_connections + w.comm_threshold_bytes = old_comm_threshold - await wait_for_state(f1.key, "fetch", b) - await a.close() - b.total_out_connections = old_out_connections - b.comm_threshold_bytes = old_comm_threshold +@gen_cluster(client=True) +async def test_fetch_to_compute(c, s, a, b): + with freeze_data_fetching(b): + f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) + f2 = c.submit(inc, f1, workers=[b.address], key="f2") + await wait_for_state(f1.key, "fetch", b) + await a.close() await f2 assert_story( b.log, - # FIXME: This log should be replaced with an - # StateMachineEvent/Instruction log + # FIXME: This log should be replaced with a StateMachineEvent log [ (f2.key, "compute-task", "released"), # This is a "please fetch" request. We don't have anything like @@ -285,23 +304,188 @@ async def test_fetch_to_compute(c, s, a, b): @gen_cluster(client=True) async def test_fetch_via_amm_to_compute(c, s, a, b): - # Block ensure_communicating to ensure we indeed know that the task is in - # fetch and doesn't leave it accidentally - old_out_connections, b.total_out_connections = b.total_out_connections, 0 - old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0 - - f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) + with freeze_data_fetching(b): + f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) + await f1 + s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test") + await wait_for_state(f1.key, "fetch", b) + await a.close() await f1 - s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test") - await wait_for_state(f1.key, "fetch", b) - await a.close() + assert_story( + b.log, + # FIXME: This log should be replaced with a StateMachineEvent log + [ + (f1.key, "ensure-task-exists", "released"), + (f1.key, "released", "fetch", "fetch", {}), + (f1.key, "compute-task", "fetch"), + (f1.key, "put-in-memory"), + ], + ) + - b.total_out_connections = old_out_connections - b.comm_threshold_bytes = old_comm_threshold +@pytest.mark.parametrize("as_deps", [False, True]) +@gen_cluster(client=True, nthreads=[("", 1)] * 3) +async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps): + """ + as_deps=True + 0. task x is a dependency of y1 and y2 + 1. scheduler calls handle_compute("y1", who_has={"x": [w2, w3]}) on w1 + 2. x transitions released -> fetch + 3. the network stack is busy, so x does not transition to flight yet. + 4. scheduler calls handle_compute("y2", who_has={"x": [w3]}) on w1 + 5. when x finally reaches the top of the data_needed heap, the w1 will not try + contacting w2 + + as_deps=False + 1. scheduler calls handle_acquire_replicas(who_has={"x": [w2, w3]}) on w1 + 2. x transitions released -> fetch + 3. the network stack is busy, so x does not transition to flight yet. + 4. scheduler calls handle_acquire_replicas(who_has={"x": [w3]}) on w1 + 5. when x finally reaches the top of the data_needed heap, the w1 will not try + contacting w2 + """ + x = (await c.scatter({"x": 1}, workers=[w2.address, w3.address], broadcast=True))[ + "x" + ] + with freeze_data_fetching(w1): + if as_deps: + y1 = c.submit(inc, x, key="y1", workers=[w1.address]) + else: + s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test") - await f1 + await wait_for_state("x", "fetch", w1) + assert w1.tasks["x"].who_has == {w2.address, w3.address} + + assert len(s.tasks["x"].who_has) == 2 + await w2.close() + while len(s.tasks["x"].who_has) > 1: + await asyncio.sleep(0.01) + + if as_deps: + y2 = c.submit(inc, x, key="y2", workers=[w1.address]) + else: + s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test") + + while w1.tasks["x"].who_has != {w3.address}: + await asyncio.sleep(0.01) + + # Jump-start ensure_communicating after freeze_data_fetching. + # This simulates an unrelated third worker moving out of in_flight_workers. + w1.status = Status.paused + w1.status = Status.running + + await wait_for_state("x", "memory", w1) + + assert_story( + w1.story("request-dep"), + [("request-dep", w3.address, {"x"})], + # This tests that there has been no attempt to contact w2. + # If the assumption being tested breaks, this will fail 50% of the times. + strict=True, + ) + + +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_fetch_to_missing(c, s, a, b): + """ + 1. task x is a dependency of y + 2. scheduler calls handle_compute("y", who_has={"x": [b]}) on a + 3. x transitions released -> fetch -> flight; a connects to b + 4. b responds it's busy. x transitions flight -> fetch + 5. The busy state triggers an RPC call to Scheduler.who_has + 6. the scheduler responds {"x": []}, because w1 in the meantime has lost the key. + 7. x is transitioned fetch -> missing + """ + x = await c.scatter({"x": 1}, workers=[b.address]) + b.total_in_connections = 0 + # Crucially, unlike with `c.submit(inc, x, workers=[a.address])`, the scheduler + # doesn't keep track of acquire-replicas requests, so it won't proactively inform a + # when we call remove_worker later on + s.request_acquire_replicas(a.address, ["x"], stimulus_id="test") + + # state will flip-flop between fetch and flight every 150ms, which is the retry + # period for busy workers. + await wait_for_state("x", "fetch", a) + assert b.address in a.busy_workers + + # Sever connection between b and s, but not between b and a. + # If a tries fetching from b after this, b will keep responding {status: busy}. + b.periodic_callbacks["heartbeat"].stop() + await s.remove_worker(b.address, close=False, stimulus_id="test") + + await wait_for_state("x", "missing", a) + + assert_story( + a.story("x"), + [ + ("x", "ensure-task-exists", "released"), + ("x", "update-who-has", [b.address], []), + ("x", "released", "fetch", "fetch", {}), + ("gather-dependencies", b.address, {"x"}), + ("x", "fetch", "flight", "flight", {}), + ("request-dep", b.address, {"x"}), + ("busy-gather", b.address, {"x"}), + ("x", "flight", "fetch", "fetch", {}), + ("x", "update-who-has", [], [b.address]), # Called Scheduler.who_has + ("x", "fetch", "missing", "missing", {}), + ], + # There may be a round of find_missing() after this. + # Due to timings, there also may be multiple attempts to connect from a to b. + strict=False, + ) + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_self_denounce_missing_data(c, s, a): + x = c.submit(inc, 1, key="x") + await x + + y = c.submit(inc, x, key="y") + # Wait until the scheduler is forwarding the compute-task to the worker, but before + # the worker has received it + while "y" not in s.tasks or s.tasks["y"].state != "processing": + await asyncio.sleep(0) + + # Wipe x from the worker. This simulates the scheduler calling + # delete_worker_data(a.address, ["x"]), but the RPC call has not returned yet. + a.handle_free_keys(keys=["x"], stimulus_id="test") + + # At the same time, + # a->s: {"op": "release-worker-data", keys=["x"]} + # s->a: {"op": "compute-task", key="y", who_has={"x": [a.address]}} + + with captured_logger("distributed.worker", level=logging.DEBUG) as logger: + # The compute-task request for y gets stuck in waiting state, because x is not + # in memory. However, moments later the scheduler is informed that a is missing + # x; so it releases y and then recomputes both x and y. + assert await y == 3 + + assert ( + f"Scheduler claims worker {a.address} holds data for task " + ", which is not true." + ) in logger.getvalue() + + assert_story( + a.story("x", "y"), + # Note: omitted uninteresting events + [ + # {"op": "compute-task", key="y", who_has={"x": [a.address]}} + # reaches the worker + ("y", "compute-task", "released"), + ("x", "ensure-task-exists", "released"), + ("x", "released", "missing", "missing", {}), + # {"op": "release-worker-data", keys=["x"]} reaches the scheduler, which + # reacts by releasing both keys and then recomputing them + ("y", "release-key"), + ("x", "release-key"), + ("x", "compute-task", "released"), + ("x", "executing", "memory", "memory", {}), + ("y", "compute-task", "released"), + ("y", "executing", "memory", "memory", {}), + ], + ) @gen_cluster(client=True) diff --git a/distributed/worker.py b/distributed/worker.py index 1795bb80c59..1accfa5f2dd 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1887,8 +1887,12 @@ def handle_acquire_replicas( if ts.state != "memory": recommendations[ts] = "fetch" - self.update_who_has(who_has) + recommendations, instructions = merge_recs_instructions( + (recommendations, []), + self._update_who_has(who_has, stimulus_id=stimulus_id), + ) self.transitions(recommendations, stimulus_id=stimulus_id) + self._handle_instructions(instructions) if self.validate: for key in keys: @@ -1991,7 +1995,10 @@ def handle_compute_task( for dep_key, value in nbytes.items(): self.tasks[dep_key].nbytes = value - self.update_who_has(who_has) + recommendations, instructions = merge_recs_instructions( + (recommendations, instructions), + self._update_who_has(who_has, stimulus_id=stimulus_id), + ) else: # pragma: nocover raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}") @@ -2947,10 +2954,14 @@ def stateof(self, key: str) -> dict[str, Any]: "data": key in self.data, } - def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]: - """Return all transitions involving one or more tasks""" - keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} - return worker_story(keys, self.log) + def story(self, *keys_or_tasks_or_tags: str | TaskState) -> list[tuple]: + """Return all records from the transitions log involving one or more tasks; + it can also be used for arbitrary non-transition tags. + """ + keys_or_tags = { + e.key if isinstance(e, TaskState) else e for e in keys_or_tasks_or_tags + } + return worker_story(keys_or_tags, self.log) async def get_story(self, keys=None): return self.story(*keys) @@ -3143,8 +3154,14 @@ def _select_keys_for_gather( while tasks: ts = tasks.peek() - if ts.state != "fetch" or ts.key in all_keys_to_gather: + if ( + ts.state != "fetch" # Do not acquire the same key twice if multiple workers holds replicas + or ts.key in all_keys_to_gather + # A replica is still available (otherwise status would not be 'fetch' + # anymore), but not on this worker. See _update_who_has(). + or worker not in ts.who_has + ): tasks.pop() continue if total_bytes + ts.get_nbytes() > self.target_message_size: @@ -3370,7 +3387,12 @@ def done_event(): who_has = await retry_operation( self.scheduler.who_has, keys=refresh_who_has ) - self.update_who_has(who_has) + refresh_stimulus_id = f"gather-dep-busy-{time()}" + recommendations, instructions = self._update_who_has( + who_has, stimulus_id=refresh_stimulus_id + ) + self.transitions(recommendations, stimulus_id=refresh_stimulus_id) + self._handle_instructions(instructions) @log_errors def _readd_busy_worker(self, worker: str) -> None: @@ -3394,12 +3416,15 @@ async def find_missing(self) -> None: self.scheduler.who_has, keys=[ts.key for ts in self._missing_dep_flight], ) - self.update_who_has(who_has) - recommendations: Recs = {} + recommendations, instructions = self._update_who_has( + who_has, stimulus_id=stimulus_id + ) for ts in self._missing_dep_flight: if ts.who_has: + assert ts not in recommendations recommendations[ts] = "fetch" self.transitions(recommendations, stimulus_id=stimulus_id) + self._handle_instructions(instructions) finally: self._find_missing_running = False @@ -3408,34 +3433,90 @@ async def find_missing(self) -> None: "find-missing" ].callback_time = self.periodic_callbacks["heartbeat"].callback_time - def update_who_has(self, who_has: dict[str, Collection[str]]) -> None: - try: - for dep, workers in who_has.items(): - if not workers: - continue + def _update_who_has( + self, who_has: Mapping[str, Collection[str]], *, stimulus_id: str + ) -> RecsInstrs: + recs: Recs = {} + instructions: Instructions = [] - if dep in self.tasks: - dep_ts = self.tasks[dep] - if self.address in workers and self.tasks[dep].state != "memory": - logger.debug( - "Scheduler claims worker %s holds data for task %s which is not true.", - self.name, - dep, - ) - # Do not mutate the input dict. That's rude - workers = set(workers) - {self.address} - dep_ts.who_has.update(workers) + for key, workers in who_has.items(): + ts = self.tasks.get(key) + if not ts: + # The worker sent a refresh-who-has request to the scheduler but, by the + # time the answer comes back, some of the keys have been forgotten. + continue + workers = set(workers) + + if self.address in workers: + workers.remove(self.address) + if ts.state != "memory": + # This can happen if the worker recently released a key, but the + # scheduler hasn't processed the release-worker-data message yet. + # It is not necessary to send a missing-data message back to the + # scheduler here, because it is guaranteed to reach the scheduler + # after the release-worker-data from the same BatchedSend channel. + logger.debug( + "Scheduler claims worker %s holds data for task %s, " + "which is not true.", + self.address, + ts, + ) - for worker in workers: - self.has_what[worker].add(dep) - self.data_needed_per_worker[worker].push(dep_ts) - except Exception as e: # pragma: no cover - logger.exception(e) - if LOG_PDB: - import pdb + if ts.who_has == workers: + continue - pdb.set_trace() - raise + new_workers = workers - ts.who_has + del_workers = ts.who_has - workers + + def max2(workers: set[str]) -> list[str]: + if len(workers) < 3: + return list(workers) + it = iter(workers) + return [next(it), next(it), f"({len(workers) - 2} more)"] + + self.log.append( + ( + key, + "update-who-has", + max2(new_workers), + max2(del_workers), + stimulus_id, + time(), + ) + ) + + for worker in del_workers: + self.has_what[worker].discard(key) + # Can't remove from self.data_needed_per_worker; there is logic + # in _select_keys_for_gather to deal with this + + for worker in new_workers: + self.has_what[worker].add(key) + if ts.state == "fetch": + self.data_needed_per_worker[worker].push(ts) + # All workers which previously held a replica of the key may either + # be in flight or busy. Kick off ensure_communicating to try + # fetching the data from the new worker. + # There are other reasons why a task may be sitting in 'fetch' state + # - e.g. if we're over the total_out_connections or the + # comm_threshold_bytes limit, or if we're paused. We're deliberately + # NOT testing the full gamut of use cases for the sake of simplicity + # and robustness and just stating that ensure_communicating *may* + # return new GatherDep events now. + instructions.append( + EnsureCommunicatingAfterTransitions(stimulus_id=stimulus_id) + ) + + ts.who_has = workers + # currently fetching -> can no longer be fetched -> transition to missing + # any other state -> eventually, possibly, the task may transition to fetch + # or missing, at which point the relevant transitions will test who_has that + # we just updated. e.g. see the various transitions to fetch, which + # instead recommend transitioning to missing if who_has is empty. + if not workers and ts.state == "fetch": + recs[ts] = "missing" + + return recs, instructions def handle_steal_request(self, key: str, stimulus_id: str) -> None: # There may be a race condition between stealing and releasing a task. @@ -4259,6 +4340,7 @@ def validate_state(self): assert ts.state is not None # check that worker has task for worker in ts.who_has: + assert worker != self.address assert ts.key in self.has_what[worker] # check that deps have a set state and that dependency<->dependent links # are there @@ -4283,6 +4365,7 @@ def validate_state(self): # FIXME https://github.com/dask/distributed/issues/6319 # assert self.waiting_for_data_count == waiting_for_data_count for worker, keys in self.has_what.items(): + assert worker != self.address for k in keys: assert worker in self.tasks[k].who_has From e1c439093dba2e92d2baf74e06aa45cf05247ebb Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 25 May 2022 11:19:46 +0100 Subject: [PATCH 2/9] Moved to #6441 --- distributed/scheduler.py | 6 ++---- distributed/tests/test_cancelled_state.py | 4 ++-- distributed/tests/test_worker.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index cb5318a8231..c08a79d73a8 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4141,9 +4141,7 @@ def stimulus_retry(self, keys, client=None): return tuple(seen) @log_errors - async def remove_worker( - self, address: str, *, stimulus_id: str, safe: bool = False, close: bool = True - ) -> Literal["OK", "already-removed"]: + async def remove_worker(self, address, stimulus_id, safe=False, close=True): """ Remove worker from cluster @@ -4152,7 +4150,7 @@ async def remove_worker( state. """ if self.status == Status.closed: - return "already-removed" + return address = self.coerce_address(address) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 409d530277c..a4dca9e2872 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -300,7 +300,7 @@ async def get_data(self, comm, *args, **kwargs): s.set_restrictions({fut1.key: [a.address, b.address]}) # It is removed, i.e. get_data is guaranteed to fail and f1 is scheduled # to be recomputed on B - await s.remove_worker(a.address, stimulus_id="foo", close=False, safe=True) + await s.remove_worker(a.address, "foo", close=False, safe=True) while not b.tasks[fut1.key].state == "resumed": await asyncio.sleep(0.01) @@ -438,7 +438,7 @@ async def get_data(self, comm, *args, **kwargs): f3.key: {w2.address}, } ) - await s.remove_worker(w1.address, stimulus_id="stim-id") + await s.remove_worker(w1.address, "stim-id") await wait_for_state(f3.key, "resumed", w2) assert_story( diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 6d882f0b2a9..938d2679a12 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1741,7 +1741,7 @@ async def close(self): ) as wlogger, captured_logger( "distributed.scheduler", level=logging.WARNING ) as slogger: - await s.remove_worker(a.address, stimulus_id="foo") + await s.remove_worker(a.address, "foo") assert not s.workers # Wait until the close signal reaches the worker and it starts shutting down. From 4b3adf8a83b23adc96ce7bba22f0d8c0820e893b Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 25 May 2022 11:43:48 +0100 Subject: [PATCH 3/9] Moved to #6442 --- distributed/_stories.py | 17 ++++++----------- distributed/worker.py | 12 ++++-------- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/distributed/_stories.py b/distributed/_stories.py index e461ae31ca2..d17e54df53f 100644 --- a/distributed/_stories.py +++ b/distributed/_stories.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import Iterable @@ -21,29 +19,26 @@ def scheduler_story(keys: set, transition_log: Iterable) -> list: return [t for t in transition_log if t[0] in keys or keys.intersection(t[3])] -def worker_story(keys_or_tags: set[str], log: Iterable) -> list: +def worker_story(keys: set, log: Iterable) -> list: """Creates a story from the worker log given a set of keys describing tasks or stimuli. Parameters ---------- - keys_or_tags : set[str] - A set of task `keys` or arbitrary tags from the event log, e.g. `stimulus_id`'s + keys : set + A set of task `keys` or `stimulus_id`'s log : iterable The worker log Returns ------- - story : list[str] + story : list """ return [ msg for msg in log - if any(key in msg for key in keys_or_tags) + if any(key in msg for key in keys) or any( - key in c - for key in keys_or_tags - for c in msg - if isinstance(c, (tuple, list, set)) + key in c for key in keys for c in msg if isinstance(c, (tuple, list, set)) ) ] diff --git a/distributed/worker.py b/distributed/worker.py index 1accfa5f2dd..bc5cac2dc25 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2954,14 +2954,10 @@ def stateof(self, key: str) -> dict[str, Any]: "data": key in self.data, } - def story(self, *keys_or_tasks_or_tags: str | TaskState) -> list[tuple]: - """Return all records from the transitions log involving one or more tasks; - it can also be used for arbitrary non-transition tags. - """ - keys_or_tags = { - e.key if isinstance(e, TaskState) else e for e in keys_or_tasks_or_tags - } - return worker_story(keys_or_tags, self.log) + def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]: + """Return all transitions involving one or more tasks""" + keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} + return worker_story(keys, self.log) async def get_story(self, keys=None): return self.story(*keys) From 912b75eef6e4c979675f35a4b9b62ffa3c650b0e Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 25 May 2022 12:11:58 +0100 Subject: [PATCH 4/9] Refactor helper functions --- distributed/tests/test_worker.py | 26 +----- .../tests/test_worker_state_machine.py | 27 +----- distributed/utils_test.py | 86 +++++++++++++++++++ 3 files changed, 90 insertions(+), 49 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 938d2679a12..8f385d3105a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -46,6 +46,8 @@ from distributed.protocol import pickle from distributed.scheduler import Scheduler from distributed.utils_test import ( + BlockedGatherDep, + BlockedGetData, TaskStateMetadataPlugin, _LockedCommPool, assert_story, @@ -3138,30 +3140,6 @@ async def test_task_flight_compute_oserror(c, s, a, b): assert_story(sum_story, expected_sum_story, strict=True) -class BlockedGatherDep(Worker): - def __init__(self, *args, **kwargs): - self.in_gather_dep = asyncio.Event() - self.block_gather_dep = asyncio.Event() - super().__init__(*args, **kwargs) - - async def gather_dep(self, *args, **kwargs): - self.in_gather_dep.set() - await self.block_gather_dep.wait() - return await super().gather_dep(*args, **kwargs) - - -class BlockedGetData(Worker): - def __init__(self, *args, **kwargs): - self.in_get_data = asyncio.Event() - self.block_get_data = asyncio.Event() - super().__init__(*args, **kwargs) - - async def get_data(self, comm, *args, **kwargs): - self.in_get_data.set() - await self.block_get_data.wait() - return await super().get_data(comm, *args, **kwargs) - - @gen_cluster(client=True, nthreads=[]) async def test_gather_dep_cancelled_rescheduled(c, s): """At time of writing, the gather_dep implementation filtered tasks again diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 3da0ecfe3ab..200cb3d1316 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1,18 +1,17 @@ import asyncio import logging -from contextlib import contextmanager from itertools import chain import pytest from distributed import Worker -from distributed.core import Status from distributed.protocol.serialize import Serialize from distributed.utils import recursive_to_dict from distributed.utils_test import ( _LockedCommPool, assert_story, captured_logger, + freeze_data_fetching, gen_cluster, inc, ) @@ -256,23 +255,6 @@ def test_executefailure_to_dict(): assert ev3.traceback_text == "tb text" -@contextmanager -def freeze_data_fetching(w: Worker): - """Prevent any task from transitioning from fetch to flight on the worker while - inside the context. - - This is not the same as setting the worker to Status=paused, which would also - inform the Scheduler and prevent further tasks to be enqueued on the worker. - """ - old_out_connections = w.total_out_connections - old_comm_threshold = w.comm_threshold_bytes - w.total_out_connections = 0 - w.comm_threshold_bytes = 0 - yield - w.total_out_connections = old_out_connections - w.comm_threshold_bytes = old_comm_threshold - - @gen_cluster(client=True) async def test_fetch_to_compute(c, s, a, b): with freeze_data_fetching(b): @@ -349,7 +331,7 @@ async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps): x = (await c.scatter({"x": 1}, workers=[w2.address, w3.address], broadcast=True))[ "x" ] - with freeze_data_fetching(w1): + with freeze_data_fetching(w1, jump_start=True): if as_deps: y1 = c.submit(inc, x, key="y1", workers=[w1.address]) else: @@ -371,11 +353,6 @@ async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps): while w1.tasks["x"].who_has != {w3.address}: await asyncio.sleep(0.01) - # Jump-start ensure_communicating after freeze_data_fetching. - # This simulates an unrelated third worker moving out of in_flight_workers. - w1.status = Status.paused - w1.status = Status.running - await wait_for_state("x", "memory", w1) assert_story( diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 87d5c83958b..3f1e5e746ae 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2153,3 +2153,89 @@ def raises_with_cause( assert re.search( match_cause, str(exc.__cause__) ), f"Pattern ``{match_cause}`` not found in ``{exc.__cause__}``" + + +class BlockedGatherDep(Worker): + """A Worker that sets event `in_gather_dep` the first time it enters the gather_dep + method and then does not initiate any comms, thus leaving the task(s) in flight + indefinitely, until the test sets `block_gather_dep` + + Example + ------- + .. code-block:: python + + @gen_test() + async def test1(s, a, b): + async with BlockedGatherDep(s.address) as x: + # [do something to cause x to fetch data from a or b] + await x.in_gather_dep.wait() + # [do something that must happen while the tasks are in flight] + x.block_gather_dep.set() + # [from this moment on, x is a regular worker] + + See also + -------- + BlockedGetData + """ + + def __init__(self, *args, **kwargs): + self.in_gather_dep = asyncio.Event() + self.block_gather_dep = asyncio.Event() + super().__init__(*args, **kwargs) + + async def gather_dep(self, *args, **kwargs): + self.in_gather_dep.set() + await self.block_gather_dep.wait() + return await super().gather_dep(*args, **kwargs) + + +class BlockedGetData(Worker): + """A Worker that sets event `in_get_data` the first time it enters the get_data + method and then does not answer the comms, thus leaving the task(s) in flight + indefinitely, until the test sets `block_get_data` + + See also + -------- + BlockedGatherDep + """ + + def __init__(self, *args, **kwargs): + self.in_get_data = asyncio.Event() + self.block_get_data = asyncio.Event() + super().__init__(*args, **kwargs) + + async def get_data(self, comm, *args, **kwargs): + self.in_get_data.set() + await self.block_get_data.wait() + return await super().get_data(comm, *args, **kwargs) + + +@contextmanager +def freeze_data_fetching(w: Worker, *, jump_start: bool = False): + """Prevent any task from transitioning from fetch to flight on the worker while + inside the context, simulating a situation where the worker's network comms are + saturated. + + This is not the same as setting the worker to Status=paused, which would also + inform the Scheduler and prevent further tasks to be enqueued on the worker. + + Parameters + ---------- + w: Worker + The Worker on which tasks will not transition from fetch to flight + jump_start: bool + If False, tasks will remain in fetch state after exiting the context, until + something else triggers ensure_communicating. + If True, trigger ensure_communicating on exit; this simulates e.g. an unrelated + worker moving out of in_flight_workers. + """ + old_out_connections = w.total_out_connections + old_comm_threshold = w.comm_threshold_bytes + w.total_out_connections = 0 + w.comm_threshold_bytes = 0 + yield + w.total_out_connections = old_out_connections + w.comm_threshold_bytes = old_comm_threshold + if jump_start: + w.status = Status.paused + w.status = Status.running From 436f4b6f1d43479295774b51b8e912a6f1785774 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 25 May 2022 12:35:22 +0100 Subject: [PATCH 5/9] test update_who_has kicks off ensure_communicating on new replicas --- .../tests/test_worker_state_machine.py | 46 ++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 200cb3d1316..7fa7f963ec1 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -4,10 +4,11 @@ import pytest -from distributed import Worker +from distributed import Worker, wait from distributed.protocol.serialize import Serialize from distributed.utils import recursive_to_dict from distributed.utils_test import ( + BlockedGetData, _LockedCommPool, assert_story, captured_logger, @@ -414,6 +415,49 @@ async def test_fetch_to_missing(c, s, a, b): ) +@gen_cluster(client=True) +async def test_new_replica_while_all_workers_in_flight(c, s, w1, w2): + """A task is stuck in 'fetch' state because all workers that hold a replica are in + flight. While in this state, a new replica appears on a different worker and the + scheduler informs the waiting worker through a new acquire-replicas or + compute-task op. + + In real life, this will typically happen when the Active Memory Manager replicates a + key to multiple workers and some workers are much faster than others to acquire it, + due to unrelated tasks being in flight, so 2 seconds later the AMM reiterates the + request, passing a larger who_has. + + Test that, when this happens, the task is immediately acquired from the new worker, + without waiting for the original replica holders to get out of flight. + """ + async with BlockedGetData(s.address) as w3: + x = c.submit(inc, 1, key="x", workers=[w3.address]) + y = c.submit(inc, 2, key="y", workers=[w3.address]) + await wait([x, y]) + s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test") + await w3.in_get_data.wait() + assert w1.tasks["x"].state == "flight" + s.request_acquire_replicas(w1.address, ["y"], stimulus_id="test") + # This cannot progress beyond fetch because w3 is already in flight + await wait_for_state("y", "fetch", w1) + + # Simulate that the AMM also requires that w2 acquires a replica of x. + # The replica lands on w2 soon afterwards, while w3->w1 comms remain blocked by + # unrelated transfers (x in our case). + w2.update_data({"y": 3}, report=True) + ws2 = s.workers[w2.address] + while ws2 not in s.tasks["y"].who_has: + await asyncio.sleep(0.01) + + # 2 seconds later, the AMM reiterates that w1 should acquire a replica of y + s.request_acquire_replicas(w1.address, ["y"], stimulus_id="test") + await wait_for_state("y", "memory", w1) + + # Finally let the other worker to get out of flight + w3.block_get_data.set() + await wait_for_state("x", "memory", w1) + + @gen_cluster(client=True, nthreads=[("", 1)]) async def test_self_denounce_missing_data(c, s, a): x = c.submit(inc, 1, key="x") From 1d35b03efacacef8fd206f4313ccaa228af7fe36 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 25 May 2022 14:56:38 +0100 Subject: [PATCH 6/9] Remove transitions from update_who_has --- distributed/tests/test_stories.py | 1 - distributed/tests/test_worker.py | 1 - .../tests/test_worker_state_machine.py | 15 ++- distributed/worker.py | 91 ++++--------------- 4 files changed, 26 insertions(+), 82 deletions(-) diff --git a/distributed/tests/test_stories.py b/distributed/tests/test_stories.py index 0500bc44a48..ec81ddaff8b 100644 --- a/distributed/tests/test_stories.py +++ b/distributed/tests/test_stories.py @@ -156,7 +156,6 @@ async def test_worker_story_with_deps(c, s, a, b): assert stimulus_ids == {"compute-task"} expected = [ ("dep", "ensure-task-exists", "released"), - ("dep", "update-who-has", [a.address], []), ("dep", "released", "fetch", "fetch", {}), ("gather-dependencies", a.address, {"dep"}), ("dep", "fetch", "flight", "flight", {}), diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 8f385d3105a..05f0c98d5d6 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2767,7 +2767,6 @@ def __getstate__(self): story[b], [ ("x", "ensure-task-exists", "released"), - ("x", "update-who-has", [a], []), ("x", "released", "fetch", "fetch", {}), ("gather-dependencies", a, {"x"}), ("x", "fetch", "flight", "flight", {}), diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 7fa7f963ec1..9902935ff71 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -318,7 +318,7 @@ async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps): 2. x transitions released -> fetch 3. the network stack is busy, so x does not transition to flight yet. 4. scheduler calls handle_compute("y2", who_has={"x": [w3]}) on w1 - 5. when x finally reaches the top of the data_needed heap, the w1 will not try + 5. when x finally reaches the top of the data_needed heap, w1 will not try contacting w2 as_deps=False @@ -326,12 +326,16 @@ async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps): 2. x transitions released -> fetch 3. the network stack is busy, so x does not transition to flight yet. 4. scheduler calls handle_acquire_replicas(who_has={"x": [w3]}) on w1 - 5. when x finally reaches the top of the data_needed heap, the w1 will not try + 5. when x finally reaches the top of the data_needed heap, w1 will not try contacting w2 """ x = (await c.scatter({"x": 1}, workers=[w2.address, w3.address], broadcast=True))[ "x" ] + + # Make sure find_missing is not involved + w1.periodic_callbacks["find-missing"].stop() + with freeze_data_fetching(w1, jump_start=True): if as_deps: y1 = c.submit(inc, x, key="y1", workers=[w1.address]) @@ -355,7 +359,6 @@ async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps): await asyncio.sleep(0.01) await wait_for_state("x", "memory", w1) - assert_story( w1.story("request-dep"), [("request-dep", w3.address, {"x"})], @@ -399,14 +402,12 @@ async def test_fetch_to_missing(c, s, a, b): a.story("x"), [ ("x", "ensure-task-exists", "released"), - ("x", "update-who-has", [b.address], []), ("x", "released", "fetch", "fetch", {}), ("gather-dependencies", b.address, {"x"}), ("x", "fetch", "flight", "flight", {}), ("request-dep", b.address, {"x"}), ("busy-gather", b.address, {"x"}), ("x", "flight", "fetch", "fetch", {}), - ("x", "update-who-has", [], [b.address]), # Called Scheduler.who_has ("x", "fetch", "missing", "missing", {}), ], # There may be a round of find_missing() after this. @@ -415,6 +416,7 @@ async def test_fetch_to_missing(c, s, a, b): ) +@pytest.mark.skip(reason="TODO link issue") @gen_cluster(client=True) async def test_new_replica_while_all_workers_in_flight(c, s, w1, w2): """A task is stuck in 'fetch' state because all workers that hold a replica are in @@ -430,6 +432,9 @@ async def test_new_replica_while_all_workers_in_flight(c, s, w1, w2): Test that, when this happens, the task is immediately acquired from the new worker, without waiting for the original replica holders to get out of flight. """ + # Make sure find_missing is not involved + w1.periodic_callbacks["find-missing"].stop() + async with BlockedGetData(s.address) as w3: x = c.submit(inc, 1, key="x", workers=[w3.address]) y = c.submit(inc, 2, key="y", workers=[w3.address]) diff --git a/distributed/worker.py b/distributed/worker.py index bc5cac2dc25..44e97e21c2e 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1887,12 +1887,8 @@ def handle_acquire_replicas( if ts.state != "memory": recommendations[ts] = "fetch" - recommendations, instructions = merge_recs_instructions( - (recommendations, []), - self._update_who_has(who_has, stimulus_id=stimulus_id), - ) + self._update_who_has(who_has) self.transitions(recommendations, stimulus_id=stimulus_id) - self._handle_instructions(instructions) if self.validate: for key in keys: @@ -1995,10 +1991,7 @@ def handle_compute_task( for dep_key, value in nbytes.items(): self.tasks[dep_key].nbytes = value - recommendations, instructions = merge_recs_instructions( - (recommendations, instructions), - self._update_who_has(who_has, stimulus_id=stimulus_id), - ) + self._update_who_has(who_has) else: # pragma: nocover raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}") @@ -2105,8 +2098,7 @@ def transition_released_waiting( if dep_ts.state != "memory": ts.waiting_for_data.add(dep_ts) dep_ts.waiters.add(ts) - if dep_ts.state not in {"fetch", "flight"}: - recommendations[dep_ts] = "fetch" + recommendations[dep_ts] = "fetch" if ts.waiting_for_data: self.waiting_for_data_count += 1 @@ -2693,7 +2685,7 @@ def _transition( assert not args finish, *args = finish # type: ignore - if ts is None or ts.state == finish: + if ts.state == finish: return {}, [] start = ts.state @@ -2996,8 +2988,11 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: if ts.state != "fetch" or ts.key in all_keys_to_gather: continue + if not ts.who_has: + recommendations[ts] = "missing" + continue + if self.validate: - assert ts.who_has assert self.address not in ts.who_has workers = [ @@ -3383,12 +3378,7 @@ def done_event(): who_has = await retry_operation( self.scheduler.who_has, keys=refresh_who_has ) - refresh_stimulus_id = f"gather-dep-busy-{time()}" - recommendations, instructions = self._update_who_has( - who_has, stimulus_id=refresh_stimulus_id - ) - self.transitions(recommendations, stimulus_id=refresh_stimulus_id) - self._handle_instructions(instructions) + self._update_who_has(who_has) @log_errors def _readd_busy_worker(self, worker: str) -> None: @@ -3412,15 +3402,12 @@ async def find_missing(self) -> None: self.scheduler.who_has, keys=[ts.key for ts in self._missing_dep_flight], ) - recommendations, instructions = self._update_who_has( - who_has, stimulus_id=stimulus_id - ) + self._update_who_has(who_has) + recommendations: Recs = {} for ts in self._missing_dep_flight: if ts.who_has: - assert ts not in recommendations recommendations[ts] = "fetch" self.transitions(recommendations, stimulus_id=stimulus_id) - self._handle_instructions(instructions) finally: self._find_missing_running = False @@ -3429,12 +3416,7 @@ async def find_missing(self) -> None: "find-missing" ].callback_time = self.periodic_callbacks["heartbeat"].callback_time - def _update_who_has( - self, who_has: Mapping[str, Collection[str]], *, stimulus_id: str - ) -> RecsInstrs: - recs: Recs = {} - instructions: Instructions = [] - + def _update_who_has(self, who_has: Mapping[str, Collection[str]]) -> None: for key, workers in who_has.items(): ts = self.tasks.get(key) if not ts: @@ -3461,58 +3443,17 @@ def _update_who_has( if ts.who_has == workers: continue - new_workers = workers - ts.who_has - del_workers = ts.who_has - workers - - def max2(workers: set[str]) -> list[str]: - if len(workers) < 3: - return list(workers) - it = iter(workers) - return [next(it), next(it), f"({len(workers) - 2} more)"] - - self.log.append( - ( - key, - "update-who-has", - max2(new_workers), - max2(del_workers), - stimulus_id, - time(), - ) - ) - - for worker in del_workers: + for worker in ts.who_has - workers: self.has_what[worker].discard(key) # Can't remove from self.data_needed_per_worker; there is logic # in _select_keys_for_gather to deal with this - for worker in new_workers: + for worker in workers - ts.who_has: self.has_what[worker].add(key) if ts.state == "fetch": self.data_needed_per_worker[worker].push(ts) - # All workers which previously held a replica of the key may either - # be in flight or busy. Kick off ensure_communicating to try - # fetching the data from the new worker. - # There are other reasons why a task may be sitting in 'fetch' state - # - e.g. if we're over the total_out_connections or the - # comm_threshold_bytes limit, or if we're paused. We're deliberately - # NOT testing the full gamut of use cases for the sake of simplicity - # and robustness and just stating that ensure_communicating *may* - # return new GatherDep events now. - instructions.append( - EnsureCommunicatingAfterTransitions(stimulus_id=stimulus_id) - ) ts.who_has = workers - # currently fetching -> can no longer be fetched -> transition to missing - # any other state -> eventually, possibly, the task may transition to fetch - # or missing, at which point the relevant transitions will test who_has that - # we just updated. e.g. see the various transitions to fetch, which - # instead recommend transitioning to missing if who_has is empty. - if not workers and ts.state == "fetch": - recs[ts] = "missing" - - return recs, instructions def handle_steal_request(self, key: str, stimulus_id: str) -> None: # There may be a race condition between stealing and releasing a task. @@ -4244,8 +4185,8 @@ def validate_task_fetch(self, ts): assert self.address not in ts.who_has assert not ts.done assert ts in self.data_needed - assert ts.who_has - + # Note: ts.who_has may be have been emptied by _update_who_has, but the task + # won't transition to missing until it reaches the top of the data_needed heap. for w in ts.who_has: assert ts.key in self.has_what[w] assert ts in self.data_needed_per_worker[w] From ce93c76cd475a7de012d39a6b16676b5ebac842b Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 25 May 2022 15:43:27 +0100 Subject: [PATCH 7/9] link follow-up --- distributed/tests/test_worker_state_machine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 9902935ff71..be25835a81a 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -416,7 +416,7 @@ async def test_fetch_to_missing(c, s, a, b): ) -@pytest.mark.skip(reason="TODO link issue") +@pytest.mark.skip(reason="https://github.com/dask/distributed/issues/6446") @gen_cluster(client=True) async def test_new_replica_while_all_workers_in_flight(c, s, w1, w2): """A task is stuck in 'fetch' state because all workers that hold a replica are in From c3b02afed791ecd47e0747ffffeb8ec8c5846b19 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 26 May 2022 16:01:28 +0100 Subject: [PATCH 8/9] typo --- distributed/tests/test_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 65402739161..84195bbfce3 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2620,7 +2620,7 @@ def sink(a, b, *args): if peer_addr == a.address and msg["op"] == "get_data": break - # Provoke an "impossible transision exception" + # Provoke an "impossible transition exception" # By choosing a state which doesn't exist we're not running into validation # errors and the state machine should raise if we want to transition from # fetch to memory From 261a3e464a4d26a63c805120fecdcda15ba7c281 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 26 May 2022 16:02:22 +0100 Subject: [PATCH 9/9] Remove test_self_denounce_missing_data --- .../tests/test_worker_state_machine.py | 53 ------------------- distributed/worker.py | 14 ++--- 2 files changed, 8 insertions(+), 59 deletions(-) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index be25835a81a..d852a42499a 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1,5 +1,4 @@ import asyncio -import logging from itertools import chain import pytest @@ -11,7 +10,6 @@ BlockedGetData, _LockedCommPool, assert_story, - captured_logger, freeze_data_fetching, gen_cluster, inc, @@ -463,57 +461,6 @@ async def test_new_replica_while_all_workers_in_flight(c, s, w1, w2): await wait_for_state("x", "memory", w1) -@gen_cluster(client=True, nthreads=[("", 1)]) -async def test_self_denounce_missing_data(c, s, a): - x = c.submit(inc, 1, key="x") - await x - - y = c.submit(inc, x, key="y") - # Wait until the scheduler is forwarding the compute-task to the worker, but before - # the worker has received it - while "y" not in s.tasks or s.tasks["y"].state != "processing": - await asyncio.sleep(0) - - # Wipe x from the worker. This simulates the scheduler calling - # delete_worker_data(a.address, ["x"]), but the RPC call has not returned yet. - a.handle_free_keys(keys=["x"], stimulus_id="test") - - # At the same time, - # a->s: {"op": "release-worker-data", keys=["x"]} - # s->a: {"op": "compute-task", key="y", who_has={"x": [a.address]}} - - with captured_logger("distributed.worker", level=logging.DEBUG) as logger: - # The compute-task request for y gets stuck in waiting state, because x is not - # in memory. However, moments later the scheduler is informed that a is missing - # x; so it releases y and then recomputes both x and y. - assert await y == 3 - - assert ( - f"Scheduler claims worker {a.address} holds data for task " - ", which is not true." - ) in logger.getvalue() - - assert_story( - a.story("x", "y"), - # Note: omitted uninteresting events - [ - # {"op": "compute-task", key="y", who_has={"x": [a.address]}} - # reaches the worker - ("y", "compute-task", "released"), - ("x", "ensure-task-exists", "released"), - ("x", "released", "missing", "missing", {}), - # {"op": "release-worker-data", keys=["x"]} reaches the scheduler, which - # reacts by releasing both keys and then recomputing them - ("y", "release-key"), - ("x", "release-key"), - ("x", "compute-task", "released"), - ("x", "executing", "memory", "memory", {}), - ("y", "compute-task", "released"), - ("y", "executing", "memory", "memory", {}), - ], - ) - - @gen_cluster(client=True) async def test_cancelled_while_in_flight(c, s, a, b): event = asyncio.Event() diff --git a/distributed/worker.py b/distributed/worker.py index 44e97e21c2e..d9c64e8e3be 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3427,13 +3427,15 @@ def _update_who_has(self, who_has: Mapping[str, Collection[str]]) -> None: if self.address in workers: workers.remove(self.address) + # This can only happen if rebalance() recently asked to release a key, + # but the RPC call hasn't returned yet. rebalance() is flagged as not + # being safe to run while the cluster is not at rest and has already + # been penned in to be redesigned on top of the AMM. + # It is not necessary to send a message back to the + # scheduler here, because it is guaranteed that there's already a + # release-worker-data message in transit to it. if ts.state != "memory": - # This can happen if the worker recently released a key, but the - # scheduler hasn't processed the release-worker-data message yet. - # It is not necessary to send a missing-data message back to the - # scheduler here, because it is guaranteed to reach the scheduler - # after the release-worker-data from the same BatchedSend channel. - logger.debug( + logger.debug( # pragma: nocover "Scheduler claims worker %s holds data for task %s, " "which is not true.", self.address,