From df1eaba1f565ab016a048c5bb2aed10fdcd51e09 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 10 Jun 2022 17:31:21 +0100 Subject: [PATCH] Refactor gather_dep (#6388) --- distributed/tests/test_stories.py | 5 +- distributed/tests/test_worker.py | 1 - .../tests/test_worker_state_machine.py | 35 +++ distributed/worker.py | 238 +++++++++++------- distributed/worker_state_machine.py | 79 +++++- 5 files changed, 263 insertions(+), 95 deletions(-) diff --git a/distributed/tests/test_stories.py b/distributed/tests/test_stories.py index ec81ddaff8b..9c82e1c9a21 100644 --- a/distributed/tests/test_stories.py +++ b/distributed/tests/test_stories.py @@ -138,8 +138,7 @@ async def test_worker_story_with_deps(c, s, a, b): # Story now includes randomized stimulus_ids and timestamps. story = b.story("res") stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story} - assert stimulus_ids == {"compute-task", "task-finished"} - + assert stimulus_ids == {"compute-task", "gather-dep-success", "task-finished"} # This is a simple transition log expected = [ ("res", "compute-task", "released"), @@ -153,7 +152,7 @@ async def test_worker_story_with_deps(c, s, a, b): story = b.story("dep") stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story} - assert stimulus_ids == {"compute-task"} + assert stimulus_ids == {"compute-task", "gather-dep-success"} expected = [ ("dep", "ensure-task-exists", "released"), ("dep", "released", "fetch", "fetch", {}), diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 8c1978cecf8..32cee6209b0 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2928,7 +2928,6 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers): coming_from.handle_stimulus(RemoveReplicasEvent(keys=[f1.key], stimulus_id="test")) await f2 - assert_story(a.story(f1.key), [(f1.key, "missing-dep")]) assert a.tasks[f1.key].suspicious_count == 0 assert s.tasks[f1.key].suspicious == 0 diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 94408d1ef85..4fa937581cd 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -647,3 +647,38 @@ async def test_fetch_to_missing_on_refresh_who_has(c, s, w1, w2, w3): assert w3.tasks["x"].state == "missing" assert w3.tasks["y"].state == "flight" assert w3.tasks["y"].who_has == {w2.address} + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_fetch_to_missing_on_network_failure(c, s, a): + """ + 1. Two tasks, x and y, are respectively in flight and fetch state from the same + worker, which holds the only replica of both. + 2. gather_dep for x returns GatherDepNetworkFailureEvent + 3. The event empties has_what, x.who_has, and y.who_has. + 4. The same event invokes _ensure_communicating, which pops y from data_needed + - but y has an empty who_has, which is an exceptional situation. + _ensure_communicating recommends a transition to missing for x. + 5. The fetch->missing transition is executed, but y is no longer in data_needed - + another exceptional situation. + """ + block_get_data = asyncio.Event() + + class BlockedBreakingWorker(Worker): + async def get_data(self, comm, *args, **kwargs): + await block_get_data.wait() + raise OSError("fake error") + + async with BlockedBreakingWorker(s.address) as b: + x = c.submit(inc, 1, key="x", workers=[b.address]) + y = c.submit(inc, 2, key="y", workers=[b.address]) + await wait([x, y]) + s.request_acquire_replicas(a.address, ["x"], stimulus_id="test_x") + await wait_for_state("x", "flight", a) + s.request_acquire_replicas(a.address, ["y"], stimulus_id="test_y") + await wait_for_state("y", "fetch", a) + + block_get_data.set() + + await wait_for_state("x", "missing", a) + await wait_for_state("y", "missing", a) diff --git a/distributed/worker.py b/distributed/worker.py index ca3dff93969..2c590befdc4 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -21,6 +21,7 @@ Collection, Container, Iterable, + Iterator, Mapping, MutableMapping, ) @@ -122,7 +123,11 @@ FindMissingEvent, FreeKeysEvent, GatherDep, + GatherDepBusyEvent, GatherDepDoneEvent, + GatherDepFailureEvent, + GatherDepNetworkFailureEvent, + GatherDepSuccessEvent, Instructions, InvalidTaskState, InvalidTransition, @@ -2185,13 +2190,7 @@ def transition_fetch_flight( def transition_fetch_missing( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: - # There's a use case where ts won't be found in self.data_needed, so - # `self.data_needed.remove(ts)` would crash: - # 1. An event handler empties who_has and pushes a recommendation to missing - # 2. The same event handler calls _ensure_communicating, which pops the task - # from data_needed - # 3. The recommendation is enacted - # See matching code in _ensure_communicating. + # _ensure_communicating could have just popped this task out of data_needed self.data_needed.discard(ts) return self.transition_generic_missing(ts, stimulus_id=stimulus_id) @@ -3017,11 +3016,7 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: assert self.address not in ts.who_has if not ts.who_has: - # An event handler just emptied who_has and recommended a fetch->missing - # transition. Then, the same handler called _ensure_communicating. The - # transition hasn't been enacted yet, so the task is still in fetch - # state and in data_needed. - # See matching code in transition_fetch_missing. + recommendations[ts] = "missing" continue workers = [ @@ -3293,13 +3288,6 @@ async def gather_dep( if self.status not in WORKER_ANY_RUNNING: return None - recommendations: Recs = {} - instructions: Instructions = [] - response = {} - - def done_event(): - return GatherDepDoneEvent(stimulus_id=f"gather-dep-done-{time()}") - try: self.log.append(("request-dep", worker, to_gather, stimulus_id, time())) logger.debug("Request %d keys from %s", len(to_gather), worker) @@ -3310,8 +3298,14 @@ def done_event(): ) stop = time() if response["status"] == "busy": - return done_event() + self.log.append(("busy-gather", worker, to_gather, stimulus_id, time())) + return GatherDepBusyEvent( + worker=worker, + total_nbytes=total_nbytes, + stimulus_id=f"gather-dep-busy-{time()}", + ) + assert response["status"] == "OK" cause = self._get_cause(to_gather) self._update_metrics_received_data( start=start, @@ -3323,86 +3317,156 @@ def done_event(): self.log.append( ("receive-dep", worker, set(response["data"]), stimulus_id, time()) ) - return done_event() + return GatherDepSuccessEvent( + worker=worker, + total_nbytes=total_nbytes, + data=response["data"], + stimulus_id=f"gather-dep-success-{time()}", + ) except OSError: logger.exception("Worker stream died during communication: %s", worker) - has_what = self.has_what.pop(worker) - self.data_needed_per_worker.pop(worker) self.log.append( - ("receive-dep-failed", worker, has_what, stimulus_id, time()) + ("receive-dep-failed", worker, to_gather, stimulus_id, time()) + ) + return GatherDepNetworkFailureEvent( + worker=worker, + total_nbytes=total_nbytes, + stimulus_id=f"gather-dep-network-failure-{time()}", ) - for d in has_what: - ts = self.tasks[d] - ts.who_has.remove(worker) - if not ts.who_has and ts.state in ( - "fetch", - "flight", - "resumed", - "cancelled", - ): - recommendations[ts] = "missing" - self.log.append( - ("missing-who-has", worker, ts.key, stimulus_id, time()) - ) - return done_event() except Exception as e: + # e.g. data failed to deserialize logger.exception(e) if self.batched_stream and LOG_PDB: import pdb pdb.set_trace() - msg = error_message(e) - for k in self.in_flight_workers[worker]: - ts = self.tasks[k] - recommendations[ts] = tuple(msg.values()) - return done_event() - finally: - self.comm_nbytes -= total_nbytes - busy = response.get("status", "") == "busy" - data = response.get("data", {}) + return GatherDepFailureEvent.from_exception( + e, + worker=worker, + total_nbytes=total_nbytes, + stimulus_id=f"gather-dep-failure-{time()}", + ) - if busy: - self.log.append(("busy-gather", worker, to_gather, stimulus_id, time())) - # Avoid hammering the worker. If there are multiple replicas - # available, immediately try fetching from a different worker. - self.busy_workers.add(worker) - instructions.append( - RetryBusyWorkerLater(worker=worker, stimulus_id=stimulus_id) - ) + def _gather_dep_done_common(self, ev: GatherDepDoneEvent) -> Iterator[TaskState]: + """Common code for all subclasses of GatherDepDoneEvent. - refresh_who_has = [] - - for d in self.in_flight_workers.pop(worker): - ts = self.tasks[d] - ts.done = True - if d in data: - recommendations[ts] = ("memory", data[d]) - elif busy: - recommendations[ts] = "fetch" - if not ts.who_has - self.busy_workers: - refresh_who_has.append(d) - elif ts not in recommendations: - ts.who_has.discard(worker) - self.has_what[worker].discard(ts.key) - self.data_needed_per_worker[worker].discard(ts) - self.log.append((d, "missing-dep", stimulus_id, time())) - recommendations[ts] = "fetch" - - if refresh_who_has: - # All workers that hold known replicas of our tasks are busy. - # Try querying the scheduler for unknown ones. - instructions.append( - RequestRefreshWhoHasMsg( - keys=refresh_who_has, - stimulus_id=f"gather-dep-busy-{time()}", - ) + Yields the tasks that need to transition out of flight. + """ + self.comm_nbytes -= ev.total_nbytes + keys = self.in_flight_workers.pop(ev.worker) + for key in keys: + ts = self.tasks[key] + ts.done = True + yield ts + + @_handle_event.register + def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs: + """gather_dep terminated successfully. + The response may contain less keys than the request. + """ + recommendations: Recs = {} + for ts in self._gather_dep_done_common(ev): + if ts.key in ev.data: + recommendations[ts] = ("memory", ev.data[ts.key]) + else: + self.log.append((ts.key, "missing-dep", ev.stimulus_id, time())) + if self.validate: + assert ts.state != "fetch" + assert ts not in self.data_needed_per_worker[ev.worker] + ts.who_has.discard(ev.worker) + self.has_what[ev.worker].discard(ts.key) + recommendations[ts] = "fetch" + + return merge_recs_instructions( + (recommendations, []), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) + + @_handle_event.register + def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs: + """gather_dep terminated: remote worker is busy""" + # Avoid hammering the worker. If there are multiple replicas + # available, immediately try fetching from a different worker. + self.busy_workers.add(ev.worker) + + recommendations: Recs = {} + refresh_who_has = [] + for ts in self._gather_dep_done_common(ev): + recommendations[ts] = "fetch" + if not ts.who_has - self.busy_workers: + refresh_who_has.append(ts.key) + + instructions: Instructions = [ + RetryBusyWorkerLater(worker=ev.worker, stimulus_id=ev.stimulus_id), + ] + + if refresh_who_has: + # All workers that hold known replicas of our tasks are busy. + # Try querying the scheduler for unknown ones. + instructions.append( + RequestRefreshWhoHasMsg( + keys=refresh_who_has, stimulus_id=ev.stimulus_id ) + ) - self.transitions(recommendations, stimulus_id=stimulus_id) - self._handle_instructions(instructions) + return merge_recs_instructions( + (recommendations, instructions), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) + + @_handle_event.register + def _handle_gather_dep_network_failure( + self, ev: GatherDepNetworkFailureEvent + ) -> RecsInstrs: + """gather_dep terminated: network failure while trying to + communicate with remote worker + + Though the network failure could be transient, we assume it is not, and + preemptively act as though the other worker has died (including removing all + keys from it, even ones we did not fetch). + + This optimization leads to faster completion of the fetch, since we immediately + either retry a different worker, or ask the scheduler to inform us of a new + worker if no other worker is available. + """ + self.data_needed_per_worker.pop(ev.worker) + for key in self.has_what.pop(ev.worker): + ts = self.tasks[key] + ts.who_has.discard(ev.worker) + + recommendations: Recs = {} + for ts in self._gather_dep_done_common(ev): + self.log.append((ts.key, "missing-dep", ev.stimulus_id, time())) + recommendations[ts] = "fetch" + + return merge_recs_instructions( + (recommendations, []), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) + + @_handle_event.register + def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs: + """gather_dep terminated: generic error raised (not a network failure); + e.g. data failed to deserialize. + """ + recommendations: Recs = { + ts: ( + "error", + ev.exception, + ev.traceback, + ev.exception_text, + ev.traceback_text, + ) + for ts in self._gather_dep_done_common(ev) + } + + return merge_recs_instructions( + (recommendations, []), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None: await asyncio.sleep(0.15) @@ -3841,11 +3905,6 @@ def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs: self._ensure_communicating(stimulus_id=ev.stimulus_id), ) - @_handle_event.register - def _handle_gather_dep_done(self, ev: GatherDepDoneEvent) -> RecsInstrs: - """Temporary hack - to be removed""" - return self._ensure_communicating(stimulus_id=ev.stimulus_id) - @_handle_event.register def _handle_retry_busy_worker(self, ev: RetryBusyWorkerEvent) -> RecsInstrs: self.busy_workers.discard(ev.worker) @@ -4181,8 +4240,7 @@ def validate_task_fetch(self, ts): assert self.address not in ts.who_has assert not ts.done assert ts in self.data_needed - assert ts.who_has - + # Note: ts.who_has may be empty; see GatherDepNetworkFailureEvent for w in ts.who_has: assert ts.key in self.has_what[w] assert ts in self.data_needed_per_worker[w] diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index b63f7318437..b6fa61f580c 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -523,11 +523,88 @@ class RetryBusyWorkerEvent(StateMachineEvent): @dataclass class GatherDepDoneEvent(StateMachineEvent): - """Temporary hack - to be removed""" + """:class:`GatherDep` instruction terminated (abstract base class)""" + + __slots__ = ("worker", "total_nbytes") + worker: str + total_nbytes: int # Must be the same as in GatherDep instruction + + +@dataclass +class GatherDepSuccessEvent(GatherDepDoneEvent): + """:class:`GatherDep` instruction terminated: + remote worker fetched successfully + """ + + __slots__ = ("data",) + + data: dict[str, object] # There may be less keys than in GatherDep + + def to_loggable(self, *, handled: float) -> StateMachineEvent: + out = copy(self) + out.handled = handled + out.data = {k: None for k in self.data} + return out + + def _after_from_dict(self) -> None: + self.data = {k: None for k in self.data} + + +@dataclass +class GatherDepBusyEvent(GatherDepDoneEvent): + """:class:`GatherDep` instruction terminated: + remote worker is busy + """ __slots__ = () +@dataclass +class GatherDepNetworkFailureEvent(GatherDepDoneEvent): + """:class:`GatherDep` instruction terminated: + network failure while trying to communicate with remote worker + """ + + __slots__ = () + + +@dataclass +class GatherDepFailureEvent(GatherDepDoneEvent): + """class:`GatherDep` instruction terminated: + generic error raised (not a network failure); e.g. data failed to deserialize. + """ + + exception: Serialize + traceback: Serialize | None + exception_text: str + traceback_text: str + __slots__ = tuple(__annotations__) # type: ignore + + def _after_from_dict(self) -> None: + self.exception = Serialize(Exception()) + self.traceback = None + + @classmethod + def from_exception( + cls, + err: BaseException, + *, + worker: str, + total_nbytes: int, + stimulus_id: str, + ) -> GatherDepFailureEvent: + msg = error_message(err) + return cls( + worker=worker, + total_nbytes=total_nbytes, + exception=msg["exception"], + traceback=msg["traceback"], + exception_text=msg["exception_text"], + traceback_text=msg["traceback_text"], + stimulus_id=stimulus_id, + ) + + @dataclass class ComputeTaskEvent(StateMachineEvent): key: str