From c39288a35bc198317f07ec4fce08bccdff719cc1 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 13 May 2022 22:52:44 +0100 Subject: [PATCH] Overhaul update_who_has --- distributed/tests/test_stories.py | 1 + distributed/tests/test_worker.py | 1 + distributed/worker.py | 108 +++++++++++++++++++++--------- 3 files changed, 79 insertions(+), 31 deletions(-) diff --git a/distributed/tests/test_stories.py b/distributed/tests/test_stories.py index ec81ddaff8b..232eca6c36c 100644 --- a/distributed/tests/test_stories.py +++ b/distributed/tests/test_stories.py @@ -156,6 +156,7 @@ async def test_worker_story_with_deps(c, s, a, b): assert stimulus_ids == {"compute-task"} expected = [ ("dep", "ensure-task-exists", "released"), + ("dep", "update-who-has", [], [a.address]), ("dep", "released", "fetch", "fetch", {}), ("gather-dependencies", a.address, {"dep"}), ("dep", "fetch", "flight", "flight", {}), diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index bb99438e9f4..5448328cc36 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2814,6 +2814,7 @@ def __getstate__(self): story[b], [ ("x", "ensure-task-exists", "released"), + ("x", "update-who-has", [], [a]), ("x", "released", "fetch", "fetch", {}), ("gather-dependencies", a, {"x"}), ("x", "fetch", "flight", "flight", {}), diff --git a/distributed/worker.py b/distributed/worker.py index 9dc2e8e61ed..272e2dc77f1 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1873,8 +1873,12 @@ def handle_acquire_replicas( if ts.state != "memory": recommendations[ts] = "fetch" - self.update_who_has(who_has) + recommendations, instructions = merge_recs_instructions( + (recommendations, []), + self._update_who_has(who_has, stimulus_id=stimulus_id), + ) self.transitions(recommendations, stimulus_id=stimulus_id) + self._handle_instructions(instructions) if self.validate: for key in keys: @@ -1977,7 +1981,10 @@ def handle_compute_task( for dep_key, value in nbytes.items(): self.tasks[dep_key].nbytes = value - self.update_who_has(who_has) + recommendations, instructions = merge_recs_instructions( + (recommendations, instructions), + self._update_who_has(who_has, stimulus_id=stimulus_id), + ) else: # pragma: nocover raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}") @@ -3136,8 +3143,14 @@ def _select_keys_for_gather( while tasks: ts = tasks.peek() - if ts.state != "fetch" or ts.key in all_keys_to_gather: + if ( + ts.state != "fetch" # Do not acquire the same key twice if multiple workers holds replicas + or ts.key in all_keys_to_gather + # A replica is still available (otherwise status would not be 'fetch' + # anymore), but not on this worker. See _update_who_has(). + or worker not in ts.who_has + ): tasks.pop() continue if total_bytes + ts.get_nbytes() > self.target_message_size: @@ -3422,7 +3435,11 @@ def done_event(): who_has = await retry_operation( self.scheduler.who_has, keys=refresh_who_has ) - self.update_who_has(who_has) + recommendations, instructions = self._update_who_has( + who_has, stimulus_id=stimulus_id + ) + self.transitions(recommendations, stimulus_id=stimulus_id) + self._handle_instructions(instructions) @log_errors def _readd_busy_worker(self, worker: str) -> None: @@ -3445,12 +3462,15 @@ async def find_missing(self) -> None: self.scheduler.who_has, keys=[ts.key for ts in self._missing_dep_flight], ) - self.update_who_has(who_has) - recommendations: Recs = {} + recommendations, instructions = self._update_who_has( + who_has, stimulus_id=stimulus_id + ) for ts in self._missing_dep_flight: if ts.who_has: + assert ts not in recommendations recommendations[ts] = "fetch" self.transitions(recommendations, stimulus_id=stimulus_id) + self._handle_instructions(instructions) finally: # This is quite arbitrary but the heartbeat has scaling implemented @@ -3458,34 +3478,60 @@ async def find_missing(self) -> None: "find-missing" ].callback_time = self.periodic_callbacks["heartbeat"].callback_time - def update_who_has(self, who_has: dict[str, Collection[str]]) -> None: - try: - for dep, workers in who_has.items(): - if not workers: - continue + def _update_who_has( + self, who_has: dict[str, Collection[str]], *, stimulus_id: str + ) -> RecsInstrs: + recs: Recs = {} + instructions: Instructions = [] - if dep in self.tasks: - dep_ts = self.tasks[dep] - if self.address in workers and self.tasks[dep].state != "memory": - logger.debug( - "Scheduler claims worker %s holds data for task %s which is not true.", - self.name, - dep, - ) - # Do not mutate the input dict. That's rude - workers = set(workers) - {self.address} - dep_ts.who_has.update(workers) + for key, workers in who_has.items(): + ts = self.tasks.get(key) + if not ts: + continue + workers = set(workers) - for worker in workers: - self.has_what[worker].add(dep) - self.data_needed_per_worker[worker].push(dep_ts) - except Exception as e: # pragma: no cover - logger.exception(e) - if LOG_PDB: - import pdb + if self.address in workers and ts.state != "memory": + logger.debug( + "Scheduler claims worker %s holds data for task %s which is not true.", + self.name, + ts, + ) + workers.remove(self.address) + instructions.append( + MissingDataMsg( + key=key, errant_worker=self.address, stimulus_id=stimulus_id + ) + ) - pdb.set_trace() - raise + if ts.who_has == workers: + continue + + self.log.append( + ( + key, + "update-who-has", + list(ts.who_has), + list(workers), + stimulus_id, + time(), + ) + ) + + for worker in ts.who_has - workers: + self.has_what[worker].discard(key) + # Can't remove from self.data_needed_per_worker; there is logic + # in _select_keys_for_gather to deal with this + + for worker in workers - ts.who_has: + self.has_what[worker].add(key) + if ts.state == "fetch": + self.data_needed_per_worker[worker].push(ts) + + ts.who_has = workers + if not workers and ts.state == "fetch": + recs[ts] = "missing" + + return recs, instructions def handle_steal_request(self, key: str, stimulus_id: str) -> None: # There may be a race condition between stealing and releasing a task.