diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 45d7c58bf79..81a1ff6636f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4140,7 +4140,9 @@ def stimulus_retry(self, keys, client=None): return tuple(seen) @log_errors - async def remove_worker(self, address, stimulus_id, safe=False, close=True): + async def remove_worker( + self, address: str, *, stimulus_id: str, safe: bool = False, close: bool = True + ) -> Literal["OK", "already-removed"]: """ Remove worker from cluster @@ -4149,7 +4151,7 @@ async def remove_worker(self, address, stimulus_id, safe=False, close=True): state. """ if self.status == Status.closed: - return + return "already-removed" address = self.coerce_address(address) @@ -7060,23 +7062,20 @@ def adaptive_target(self, target_duration=None): to_close = self.workers_to_close() return len(self.workers) - len(to_close) - def request_acquire_replicas(self, addr: str, keys: list, *, stimulus_id: str): + def request_acquire_replicas( + self, addr: str, keys: Iterable[str], *, stimulus_id: str + ): """Asynchronously ask a worker to acquire a replica of the listed keys from other workers. This is a fire-and-forget operation which offers no feedback for success or failure, and is intended for housekeeping and not for computation. """ - who_has = {} - for key in keys: - ts = self.tasks[key] - who_has[key] = {ws.address for ws in ts.who_has} - + who_has = {key: {ws.address for ws in self.tasks[key].who_has} for key in keys} if self.validate: assert all(who_has.values()) self.stream_comms[addr].send( { "op": "acquire-replicas", - "keys": keys, "who_has": who_has, "stimulus_id": stimulus_id, }, diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index a4dca9e2872..409d530277c 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -300,7 +300,7 @@ async def get_data(self, comm, *args, **kwargs): s.set_restrictions({fut1.key: [a.address, b.address]}) # It is removed, i.e. get_data is guaranteed to fail and f1 is scheduled # to be recomputed on B - await s.remove_worker(a.address, "foo", close=False, safe=True) + await s.remove_worker(a.address, stimulus_id="foo", close=False, safe=True) while not b.tasks[fut1.key].state == "resumed": await asyncio.sleep(0.01) @@ -438,7 +438,7 @@ async def get_data(self, comm, *args, **kwargs): f3.key: {w2.address}, } ) - await s.remove_worker(w1.address, "stim-id") + await s.remove_worker(w1.address, stimulus_id="stim-id") await wait_for_state(f3.key, "resumed", w2) assert_story( diff --git a/distributed/tests/test_stories.py b/distributed/tests/test_stories.py index ec81ddaff8b..232eca6c36c 100644 --- a/distributed/tests/test_stories.py +++ b/distributed/tests/test_stories.py @@ -156,6 +156,7 @@ async def test_worker_story_with_deps(c, s, a, b): assert stimulus_ids == {"compute-task"} expected = [ ("dep", "ensure-task-exists", "released"), + ("dep", "update-who-has", [], [a.address]), ("dep", "released", "fetch", "fetch", {}), ("gather-dependencies", a.address, {"dep"}), ("dep", "fetch", "flight", "flight", {}), diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 034e0bf2953..1582911e177 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1741,7 +1741,7 @@ async def close(self): ) as wlogger, captured_logger( "distributed.scheduler", level=logging.WARNING ) as slogger: - await s.remove_worker(a.address, "foo") + await s.remove_worker(a.address, stimulus_id="foo") assert not s.workers # Wait until the close signal reaches the worker and it starts shutting down. @@ -2747,6 +2747,7 @@ def __getstate__(self): story[b], [ ("x", "ensure-task-exists", "released"), + ("x", "update-who-has", [], [a]), ("x", "released", "fetch", "fetch", {}), ("gather-dependencies", a, {"x"}), ("x", "fetch", "flight", "flight", {}), diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 1b6f2e3f80b..55de0e04dbe 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1,8 +1,11 @@ import asyncio +from contextlib import contextmanager from itertools import chain import pytest +from distributed import Worker +from distributed.core import Status from distributed.protocol.serialize import Serialize from distributed.utils import recursive_to_dict from distributed.utils_test import _LockedCommPool, assert_story, gen_cluster, inc @@ -16,12 +19,13 @@ SendMessageToScheduler, StateMachineEvent, TaskState, + TaskStateState, UniqueTaskHeap, merge_recs_instructions, ) -async def wait_for_state(key, state, dask_worker): +async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) -> None: while key not in dask_worker.tasks or dask_worker.tasks[key].state != state: await asyncio.sleep(0.005) @@ -245,28 +249,39 @@ def test_executefailure_to_dict(): assert ev3.traceback_text == "tb text" -@gen_cluster(client=True) -async def test_fetch_to_compute(c, s, a, b): - # Block ensure_communicating to ensure we indeed know that the task is in - # fetch and doesn't leave it accidentally - old_out_connections, b.total_out_connections = b.total_out_connections, 0 - old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0 +@contextmanager +def freeze_data_fetching(w: Worker): + """Prevent any task from transitioning from fetch to flight on the worker while + inside the context. - f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) - f2 = c.submit(inc, f1, workers=[b.address], key="f2") + This is not the same as setting the worker to Status=paused, which would also + inform the Scheduler and prevent further tasks to be enqueued on the worker. + """ + old_out_connections = w.total_out_connections + old_comm_threshold = w.comm_threshold_bytes + w.total_out_connections = 0 + w.comm_threshold_bytes = 0 + yield + w.total_out_connections = old_out_connections + w.comm_threshold_bytes = old_comm_threshold + # Jump-start ensure_communicating + w.status = Status.paused + w.status = Status.running - await wait_for_state(f1.key, "fetch", b) - await a.close() - b.total_out_connections = old_out_connections - b.comm_threshold_bytes = old_comm_threshold +@gen_cluster(client=True) +async def test_fetch_to_compute(c, s, a, b): + with freeze_data_fetching(b): + f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) + f2 = c.submit(inc, f1, workers=[b.address], key="f2") + await wait_for_state(f1.key, "fetch", b) + await a.close() await f2 assert_story( b.log, - # FIXME: This log should be replaced with an - # StateMachineEvent/Instruction log + # FIXME: This log should be replaced with a StateMachineEvent log [ (f2.key, "compute-task", "released"), # This is a "please fetch" request. We don't have anything like @@ -285,23 +300,179 @@ async def test_fetch_to_compute(c, s, a, b): @gen_cluster(client=True) async def test_fetch_via_amm_to_compute(c, s, a, b): - # Block ensure_communicating to ensure we indeed know that the task is in - # fetch and doesn't leave it accidentally - old_out_connections, b.total_out_connections = b.total_out_connections, 0 - old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0 - - f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) + with freeze_data_fetching(b): + f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) + await f1 + s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test") + await wait_for_state(f1.key, "fetch", b) + await a.close() await f1 - s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test") - await wait_for_state(f1.key, "fetch", b) - await a.close() + assert_story( + b.log, + # FIXME: This log should be replaced with a StateMachineEvent log + [ + (f1.key, "ensure-task-exists", "released"), + (f1.key, "released", "fetch", "fetch", {}), + (f1.key, "compute-task", "fetch"), + (f1.key, "put-in-memory"), + ], + ) - b.total_out_connections = old_out_connections - b.comm_threshold_bytes = old_comm_threshold - await f1 +@pytest.mark.parametrize("as_deps", [False, True]) +@gen_cluster(client=True, nthreads=[("", 1)] * 3) +async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps): + """ + as_deps=True + 0. task x is a dependency of y1 and y2 + 1. scheduler calls handle_compute("y1", who_has={"x": [w2, w3]}) on w1 + 2. x transitions released -> fetch + 3. the network stack is busy, so x does not transition to flight yet. + 4. scheduler calls handle_compute("y2", who_has={"x": [w3]}) on w1 + 5. when x finally reaches the top of the data_needed heap, the w1 will not try + contacting w2 + + as_deps=False + 1. scheduler calls handle_acquire_replicas(who_has={"x": [w2, w3]}) on w1 + 2. x transitions released -> fetch + 3. the network stack is busy, so x does not transition to flight yet. + 4. scheduler calls handle_acquire_replicas(who_has={"x": [w3]}) on w1 + 5. when x finally reaches the top of the data_needed heap, the w1 will not try + contacting w2 + """ + x = (await c.scatter({"x": 1}, workers=[w2.address, w3.address], broadcast=True))[ + "x" + ] + with freeze_data_fetching(w1): + if as_deps: + y1 = c.submit(inc, x, key="y1", workers=[w1.address]) + else: + s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test") + + await wait_for_state("x", "fetch", w1) + assert w1.tasks["x"].who_has == {w2.address, w3.address} + + assert len(s.tasks["x"].who_has) == 2 + await w2.close() + while len(s.tasks["x"].who_has) > 1: + await asyncio.sleep(0.01) + + if as_deps: + y2 = c.submit(inc, x, key="y2", workers=[w1.address]) + else: + s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test") + + while w1.tasks["x"].who_has != {w3.address}: + await asyncio.sleep(0.01) + + await wait_for_state("x", "memory", w1) + + assert_story( + w1.story("request-dep"), + [("request-dep", w3.address, {"x"})], + # This tests that there has been no attempt to contact w2. + # If the assumption being tested breaks, this will fail 50% of the times. + strict=True, + ) + + +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_fetch_to_missing(c, s, a, b): + """ + 1. task x is a dependency of y + 2. scheduler calls handle_compute("y", who_has={"x": [b]}) on a + 3. x transitions released -> fetch -> flight; a connects to b + 4. b responds it's busy. x transitions flight -> fetch + 5. The busy state triggers an RPC call to Scheduler.who_has + 6. the scheduler responds {"x": []}, because w1 in the meantime has lost the key. + 7. x is transitioned fetch -> missing + """ + x = await c.scatter({"x": 1}, workers=[b.address]) + b.total_in_connections = 0 + # Crucially, unlike with `c.submit(inc, x, workers=[a.address])`, the scheduler + # doesn't keep track of acquire-replicas requests, so it won't proactively inform a + # when we call remove_worker later on + s.request_acquire_replicas(a.address, ["x"], stimulus_id="test") + + # state will flip-flop between fetch and flight every 150ms, which is the retry + # period for busy workers. + await wait_for_state("x", "fetch", a) + assert b.address in a.busy_workers + + # Sever connection between b and s, but not between b and a. + # If a tries fetching from b after this, b will keep responding {status: busy}. + b.periodic_callbacks["heartbeat"].stop() + await s.remove_worker(b.address, close=False, stimulus_id="test") + + await wait_for_state("x", "missing", a) + + assert_story( + a.story("x"), + [ + ("x", "ensure-task-exists", "released"), + ("x", "update-who-has", [], [b.address]), + ("x", "released", "fetch", "fetch", {}), + ("gather-dependencies", b.address, {"x"}), + ("x", "fetch", "flight", "flight", {}), + ("request-dep", b.address, {"x"}), + ("busy-gather", b.address, {"x"}), + ("x", "flight", "fetch", "fetch", {}), + ("x", "update-who-has", [b.address], []), # Called Scheduler.who_has + ("x", "fetch", "missing", "missing", {}), + ], + # There may be a round of find_missing() after this. + # Due to timings, there also may be multiple attempts to connect from a to b. + strict=False, + ) + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_self_denounce_missing_data(c, s, a): + x = c.submit(inc, 1, key="x") + await x + + # Wipe x from the worker. This simulates the following condition: + # 1. The scheduler thinks a and b hold a replica of x. + # 2. b is unresponsive, but the scheduler doesn't know yet as it didn't time out. + # 3. The AMM decides to ask a to drop its replica. + a.handle_remove_replicas(keys=["x"], stimulus_id="test") + + # Lose the message that would inform the scheduler. + # In theory, this should not happen. + # In practice, in case of TPC connection fault, BatchedSend can drop messages. + assert a.batched_stream.buffer == [ + {"key": "x", "stimulus_id": "test", "op": "release-worker-data"} + ] + a.batched_stream.buffer.clear() + + y = c.submit(inc, x, key="y") + # The scheduler tries computing y, but a responds that x is not available. + # The scheduler kicks off the computation of x and then y from scratch. + assert await y == 3 + + assert_story( + a.story("compute-task"), + [ + ("x", "compute-task", "released"), + # The scheduler tries computing y a first time and fails. + # This line would not be here if we didn't lose the + # {"op": "release-worker-data"} message earlier. + ("y", "compute-task", "released"), + # The scheduler receives the {"op": "missing-data"} message from the + # worker. This makes the computation of y to fail. The scheduler reschedules + # x and then y. + ("x", "compute-task", "released"), + ("y", "compute-task", "released"), + ], + strict=True, + ) + + del x + while "x" in a.data: + await asyncio.sleep(0.01) + assert a.tasks["x"].state == "released" @gen_cluster(client=True) diff --git a/distributed/worker.py b/distributed/worker.py index 437a5b20773..f7746bf413b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1857,17 +1857,15 @@ def handle_cancel_compute(self, key: str, stimulus_id: str) -> None: def handle_acquire_replicas( self, - *, - keys: Collection[str], who_has: dict[str, Collection[str]], + *, stimulus_id: str, ) -> None: if self.validate: - assert set(keys) == who_has.keys() assert all(who_has.values()) recommendations: Recs = {} - for key in keys: + for key in who_has: ts = self.ensure_task_exists( key=key, # Transfer this data after all dependency tasks of computations with @@ -1880,11 +1878,15 @@ def handle_acquire_replicas( if ts.state != "memory": recommendations[ts] = "fetch" - self.update_who_has(who_has) + recommendations, instructions = merge_recs_instructions( + (recommendations, []), + self._update_who_has(who_has, stimulus_id=stimulus_id), + ) self.transitions(recommendations, stimulus_id=stimulus_id) + self._handle_instructions(instructions) if self.validate: - for key in keys: + for key in who_has: assert self.tasks[key].state != "released", self.story(key) def ensure_task_exists( @@ -1984,7 +1986,10 @@ def handle_compute_task( for dep_key, value in nbytes.items(): self.tasks[dep_key].nbytes = value - self.update_who_has(who_has) + recommendations, instructions = merge_recs_instructions( + (recommendations, instructions), + self._update_who_has(who_has, stimulus_id=stimulus_id), + ) else: # pragma: nocover raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}") @@ -3136,8 +3141,14 @@ def _select_keys_for_gather( while tasks: ts = tasks.peek() - if ts.state != "fetch" or ts.key in all_keys_to_gather: + if ( + ts.state != "fetch" # Do not acquire the same key twice if multiple workers holds replicas + or ts.key in all_keys_to_gather + # A replica is still available (otherwise status would not be 'fetch' + # anymore), but not on this worker. See _update_who_has(). + or worker not in ts.who_has + ): tasks.pop() continue if total_bytes + ts.get_nbytes() > self.target_message_size: @@ -3363,7 +3374,12 @@ def done_event(): who_has = await retry_operation( self.scheduler.who_has, keys=refresh_who_has ) - self.update_who_has(who_has) + refresh_stimulus_id = f"refresh-who-has-{time()}" + recommendations, instructions = self._update_who_has( + who_has, stimulus_id=refresh_stimulus_id + ) + self.transitions(recommendations, stimulus_id=refresh_stimulus_id) + self._handle_instructions(instructions) @log_errors def _readd_busy_worker(self, worker: str) -> None: @@ -3387,12 +3403,15 @@ async def find_missing(self) -> None: self.scheduler.who_has, keys=[ts.key for ts in self._missing_dep_flight], ) - self.update_who_has(who_has) - recommendations: Recs = {} + recommendations, instructions = self._update_who_has( + who_has, stimulus_id=stimulus_id + ) for ts in self._missing_dep_flight: if ts.who_has: + assert ts not in recommendations recommendations[ts] = "fetch" self.transitions(recommendations, stimulus_id=stimulus_id) + self._handle_instructions(instructions) finally: self._find_missing_running = False @@ -3401,34 +3420,81 @@ async def find_missing(self) -> None: "find-missing" ].callback_time = self.periodic_callbacks["heartbeat"].callback_time - def update_who_has(self, who_has: dict[str, Collection[str]]) -> None: - try: - for dep, workers in who_has.items(): - if not workers: - continue + def _update_who_has( + self, who_has: Mapping[str, Collection[str]], *, stimulus_id: str + ) -> RecsInstrs: + recs: Recs = {} + instructions: Instructions = [] - if dep in self.tasks: - dep_ts = self.tasks[dep] - if self.address in workers and self.tasks[dep].state != "memory": - logger.debug( - "Scheduler claims worker %s holds data for task %s which is not true.", - self.name, - dep, + for key, workers in who_has.items(): + ts = self.tasks.get(key) + if not ts: + # The worker sent a refresh-who-has request to the scheduler but, by the + # time the answer comes back, some of the keys have been forgotten. + continue + workers = set(workers) + + if self.address in workers: + workers.remove(self.address) + if ts.state != "memory": + logger.debug( + "Scheduler claims worker %s holds data for task %s, " + "which is not true.", + self.name, + ts, + ) + instructions.append( + MissingDataMsg( + key=key, errant_worker=self.address, stimulus_id=stimulus_id ) - # Do not mutate the input dict. That's rude - workers = set(workers) - {self.address} - dep_ts.who_has.update(workers) + ) - for worker in workers: - self.has_what[worker].add(dep) - self.data_needed_per_worker[worker].push(dep_ts) - except Exception as e: # pragma: no cover - logger.exception(e) - if LOG_PDB: - import pdb + if ts.who_has == workers: + continue - pdb.set_trace() - raise + self.log.append( + ( + key, + "update-who-has", + list(ts.who_has), + list(workers), + stimulus_id, + time(), + ) + ) + + for worker in ts.who_has - workers: + self.has_what[worker].discard(key) + # Can't remove from self.data_needed_per_worker; there is logic + # in _select_keys_for_gather to deal with this + + for worker in workers - ts.who_has: + self.has_what[worker].add(key) + if ts.state == "fetch": + self.data_needed_per_worker[worker].push(ts) + # All workers which previously held a replica of the key may either + # be in flight or busy. Kick off ensure_communicating to try + # fetching the data from the new worker. + # There are other reasons why a task may be sitting in 'fetch' state + # - e.g. if we're over the total_out_connections or the + # comm_threshold_bytes limit, or if we're paused. We're deliberately + # NOT testing the full gamut of use cases for the sake of simplicity + # and robustness and just stating that ensure_communicating *may* + # return new GatherDep events now. + instructions.append( + EnsureCommunicatingAfterTransitions(stimulus_id=stimulus_id) + ) + + ts.who_has = workers + # currently fetching -> can no longer be fetched -> transition to missing + # any other state -> eventually, possibly, the task may transition to fetch + # or missing, at which point the relevant transitions will test who_has that + # we just updated. e.g. see the various transitions to fetch, which + # instead recommend transitioning to missing if who_has is empty. + if not workers and ts.state == "fetch": + recs[ts] = "missing" + + return recs, instructions def handle_steal_request(self, key: str, stimulus_id: str) -> None: # There may be a race condition between stealing and releasing a task. @@ -4269,6 +4335,7 @@ def validate_state(self): assert ts.state is not None # check that worker has task for worker in ts.who_has: + assert worker != self.address assert ts.key in self.has_what[worker] # check that deps have a set state and that dependency<->dependent links # are there @@ -4293,6 +4360,7 @@ def validate_state(self): # FIXME https://github.com/dask/distributed/issues/6319 # assert self.waiting_for_data_count == waiting_for_data_count for worker, keys in self.has_what.items(): + assert worker != self.address for k in keys: assert worker in self.tasks[k].who_has