diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8743d9e65f4..480e5b9148d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3023,7 +3023,6 @@ def __init__( "task-erred": self.handle_task_erred, "release-worker-data": self.release_worker_data, "add-keys": self.add_keys, - "missing-data": self.handle_missing_data, "long-running": self.handle_long_running, "reschedule": self.reschedule, "keep-alive": lambda *args, **kwargs: None, @@ -4667,40 +4666,6 @@ def handle_task_erred(self, key: str, stimulus_id: str, **msg) -> None: self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) - def handle_missing_data( - self, key: str, worker: str, errant_worker: str, stimulus_id: str - ) -> None: - """Signal that `errant_worker` does not hold `key`. - - This may either indicate that `errant_worker` is dead or that we may be working - with stale data and need to remove `key` from the workers `has_what`. If no - replica of a task is available anymore, the task is transitioned back to - released and rescheduled, if possible. - - Parameters - ---------- - key : str - Task key that could not be found - worker : str - Address of the worker informing the scheduler - errant_worker : str - Address of the worker supposed to hold a replica - """ - logger.debug(f"handle missing data {key=} {worker=} {errant_worker=}") - self.log_event(errant_worker, {"action": "missing-data", "key": key}) - - ts = self.tasks.get(key) - ws = self.workers.get(errant_worker) - if not ts or not ws or ws not in ts.who_has: - return - - self.remove_replica(ts, ws) - if ts.state == "memory" and not ts.who_has: - if ts.run_spec: - self.transitions({key: "released"}, stimulus_id) - else: - self.transitions({key: "forgotten"}, stimulus_id) - def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None: ts = self.tasks.get(key) ws = self.workers.get(worker) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 67fac061031..94de10ac950 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -101,19 +101,6 @@ def f(ev): ) -@gen_cluster(client=True) -async def test_worker_find_missing(c, s, a, b): - fut = c.submit(inc, 1, workers=[a.address]) - await fut - # We do not want to use proper API since it would ensure that the cluster is - # informed properly - del a.data[fut.key] - del a.tasks[fut.key] - - # Actually no worker has the data; the scheduler is supposed to reschedule - assert await c.submit(inc, fut, workers=[b.address]) == 3 - - @gen_cluster(client=True) async def test_worker_stream_died_during_comm(c, s, a, b): write_queue = asyncio.Queue() diff --git a/distributed/tests/test_stories.py b/distributed/tests/test_stories.py index ec81ddaff8b..9c82e1c9a21 100644 --- a/distributed/tests/test_stories.py +++ b/distributed/tests/test_stories.py @@ -138,8 +138,7 @@ async def test_worker_story_with_deps(c, s, a, b): # Story now includes randomized stimulus_ids and timestamps. story = b.story("res") stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story} - assert stimulus_ids == {"compute-task", "task-finished"} - + assert stimulus_ids == {"compute-task", "gather-dep-success", "task-finished"} # This is a simple transition log expected = [ ("res", "compute-task", "released"), @@ -153,7 +152,7 @@ async def test_worker_story_with_deps(c, s, a, b): story = b.story("dep") stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story} - assert stimulus_ids == {"compute-task"} + assert stimulus_ids == {"compute-task", "gather-dep-success"} expected = [ ("dep", "ensure-task-exists", "released"), ("dep", "released", "fetch", "fetch", {}), diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 1ace95dc605..4e61544214d 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2946,7 +2946,6 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers): coming_from.handle_stimulus(RemoveReplicasEvent(keys=[f1.key], stimulus_id="test")) await f2 - assert_story(a.story(f1.key), [(f1.key, "missing-dep")]) assert a.tasks[f1.key].suspicious_count == 0 assert s.tasks[f1.key].suspicious == 0 diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 94408d1ef85..4ba182530e4 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -647,3 +647,39 @@ async def test_fetch_to_missing_on_refresh_who_has(c, s, w1, w2, w3): assert w3.tasks["x"].state == "missing" assert w3.tasks["y"].state == "flight" assert w3.tasks["y"].who_has == {w2.address} + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_fetch_to_missing_on_network_failure(c, s, a): + """ + 1. Two tasks, x and y, are respectively in flight and fetch state from the same + worker, which holds the only replica of both. + 2. gather_dep for x returns GatherDepNetworkFailureEvent + 3. The event empties has_what, x.who_has, and y.who_has; it recommends a transition + to missing for both x and y. + 5. Before the recommendation can be implemented, the same event invokes + _ensure_communicating, which pops y from data_needed - but y has an empty + who_has, which is an exceptional situation. + 6. The fetch->missing transition is executed, but y is no longer in data_needed - + another exceptional situation. + """ + block_get_data = asyncio.Event() + + class BlockedBreakingWorker(Worker): + async def get_data(self, comm, *args, **kwargs): + await block_get_data.wait() + raise OSError("fake error") + + async with BlockedBreakingWorker(s.address) as b: + x = c.submit(inc, 1, key="x", workers=[b.address]) + y = c.submit(inc, 2, key="y", workers=[b.address]) + await wait([x, y]) + s.request_acquire_replicas(a.address, ["x"], stimulus_id="test_x") + await wait_for_state("x", "flight", a) + s.request_acquire_replicas(a.address, ["y"], stimulus_id="test_y") + await wait_for_state("y", "fetch", a) + + block_get_data.set() + + await wait_for_state("x", "missing", a) + # await wait_for_state("y", "missing", a) diff --git a/distributed/worker.py b/distributed/worker.py index 880f7415022..4c199193247 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -21,6 +21,7 @@ Collection, Container, Iterable, + Iterator, Mapping, MutableMapping, ) @@ -122,12 +123,15 @@ FindMissingEvent, FreeKeysEvent, GatherDep, + GatherDepBusyEvent, GatherDepDoneEvent, + GatherDepFailureEvent, + GatherDepNetworkFailureEvent, + GatherDepSuccessEvent, Instructions, InvalidTaskState, InvalidTransition, LongRunningMsg, - MissingDataMsg, RecommendationsConflict, Recs, RecsInstrs, @@ -3018,6 +3022,9 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: # transition hasn't been enacted yet, so the task is still in fetch # state and in data_needed. # See matching code in transition_fetch_missing. + # Just to be sure, transition to missing again. If the above + # described assumption is correct, this is a no-op + recommendations[ts] = "missing" continue workers = [ @@ -3289,13 +3296,6 @@ async def gather_dep( if self.status not in WORKER_ANY_RUNNING: return None - recommendations: Recs = {} - instructions: Instructions = [] - response = {} - - def done_event(): - return GatherDepDoneEvent(stimulus_id=f"gather-dep-done-{time()}") - try: self.log.append(("request-dep", worker, to_gather, stimulus_id, time())) logger.debug("Request %d keys from %s", len(to_gather), worker) @@ -3306,8 +3306,14 @@ def done_event(): ) stop = time() if response["status"] == "busy": - return done_event() + self.log.append(("busy-gather", worker, to_gather, stimulus_id, time())) + return GatherDepBusyEvent( + worker=worker, + total_nbytes=total_nbytes, + stimulus_id=f"gather-dep-busy-{time()}", + ) + assert response["status"] == "OK" cause = self._get_cause(to_gather) self._update_metrics_received_data( start=start, @@ -3319,93 +3325,160 @@ def done_event(): self.log.append( ("receive-dep", worker, set(response["data"]), stimulus_id, time()) ) - return done_event() + return GatherDepSuccessEvent( + worker=worker, + total_nbytes=total_nbytes, + data=response["data"], + stimulus_id=f"gather-dep-success-{time()}", + ) except OSError: logger.exception("Worker stream died during communication: %s", worker) - has_what = self.has_what.pop(worker) - self.data_needed_per_worker.pop(worker) self.log.append( - ("receive-dep-failed", worker, has_what, stimulus_id, time()) + ("receive-dep-failed", worker, to_gather, stimulus_id, time()) + ) + return GatherDepNetworkFailureEvent( + worker=worker, + total_nbytes=total_nbytes, + stimulus_id=f"gather-dep-network-failure-{time()}", ) - for d in has_what: - ts = self.tasks[d] - ts.who_has.remove(worker) - if not ts.who_has and ts.state in ( - "fetch", - "flight", - "resumed", - "cancelled", - ): - recommendations[ts] = "missing" - self.log.append( - ("missing-who-has", worker, ts.key, stimulus_id, time()) - ) - return done_event() except Exception as e: + # e.g. data failed to deserialize logger.exception(e) if self.batched_stream and LOG_PDB: import pdb pdb.set_trace() - msg = error_message(e) - for k in self.in_flight_workers[worker]: - ts = self.tasks[k] - recommendations[ts] = tuple(msg.values()) - return done_event() - finally: - self.comm_nbytes -= total_nbytes - busy = response.get("status", "") == "busy" - data = response.get("data", {}) + return GatherDepFailureEvent.from_exception( + e, + worker=worker, + total_nbytes=total_nbytes, + stimulus_id=f"gather-dep-failure-{time()}", + ) - if busy: - self.log.append(("busy-gather", worker, to_gather, stimulus_id, time())) - # Avoid hammering the worker. If there are multiple replicas - # available, immediately try fetching from a different worker. - self.busy_workers.add(worker) - instructions.append( - RetryBusyWorkerLater(worker=worker, stimulus_id=stimulus_id) - ) + def _gather_dep_done_common(self, ev: GatherDepDoneEvent) -> Iterator[TaskState]: + """Common code for all subclasses of GatherDepDoneEvent. - refresh_who_has = [] - - for d in self.in_flight_workers.pop(worker): - ts = self.tasks[d] - ts.done = True - if d in data: - recommendations[ts] = ("memory", data[d]) - elif busy: - recommendations[ts] = "fetch" - if not ts.who_has - self.busy_workers: - refresh_who_has.append(d) - elif ts not in recommendations: - ts.who_has.discard(worker) - self.has_what[worker].discard(ts.key) - self.data_needed_per_worker[worker].discard(ts) - self.log.append((d, "missing-dep", stimulus_id, time())) - instructions.append( - MissingDataMsg( - key=d, - errant_worker=worker, - stimulus_id=stimulus_id, - ) - ) - recommendations[ts] = "fetch" - - if refresh_who_has: - # All workers that hold known replicas of our tasks are busy. - # Try querying the scheduler for unknown ones. - instructions.append( - RequestRefreshWhoHasMsg( - keys=refresh_who_has, - stimulus_id=f"gather-dep-busy-{time()}", - ) + Yields the tasks that need to transition out of flight. + """ + self.comm_nbytes -= ev.total_nbytes + keys = self.in_flight_workers.pop(ev.worker) + for key in keys: + ts = self.tasks[key] + ts.done = True + yield ts + + @_handle_event.register + def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs: + """gather_dep terminated successfully. + The response may contain less keys than the request. + """ + recommendations: Recs = {} + for ts in self._gather_dep_done_common(ev): + if ts.key in ev.data: + recommendations[ts] = ("memory", ev.data[ts.key]) + else: + ts.who_has.discard(ev.worker) + self.has_what[ev.worker].discard(ts.key) + recommendations[ts] = "fetch" + + return merge_recs_instructions( + (recommendations, []), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) + + @_handle_event.register + def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs: + """gather_dep terminated: remote worker is busy""" + # Avoid hammering the worker. If there are multiple replicas + # available, immediately try fetching from a different worker. + self.busy_workers.add(ev.worker) + + recommendations: Recs = {} + refresh_who_has = [] + for ts in self._gather_dep_done_common(ev): + recommendations[ts] = "fetch" + if not ts.who_has - self.busy_workers: + refresh_who_has.append(ts.key) + + instructions: Instructions = [ + RetryBusyWorkerLater(worker=ev.worker, stimulus_id=ev.stimulus_id), + ] + + if refresh_who_has: + # All workers that hold known replicas of our tasks are busy. + # Try querying the scheduler for unknown ones. + instructions.append( + RequestRefreshWhoHasMsg( + keys=refresh_who_has, stimulus_id=ev.stimulus_id ) + ) - self.transitions(recommendations, stimulus_id=stimulus_id) - self._handle_instructions(instructions) + return merge_recs_instructions( + (recommendations, instructions), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) + + @_handle_event.register + def _handle_gather_dep_network_failure( + self, ev: GatherDepNetworkFailureEvent + ) -> RecsInstrs: + """gather_dep terminated: network failure while trying to + communicate with remote worker + + Though the network failure could be transient, we assume it is not, and + preemptively act as though the other worker has died (including removing all + keys from it, even ones we did not fetch). + + This optimization leads to faster completion of the fetch, since we immediately + either retry a different worker, or ask the scheduler to inform us of a new + worker if no other worker is available. + """ + recommendations: Recs = {} + instructions: Instructions = [] + worker = ev.worker + for key in self.has_what[worker]: + ts = self.tasks[key] + ts.who_has.discard(worker) + + del ts + + for ts in self._gather_dep_done_common(ev): + ts.who_has.discard(worker) + self.log.append((ts.key, "missing-dep", ev.stimulus_id, time())) + recommendations[ts] = "fetch" + + # This cleanup must happen after _refetch_missing_data + del self.has_what[ev.worker] + del self.data_needed_per_worker[ev.worker] + recs, instrs = merge_recs_instructions( + (recommendations, instructions), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) + return recs, instrs + + @_handle_event.register + def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs: + """gather_dep terminated: generic error raised (not a network failure); + e.g. data failed to deserialize. + """ + recommendations: Recs = { + ts: ( + "error", + ev.exception, + ev.traceback, + ev.exception_text, + ev.traceback_text, + ) + for ts in self._gather_dep_done_common(ev) + } + + return merge_recs_instructions( + (recommendations, []), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None: await asyncio.sleep(0.15) @@ -3844,11 +3917,6 @@ def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs: self._ensure_communicating(stimulus_id=ev.stimulus_id), ) - @_handle_event.register - def _handle_gather_dep_done(self, ev: GatherDepDoneEvent) -> RecsInstrs: - """Temporary hack - to be removed""" - return self._ensure_communicating(stimulus_id=ev.stimulus_id) - @_handle_event.register def _handle_retry_busy_worker(self, ev: RetryBusyWorkerEvent) -> RecsInstrs: self.busy_workers.discard(ev.worker) @@ -4183,8 +4251,6 @@ def validate_task_fetch(self, ts): assert ts.key not in self.data assert self.address not in ts.who_has assert not ts.done - assert ts in self.data_needed - assert ts.who_has for w in ts.who_has: assert ts.key in self.has_what[w] diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index de8fcde0cb7..b6fa61f580c 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -389,15 +389,6 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler): key: str -@dataclass -class MissingDataMsg(SendMessageToScheduler): - op = "missing-data" - - __slots__ = ("key", "errant_worker") - key: str - errant_worker: str - - # Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception @dataclass class RescheduleMsg(SendMessageToScheduler): @@ -532,11 +523,88 @@ class RetryBusyWorkerEvent(StateMachineEvent): @dataclass class GatherDepDoneEvent(StateMachineEvent): - """Temporary hack - to be removed""" + """:class:`GatherDep` instruction terminated (abstract base class)""" + + __slots__ = ("worker", "total_nbytes") + worker: str + total_nbytes: int # Must be the same as in GatherDep instruction + + +@dataclass +class GatherDepSuccessEvent(GatherDepDoneEvent): + """:class:`GatherDep` instruction terminated: + remote worker fetched successfully + """ + + __slots__ = ("data",) + + data: dict[str, object] # There may be less keys than in GatherDep + + def to_loggable(self, *, handled: float) -> StateMachineEvent: + out = copy(self) + out.handled = handled + out.data = {k: None for k in self.data} + return out + + def _after_from_dict(self) -> None: + self.data = {k: None for k in self.data} + + +@dataclass +class GatherDepBusyEvent(GatherDepDoneEvent): + """:class:`GatherDep` instruction terminated: + remote worker is busy + """ __slots__ = () +@dataclass +class GatherDepNetworkFailureEvent(GatherDepDoneEvent): + """:class:`GatherDep` instruction terminated: + network failure while trying to communicate with remote worker + """ + + __slots__ = () + + +@dataclass +class GatherDepFailureEvent(GatherDepDoneEvent): + """class:`GatherDep` instruction terminated: + generic error raised (not a network failure); e.g. data failed to deserialize. + """ + + exception: Serialize + traceback: Serialize | None + exception_text: str + traceback_text: str + __slots__ = tuple(__annotations__) # type: ignore + + def _after_from_dict(self) -> None: + self.exception = Serialize(Exception()) + self.traceback = None + + @classmethod + def from_exception( + cls, + err: BaseException, + *, + worker: str, + total_nbytes: int, + stimulus_id: str, + ) -> GatherDepFailureEvent: + msg = error_message(err) + return cls( + worker=worker, + total_nbytes=total_nbytes, + exception=msg["exception"], + traceback=msg["traceback"], + exception_text=msg["exception_text"], + traceback_text=msg["traceback_text"], + stimulus_id=stimulus_id, + ) + + @dataclass class ComputeTaskEvent(StateMachineEvent): key: str