diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 94ae05b8905..fd3453ad56a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -46,6 +46,8 @@ from distributed.protocol import pickle from distributed.scheduler import Scheduler from distributed.utils_test import ( + BlockedGatherDep, + BlockedGetData, TaskStateMetadataPlugin, _LockedCommPool, assert_story, @@ -2618,7 +2620,7 @@ def sink(a, b, *args): if peer_addr == a.address and msg["op"] == "get_data": break - # Provoke an "impossible transision exception" + # Provoke an "impossible transition exception" # By choosing a state which doesn't exist we're not running into validation # errors and the state machine should raise if we want to transition from # fetch to memory @@ -3137,30 +3139,6 @@ async def test_task_flight_compute_oserror(c, s, a, b): assert_story(sum_story, expected_sum_story, strict=True) -class BlockedGatherDep(Worker): - def __init__(self, *args, **kwargs): - self.in_gather_dep = asyncio.Event() - self.block_gather_dep = asyncio.Event() - super().__init__(*args, **kwargs) - - async def gather_dep(self, *args, **kwargs): - self.in_gather_dep.set() - await self.block_gather_dep.wait() - return await super().gather_dep(*args, **kwargs) - - -class BlockedGetData(Worker): - def __init__(self, *args, **kwargs): - self.in_get_data = asyncio.Event() - self.block_get_data = asyncio.Event() - super().__init__(*args, **kwargs) - - async def get_data(self, comm, *args, **kwargs): - self.in_get_data.set() - await self.block_get_data.wait() - return await super().get_data(comm, *args, **kwargs) - - @gen_cluster(client=True, nthreads=[]) async def test_gather_dep_cancelled_rescheduled(c, s): """At time of writing, the gather_dep implementation filtered tasks again diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 1b6f2e3f80b..d852a42499a 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -3,9 +3,17 @@ import pytest +from distributed import Worker, wait from distributed.protocol.serialize import Serialize from distributed.utils import recursive_to_dict -from distributed.utils_test import _LockedCommPool, assert_story, gen_cluster, inc +from distributed.utils_test import ( + BlockedGetData, + _LockedCommPool, + assert_story, + freeze_data_fetching, + gen_cluster, + inc, +) from distributed.worker_state_machine import ( ExecuteFailureEvent, ExecuteSuccessEvent, @@ -16,12 +24,13 @@ SendMessageToScheduler, StateMachineEvent, TaskState, + TaskStateState, UniqueTaskHeap, merge_recs_instructions, ) -async def wait_for_state(key, state, dask_worker): +async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) -> None: while key not in dask_worker.tasks or dask_worker.tasks[key].state != state: await asyncio.sleep(0.005) @@ -247,26 +256,17 @@ def test_executefailure_to_dict(): @gen_cluster(client=True) async def test_fetch_to_compute(c, s, a, b): - # Block ensure_communicating to ensure we indeed know that the task is in - # fetch and doesn't leave it accidentally - old_out_connections, b.total_out_connections = b.total_out_connections, 0 - old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0 - - f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) - f2 = c.submit(inc, f1, workers=[b.address], key="f2") - - await wait_for_state(f1.key, "fetch", b) - await a.close() - - b.total_out_connections = old_out_connections - b.comm_threshold_bytes = old_comm_threshold + with freeze_data_fetching(b): + f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) + f2 = c.submit(inc, f1, workers=[b.address], key="f2") + await wait_for_state(f1.key, "fetch", b) + await a.close() await f2 assert_story( b.log, - # FIXME: This log should be replaced with an - # StateMachineEvent/Instruction log + # FIXME: This log should be replaced with a StateMachineEvent log [ (f2.key, "compute-task", "released"), # This is a "please fetch" request. We don't have anything like @@ -285,23 +285,180 @@ async def test_fetch_to_compute(c, s, a, b): @gen_cluster(client=True) async def test_fetch_via_amm_to_compute(c, s, a, b): - # Block ensure_communicating to ensure we indeed know that the task is in - # fetch and doesn't leave it accidentally - old_out_connections, b.total_out_connections = b.total_out_connections, 0 - old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0 - - f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) + with freeze_data_fetching(b): + f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) + await f1 + s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test") + await wait_for_state(f1.key, "fetch", b) + await a.close() await f1 - s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test") - await wait_for_state(f1.key, "fetch", b) - await a.close() + assert_story( + b.log, + # FIXME: This log should be replaced with a StateMachineEvent log + [ + (f1.key, "ensure-task-exists", "released"), + (f1.key, "released", "fetch", "fetch", {}), + (f1.key, "compute-task", "fetch"), + (f1.key, "put-in-memory"), + ], + ) - b.total_out_connections = old_out_connections - b.comm_threshold_bytes = old_comm_threshold - await f1 +@pytest.mark.parametrize("as_deps", [False, True]) +@gen_cluster(client=True, nthreads=[("", 1)] * 3) +async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps): + """ + as_deps=True + 0. task x is a dependency of y1 and y2 + 1. scheduler calls handle_compute("y1", who_has={"x": [w2, w3]}) on w1 + 2. x transitions released -> fetch + 3. the network stack is busy, so x does not transition to flight yet. + 4. scheduler calls handle_compute("y2", who_has={"x": [w3]}) on w1 + 5. when x finally reaches the top of the data_needed heap, w1 will not try + contacting w2 + + as_deps=False + 1. scheduler calls handle_acquire_replicas(who_has={"x": [w2, w3]}) on w1 + 2. x transitions released -> fetch + 3. the network stack is busy, so x does not transition to flight yet. + 4. scheduler calls handle_acquire_replicas(who_has={"x": [w3]}) on w1 + 5. when x finally reaches the top of the data_needed heap, w1 will not try + contacting w2 + """ + x = (await c.scatter({"x": 1}, workers=[w2.address, w3.address], broadcast=True))[ + "x" + ] + + # Make sure find_missing is not involved + w1.periodic_callbacks["find-missing"].stop() + + with freeze_data_fetching(w1, jump_start=True): + if as_deps: + y1 = c.submit(inc, x, key="y1", workers=[w1.address]) + else: + s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test") + + await wait_for_state("x", "fetch", w1) + assert w1.tasks["x"].who_has == {w2.address, w3.address} + + assert len(s.tasks["x"].who_has) == 2 + await w2.close() + while len(s.tasks["x"].who_has) > 1: + await asyncio.sleep(0.01) + + if as_deps: + y2 = c.submit(inc, x, key="y2", workers=[w1.address]) + else: + s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test") + + while w1.tasks["x"].who_has != {w3.address}: + await asyncio.sleep(0.01) + + await wait_for_state("x", "memory", w1) + assert_story( + w1.story("request-dep"), + [("request-dep", w3.address, {"x"})], + # This tests that there has been no attempt to contact w2. + # If the assumption being tested breaks, this will fail 50% of the times. + strict=True, + ) + + +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_fetch_to_missing(c, s, a, b): + """ + 1. task x is a dependency of y + 2. scheduler calls handle_compute("y", who_has={"x": [b]}) on a + 3. x transitions released -> fetch -> flight; a connects to b + 4. b responds it's busy. x transitions flight -> fetch + 5. The busy state triggers an RPC call to Scheduler.who_has + 6. the scheduler responds {"x": []}, because w1 in the meantime has lost the key. + 7. x is transitioned fetch -> missing + """ + x = await c.scatter({"x": 1}, workers=[b.address]) + b.total_in_connections = 0 + # Crucially, unlike with `c.submit(inc, x, workers=[a.address])`, the scheduler + # doesn't keep track of acquire-replicas requests, so it won't proactively inform a + # when we call remove_worker later on + s.request_acquire_replicas(a.address, ["x"], stimulus_id="test") + + # state will flip-flop between fetch and flight every 150ms, which is the retry + # period for busy workers. + await wait_for_state("x", "fetch", a) + assert b.address in a.busy_workers + + # Sever connection between b and s, but not between b and a. + # If a tries fetching from b after this, b will keep responding {status: busy}. + b.periodic_callbacks["heartbeat"].stop() + await s.remove_worker(b.address, close=False, stimulus_id="test") + + await wait_for_state("x", "missing", a) + + assert_story( + a.story("x"), + [ + ("x", "ensure-task-exists", "released"), + ("x", "released", "fetch", "fetch", {}), + ("gather-dependencies", b.address, {"x"}), + ("x", "fetch", "flight", "flight", {}), + ("request-dep", b.address, {"x"}), + ("busy-gather", b.address, {"x"}), + ("x", "flight", "fetch", "fetch", {}), + ("x", "fetch", "missing", "missing", {}), + ], + # There may be a round of find_missing() after this. + # Due to timings, there also may be multiple attempts to connect from a to b. + strict=False, + ) + + +@pytest.mark.skip(reason="https://github.com/dask/distributed/issues/6446") +@gen_cluster(client=True) +async def test_new_replica_while_all_workers_in_flight(c, s, w1, w2): + """A task is stuck in 'fetch' state because all workers that hold a replica are in + flight. While in this state, a new replica appears on a different worker and the + scheduler informs the waiting worker through a new acquire-replicas or + compute-task op. + + In real life, this will typically happen when the Active Memory Manager replicates a + key to multiple workers and some workers are much faster than others to acquire it, + due to unrelated tasks being in flight, so 2 seconds later the AMM reiterates the + request, passing a larger who_has. + + Test that, when this happens, the task is immediately acquired from the new worker, + without waiting for the original replica holders to get out of flight. + """ + # Make sure find_missing is not involved + w1.periodic_callbacks["find-missing"].stop() + + async with BlockedGetData(s.address) as w3: + x = c.submit(inc, 1, key="x", workers=[w3.address]) + y = c.submit(inc, 2, key="y", workers=[w3.address]) + await wait([x, y]) + s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test") + await w3.in_get_data.wait() + assert w1.tasks["x"].state == "flight" + s.request_acquire_replicas(w1.address, ["y"], stimulus_id="test") + # This cannot progress beyond fetch because w3 is already in flight + await wait_for_state("y", "fetch", w1) + + # Simulate that the AMM also requires that w2 acquires a replica of x. + # The replica lands on w2 soon afterwards, while w3->w1 comms remain blocked by + # unrelated transfers (x in our case). + w2.update_data({"y": 3}, report=True) + ws2 = s.workers[w2.address] + while ws2 not in s.tasks["y"].who_has: + await asyncio.sleep(0.01) + + # 2 seconds later, the AMM reiterates that w1 should acquire a replica of y + s.request_acquire_replicas(w1.address, ["y"], stimulus_id="test") + await wait_for_state("y", "memory", w1) + + # Finally let the other worker to get out of flight + w3.block_get_data.set() + await wait_for_state("x", "memory", w1) @gen_cluster(client=True) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index edd74bef17e..a2780f15f1f 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2261,3 +2261,89 @@ def wait_for_log_line( if match in line: return line i += 1 + + +class BlockedGatherDep(Worker): + """A Worker that sets event `in_gather_dep` the first time it enters the gather_dep + method and then does not initiate any comms, thus leaving the task(s) in flight + indefinitely, until the test sets `block_gather_dep` + + Example + ------- + .. code-block:: python + + @gen_test() + async def test1(s, a, b): + async with BlockedGatherDep(s.address) as x: + # [do something to cause x to fetch data from a or b] + await x.in_gather_dep.wait() + # [do something that must happen while the tasks are in flight] + x.block_gather_dep.set() + # [from this moment on, x is a regular worker] + + See also + -------- + BlockedGetData + """ + + def __init__(self, *args, **kwargs): + self.in_gather_dep = asyncio.Event() + self.block_gather_dep = asyncio.Event() + super().__init__(*args, **kwargs) + + async def gather_dep(self, *args, **kwargs): + self.in_gather_dep.set() + await self.block_gather_dep.wait() + return await super().gather_dep(*args, **kwargs) + + +class BlockedGetData(Worker): + """A Worker that sets event `in_get_data` the first time it enters the get_data + method and then does not answer the comms, thus leaving the task(s) in flight + indefinitely, until the test sets `block_get_data` + + See also + -------- + BlockedGatherDep + """ + + def __init__(self, *args, **kwargs): + self.in_get_data = asyncio.Event() + self.block_get_data = asyncio.Event() + super().__init__(*args, **kwargs) + + async def get_data(self, comm, *args, **kwargs): + self.in_get_data.set() + await self.block_get_data.wait() + return await super().get_data(comm, *args, **kwargs) + + +@contextmanager +def freeze_data_fetching(w: Worker, *, jump_start: bool = False): + """Prevent any task from transitioning from fetch to flight on the worker while + inside the context, simulating a situation where the worker's network comms are + saturated. + + This is not the same as setting the worker to Status=paused, which would also + inform the Scheduler and prevent further tasks to be enqueued on the worker. + + Parameters + ---------- + w: Worker + The Worker on which tasks will not transition from fetch to flight + jump_start: bool + If False, tasks will remain in fetch state after exiting the context, until + something else triggers ensure_communicating. + If True, trigger ensure_communicating on exit; this simulates e.g. an unrelated + worker moving out of in_flight_workers. + """ + old_out_connections = w.total_out_connections + old_comm_threshold = w.comm_threshold_bytes + w.total_out_connections = 0 + w.comm_threshold_bytes = 0 + yield + w.total_out_connections = old_out_connections + w.comm_threshold_bytes = old_comm_threshold + if jump_start: + w.status = Status.paused + w.status = Status.running diff --git a/distributed/worker.py b/distributed/worker.py index c146cb56fff..0ad15cd3fb0 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1890,7 +1890,7 @@ def handle_acquire_replicas( if ts.state != "memory": recommendations[ts] = "fetch" - self.update_who_has(who_has) + self._update_who_has(who_has) self.transitions(recommendations, stimulus_id=stimulus_id) if self.validate: @@ -1994,7 +1994,7 @@ def handle_compute_task( for dep_key, value in nbytes.items(): self.tasks[dep_key].nbytes = value - self.update_who_has(who_has) + self._update_who_has(who_has) else: # pragma: nocover raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}") @@ -2101,8 +2101,7 @@ def transition_released_waiting( if dep_ts.state != "memory": ts.waiting_for_data.add(dep_ts) dep_ts.waiters.add(ts) - if dep_ts.state not in {"fetch", "flight"}: - recommendations[dep_ts] = "fetch" + recommendations[dep_ts] = "fetch" if ts.waiting_for_data: self.waiting_for_data_count += 1 @@ -2689,7 +2688,7 @@ def _transition( assert not args finish, *args = finish # type: ignore - if ts is None or ts.state == finish: + if ts.state == finish: return {}, [] start = ts.state @@ -2992,8 +2991,11 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: if ts.state != "fetch" or ts.key in all_keys_to_gather: continue + if not ts.who_has: + recommendations[ts] = "missing" + continue + if self.validate: - assert ts.who_has assert self.address not in ts.who_has workers = [ @@ -3146,8 +3148,14 @@ def _select_keys_for_gather( while tasks: ts = tasks.peek() - if ts.state != "fetch" or ts.key in all_keys_to_gather: + if ( + ts.state != "fetch" # Do not acquire the same key twice if multiple workers holds replicas + or ts.key in all_keys_to_gather + # A replica is still available (otherwise status would not be 'fetch' + # anymore), but not on this worker. See _update_who_has(). + or worker not in ts.who_has + ): tasks.pop() continue if total_bytes + ts.get_nbytes() > self.target_message_size: @@ -3373,7 +3381,7 @@ def done_event(): who_has = await retry_operation( self.scheduler.who_has, keys=refresh_who_has ) - self.update_who_has(who_has) + self._update_who_has(who_has) @log_errors def _readd_busy_worker(self, worker: str) -> None: @@ -3397,7 +3405,7 @@ async def find_missing(self) -> None: self.scheduler.who_has, keys=[ts.key for ts in self._missing_dep_flight], ) - self.update_who_has(who_has) + self._update_who_has(who_has) recommendations: Recs = {} for ts in self._missing_dep_flight: if ts.who_has: @@ -3411,34 +3419,46 @@ async def find_missing(self) -> None: "find-missing" ].callback_time = self.periodic_callbacks["heartbeat"].callback_time - def update_who_has(self, who_has: dict[str, Collection[str]]) -> None: - try: - for dep, workers in who_has.items(): - if not workers: - continue + def _update_who_has(self, who_has: Mapping[str, Collection[str]]) -> None: + for key, workers in who_has.items(): + ts = self.tasks.get(key) + if not ts: + # The worker sent a refresh-who-has request to the scheduler but, by the + # time the answer comes back, some of the keys have been forgotten. + continue + workers = set(workers) + + if self.address in workers: + workers.remove(self.address) + # This can only happen if rebalance() recently asked to release a key, + # but the RPC call hasn't returned yet. rebalance() is flagged as not + # being safe to run while the cluster is not at rest and has already + # been penned in to be redesigned on top of the AMM. + # It is not necessary to send a message back to the + # scheduler here, because it is guaranteed that there's already a + # release-worker-data message in transit to it. + if ts.state != "memory": + logger.debug( # pragma: nocover + "Scheduler claims worker %s holds data for task %s, " + "which is not true.", + self.address, + ts, + ) - if dep in self.tasks: - dep_ts = self.tasks[dep] - if self.address in workers and self.tasks[dep].state != "memory": - logger.debug( - "Scheduler claims worker %s holds data for task %s which is not true.", - self.name, - dep, - ) - # Do not mutate the input dict. That's rude - workers = set(workers) - {self.address} - dep_ts.who_has.update(workers) + if ts.who_has == workers: + continue - for worker in workers: - self.has_what[worker].add(dep) - self.data_needed_per_worker[worker].push(dep_ts) - except Exception as e: # pragma: no cover - logger.exception(e) - if LOG_PDB: - import pdb + for worker in ts.who_has - workers: + self.has_what[worker].discard(key) + # Can't remove from self.data_needed_per_worker; there is logic + # in _select_keys_for_gather to deal with this - pdb.set_trace() - raise + for worker in workers - ts.who_has: + self.has_what[worker].add(key) + if ts.state == "fetch": + self.data_needed_per_worker[worker].push(ts) + + ts.who_has = workers def handle_steal_request(self, key: str, stimulus_id: str) -> None: # There may be a race condition between stealing and releasing a task. @@ -4170,8 +4190,8 @@ def validate_task_fetch(self, ts): assert self.address not in ts.who_has assert not ts.done assert ts in self.data_needed - assert ts.who_has - + # Note: ts.who_has may be have been emptied by _update_who_has, but the task + # won't transition to missing until it reaches the top of the data_needed heap. for w in ts.who_has: assert ts.key in self.has_what[w] assert ts in self.data_needed_per_worker[w] @@ -4262,6 +4282,7 @@ def validate_state(self): assert ts.state is not None # check that worker has task for worker in ts.who_has: + assert worker != self.address assert ts.key in self.has_what[worker] # check that deps have a set state and that dependency<->dependent links # are there @@ -4286,6 +4307,7 @@ def validate_state(self): # FIXME https://github.com/dask/distributed/issues/6319 # assert self.waiting_for_data_count == waiting_for_data_count for worker, keys in self.has_what.items(): + assert worker != self.address for k in keys: assert worker in self.tasks[k].who_has