diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 98a4174775d..5342d693e91 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1850,7 +1850,12 @@ def _transition( start = "released" else: - raise RuntimeError(f"Impossible transition from {start} to {finish}") + # FIXME downcast antipattern + scheduler = cast(Scheduler, self) + raise RuntimeError( + f"Impossible transition from {start} to {finish} for {key!r}: " + f"{stimulus_id=}, {args=}, {kwargs=}, story={scheduler.story(ts)}" + ) if not stimulus_id: stimulus_id = STIMULUS_ID_UNSET @@ -2023,50 +2028,6 @@ def transition_no_worker_processing(self, key, stimulus_id): pdb.set_trace() raise - def transition_no_worker_memory( - self, - key: str, - stimulus_id: str, - *, - nbytes: int | None = None, - type: bytes | None = None, - typename: str | None = None, - worker: str, - **kwargs: Any, - ): - try: - ws = self.workers[worker] - ts = self.tasks[key] - recommendations: dict = {} - client_msgs: dict = {} - worker_msgs: dict = {} - - if self.validate: - assert not ts.processing_on - assert not ts.waiting_on - assert ts.state == "no-worker" - - self.unrunnable.remove(ts) - - if nbytes is not None: - ts.set_nbytes(nbytes) - - self.check_idle_saturated(ws) - - _add_to_memory( - self, ts, ws, recommendations, client_msgs, type=type, typename=typename - ) - ts.state = "memory" - - return recommendations, client_msgs, worker_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb - - pdb.set_trace() - raise - def decide_worker_rootish_queuing_disabled( self, ts: TaskState ) -> WorkerState | None: @@ -2292,35 +2253,23 @@ def transition_waiting_memory( worker: str, **kwargs: Any, ): + """This transition exclusively happens in a race condition where the scheduler + believes that the only copy of a dependency task has just been lost, so it + transitions all dependents back to waiting, but actually a replica has already + been acquired by a worker computing the dependency - the scheduler just doesn't + know yet - and the execution finishes before the cancellation message from the + scheduler has a chance to reach the worker. Shortly, the cancellation request + will reach the worker, thus deleting the data from memory. + """ try: - ws: WorkerState = self.workers[worker] - ts: TaskState = self.tasks[key] - recommendations: dict = {} - client_msgs: dict = {} - worker_msgs: dict = {} + ts = self.tasks[key] if self.validate: assert not ts.processing_on assert ts.waiting_on assert ts.state == "waiting" - ts.waiting_on.clear() - - if nbytes is not None: - ts.set_nbytes(nbytes) - - self.check_idle_saturated(ws) - - _add_to_memory( - self, ts, ws, recommendations, client_msgs, type=type, typename=typename - ) - - if self.validate: - assert not ts.processing_on - assert not ts.waiting_on - assert ts.who_has - - return recommendations, client_msgs, worker_msgs + return {}, {}, {} except Exception as e: logger.exception(e) if LOG_PDB: @@ -2365,21 +2314,15 @@ def transition_processing_memory( if ws is None: return {key: "released"}, {}, {} - if ws != ts.processing_on: # someone else has this task - logger.info( - "Unexpected worker completed task. Expected: %s, Got: %s, Key: %s", - ts.processing_on, - ws, - key, - ) + if ws != ts.processing_on: # pragma: nocover assert ts.processing_on - worker_msgs[ts.processing_on.address] = [ - { - "op": "cancel-compute", - "key": key, - "stimulus_id": stimulus_id, - } - ] + # FIXME downcast antipattern + scheduler = cast(Scheduler, self) + raise RuntimeError( + f"Task {ts.key!r} transitioned from processing to memory on worker " + f"{ws}, while it was expected from {ts.processing_on}. This should " + f"be impossible. {stimulus_id=}, story={scheduler.story(ts)}" + ) ############################# # Update Timing Information # @@ -2650,7 +2593,7 @@ def transition_processing_released(self, key: str, stimulus_id: str): } ] - _propagage_released(self, ts, recommendations) + _propagate_released(self, ts, recommendations) return recommendations, {}, worker_msgs except Exception as e: logger.exception(e) @@ -2874,7 +2817,7 @@ def transition_queued_released(self, key, stimulus_id): self.queued.remove(ts) - _propagage_released(self, ts, recommendations) + _propagate_released(self, ts, recommendations) return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) @@ -3027,7 +2970,6 @@ def transition_released_forgotten(self, key, stimulus_id): ("processing", "erred"): transition_processing_erred, ("no-worker", "released"): transition_no_worker_released, ("no-worker", "processing"): transition_no_worker_processing, - ("no-worker", "memory"): transition_no_worker_memory, ("released", "forgotten"): transition_released_forgotten, ("memory", "forgotten"): transition_memory_forgotten, ("erred", "released"): transition_erred_released, @@ -7965,7 +7907,7 @@ def _add_to_memory( ) -def _propagage_released( +def _propagate_released( state: SchedulerState, ts: TaskState, recommendations: Recs, @@ -8319,10 +8261,9 @@ def heartbeat_interval(n: int) -> float: def _task_slots_available(ws: WorkerState, saturation_factor: float) -> int: - "Number of tasks that can be sent to this worker without oversaturating it" + """Number of tasks that can be sent to this worker without oversaturating it""" assert not math.isinf(saturation_factor) - nthreads = ws.nthreads - return max(math.ceil(saturation_factor * nthreads), 1) - ( + return max(math.ceil(saturation_factor * ws.nthreads), 1) - ( len(ws.processing) - len(ws.long_running) ) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index db231c317dd..ce66ada35eb 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -45,11 +45,13 @@ NO_AMM, BlockedGatherDep, BrokenComm, + assert_story, async_wait_for, captured_logger, cluster, dec, div, + freeze_batched_send, freeze_data_fetching, gen_cluster, gen_test, @@ -4099,3 +4101,43 @@ async def test_count_task_prefix(c, s, a, b): assert s.task_prefixes["inc"].state_counts["memory"] == 20 assert s.task_prefixes["inc"].state_counts["erred"] == 0 + + +@gen_cluster(client=True) +async def test_transition_waiting_memory(c, s, a, b): + """Test race condition where a task transitions to memory while its state on the + scheduler is waiting: + + 1. worker a finishes x + 2. y transitions to processing and is assigned to worker b + 3. b fetches x and sends an add_keys message to the scheduler + 4. In the meantime, a dies and causes x to be scheduled back to released/waiting. + 5. Scheduler queues up a free-keys intended for b to cancel both x and y + 6. Before free-keys arrives to b, the worker runs and completes y, sending a + finished-task message to the scheduler + 7. {op: add-keys, keys=[x]} from b finally arrives to the scheduler. This triggers + a {op: remove-replicas, keys=[x]} message from the scheduler to worker b, because + add-keys when the task state is not memory triggers a cleanup of redundant + replicas (see Scheduler.add_keys) - in this, add-keys differs from task-finished! + 8. {op: task-finished, key=y} from b arrives to the scheduler and it is ignored. + """ + x = c.submit(inc, 1, key="x", workers=[a.address]) + y = c.submit(inc, x, key="y", workers=[b.address]) + await wait_for_state("x", "memory", b, interval=0) + # Note interval=0 above. It means that x has just landed on b this instant and the + # scheduler doesn't know yet. + assert b.state.tasks["y"].state == "executing" + assert s.tasks["x"].who_has == {s.workers[a.address]} + + with freeze_batched_send(b.batched_stream): + with freeze_batched_send(s.stream_comms[b.address]): + await s.remove_worker(a.address, stimulus_id="remove_a") + assert s.tasks["x"].state == "no-worker" + assert s.tasks["y"].state == "waiting" + await wait_for_state("y", "memory", b) + + await async_wait_for(lambda: not b.state.tasks, timeout=5) + + assert s.tasks["x"].state == "no-worker" + assert s.tasks["y"].state == "waiting" + assert_story(s.story("y"), [("y", "waiting", "waiting", {})])