From d3806043e7332963ea06ee8fc5407e2d5c135e31 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 27 Jan 2022 15:14:39 +0100 Subject: [PATCH] Ensure missing transitions are safe --- distributed/tests/test_worker.py | 263 +++++++++++++++++++++++++------ distributed/worker.py | 194 ++++++++++++++--------- 2 files changed, 334 insertions(+), 123 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index b6c4961ff6d..55aa5d8f9db 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -57,7 +57,14 @@ slowinc, slowsum, ) -from distributed.worker import Worker, error_message, logger, parse_memory_limit +from distributed.worker import ( + TaskState, + Worker, + _UniqueTaskHeap, + error_message, + logger, + parse_memory_limit, +) pytestmark = pytest.mark.ci1 @@ -1390,7 +1397,7 @@ async def test_prefer_gather_from_local_address(c, s, w1, w2, w3): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1)] * 20, - timeout=30, + timeout=500000, config={"distributed.worker.connections.incoming": 1}, ) async def test_avoid_oversubscription(c, s, *workers): @@ -1400,7 +1407,10 @@ async def test_avoid_oversubscription(c, s, *workers): futures = [c.submit(len, x, pure=False, workers=[w.address]) for w in workers[1:]] - await wait(futures) + try: + await asyncio.wait_for(wait(futures), 10) + except asyncio.TimeoutError: + breakpoint() # Original worker not responsible for all transfers assert len(workers[0].outgoing_transfer_log) < len(workers) - 2 @@ -2186,12 +2196,26 @@ async def test_gpu_executor(c, s, w): assert "gpu" not in w.executors -def assert_task_states_on_worker(expected, worker): - for dep_key, expected_state in expected.items(): - assert dep_key in worker.tasks, (worker.name, dep_key, worker.tasks) - dep_ts = worker.tasks[dep_key] - assert dep_ts.state == expected_state, (worker.name, dep_ts, expected_state) - assert set(expected) == set(worker.tasks) +async def assert_task_states_on_worker(expected, worker): + active_exc = None + for _ in range(10): + try: + for dep_key, expected_state in expected.items(): + assert dep_key in worker.tasks, (worker.name, dep_key, worker.tasks) + dep_ts = worker.tasks[dep_key] + assert dep_ts.state == expected_state, ( + worker.name, + dep_ts, + expected_state, + ) + assert set(expected) == set(worker.tasks) + return + except AssertionError as exc: + active_exc = exc + await asyncio.sleep(0.1) + # If after a second the workers are not in equilibrium, they are broken + assert active_exc + raise active_exc @gen_cluster(client=True) @@ -2235,7 +2259,7 @@ def raise_exc(*args): g.key: "memory", res.key: "error", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) # Expected states after we release references to the futures f.release() g.release() @@ -2251,7 +2275,7 @@ def raise_exc(*args): res.key: "error", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) res.release() @@ -2304,7 +2328,7 @@ def raise_exc(*args): g.key: "memory", res.key: "error", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) # Expected states after we release references to the futures res.release() @@ -2318,7 +2342,7 @@ def raise_exc(*args): g.key: "memory", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) f.release() g.release() @@ -2369,7 +2393,7 @@ def raise_exc(*args): g.key: "memory", res.key: "error", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) # Expected states after we release references to the futures f.release() @@ -2383,8 +2407,8 @@ def raise_exc(*args): g.key: "memory", } - assert_task_states_on_worker(expected_states, a) - assert_task_states_on_worker(expected_states, b) + await assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, b) g.release() @@ -2418,8 +2442,7 @@ def raise_exc(*args): g.key: "memory", h.key: "memory", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_A, a) + await assert_task_states_on_worker(expected_states_A, a) expected_states_B = { f.key: "memory", @@ -2427,8 +2450,7 @@ def raise_exc(*args): h.key: "memory", res.key: "error", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_B, b) + await assert_task_states_on_worker(expected_states_B, b) f.release() @@ -2436,8 +2458,7 @@ def raise_exc(*args): g.key: "memory", h.key: "memory", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_A, a) + await assert_task_states_on_worker(expected_states_A, a) expected_states_B = { f.key: "released", @@ -2445,8 +2466,7 @@ def raise_exc(*args): h.key: "memory", res.key: "error", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_B, b) + await assert_task_states_on_worker(expected_states_B, b) g.release() @@ -2454,8 +2474,7 @@ def raise_exc(*args): g.key: "released", h.key: "memory", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_A, a) + await assert_task_states_on_worker(expected_states_A, a) # B must not forget a task since all have a still valid dependent expected_states_B = { @@ -2463,19 +2482,18 @@ def raise_exc(*args): h.key: "memory", res.key: "error", } - assert_task_states_on_worker(expected_states_B, b) + await assert_task_states_on_worker(expected_states_B, b) h.release() - await asyncio.sleep(0.05) expected_states_A = {} - assert_task_states_on_worker(expected_states_A, a) + await assert_task_states_on_worker(expected_states_A, a) expected_states_B = { f.key: "released", h.key: "released", res.key: "error", } - assert_task_states_on_worker(expected_states_B, b) + await assert_task_states_on_worker(expected_states_B, b) res.release() # We no longer hold any refs. Cluster should reset completely @@ -3130,33 +3148,38 @@ async def test_missing_released_zombie_tasks(c, s, a, b): @gen_cluster(client=True) async def test_missing_released_zombie_tasks_2(c, s, a, b): - a.total_in_connections = 0 - f1 = c.submit(inc, 1, key="f1", workers=[a.address]) - f2 = c.submit(inc, f1, key="f2", workers=[b.address]) + # If get_data_from_worker raises this will suggest a dead worker to B and it + # will transition the task to missing. We want to make sure that a missing + # task is properly released and not left as a zombie + with mock.patch.object( + distributed.worker, + "get_data_from_worker", + side_effect=CommClosedError, + ): + f1 = c.submit(inc, 1, key="f1", workers=[a.address]) + f2 = c.submit(inc, f1, key="f2", workers=[b.address]) - while f1.key not in b.tasks: - await asyncio.sleep(0) + while f1.key not in b.tasks: + await asyncio.sleep(0) - ts = b.tasks[f1.key] - assert ts.state == "fetch" + ts = b.tasks[f1.key] + assert ts.state == "fetch" - # A few things can happen to clear who_has. The dominant process is upon - # connection failure to a worker. Regardless of how the set was cleared, the - # task will be transitioned to missing where the worker is trying to - # reaquire this information from the scheduler. While this is happening on - # worker side, the tasks are released and we want to ensure that no dangling - # zombie tasks are left on the worker - ts.who_has.clear() + while not ts.state == "missing": + # If we sleep for a longer time, the worker will spin into an + # endless loop of asking the scheduler who_has and trying to connect + # to A + await asyncio.sleep(0) - del f1, f2 + del f1, f2 - while b.tasks: - await asyncio.sleep(0.01) + while b.tasks: + await asyncio.sleep(0.01) - assert_worker_story( - b.story(ts), - [("f1", "missing", "released", "released", {"f1": "forgotten"})], - ) + assert_worker_story( + b.story(ts), + [("f1", "missing", "released", "released", {"f1": "forgotten"})], + ) @pytest.mark.slow @@ -3441,6 +3464,8 @@ async def test_Worker__to_dict(c, s, a): "config", "incoming_transfer_log", "outgoing_transfer_log", + "data_needed", + "pending_data_per_worker", } assert d["tasks"]["x"]["key"] == "x" @@ -3462,3 +3487,139 @@ async def test_TaskState__to_dict(c, s, a): assert isinstance(tasks["z"], dict) assert tasks["x"]["dependents"] == [""] assert tasks["y"]["dependencies"] == [""] + + +@gen_cluster(client=True) +async def test_dups_in_pending_data_per_worker(c, s, a, b): + # There has been a condition leading to a deadlock (caught by AssertionError + # if validate is enabled) that was caused by not identifying a missing key + # properly + + # We need to fetch a key which is repeatedly selected as part of the Worker.select_from_gather optimization + # since if it goes through the ordinary channels of ensure_communicating it + # is flagged immediately as missing + + # this is a batch of futures we will fetch from A. We will use these as + # seeds, i.e. primary keys to fetch for the batched fetch optimization + futs = c.map(inc, range(100), workers=[a.address]) + # This will be the culprit/missing key we selectively insert into the fetch + # queue. We will manipulate the state machine such that this would raise the + # AssertionError + missing_fut = c.submit(inc, -1, workers=[a.address], key="culprit") + + # Ensure the data is available, scheduler is aware + await c.gather(futs) + await missing_fut + + # MOCKs: + # We will mock ensure_communicating and ensure_computing to disable the + # every_cycle callback of our handle_scheduler + # Effectively this allows us to intercept the moment directly after a + # handle_compute_task handler was executed + with mock.patch.object( + Worker, "ensure_communicating", return_value=None + ) as comm_mock: + with mock.patch.object( + Worker, "ensure_computing", return_value=None + ) as comp_mock: + # This new worker will be the one where the exception is provoked + x = await Worker(s.address, name=2, validate=True) + # fill up the data needed heap with tasks that are fine to be scheduled + + f1 = c.submit(sum, [*futs[:20]], workers=[x.address], key="f1", priority=10) + # Put the bad one in between + f2 = c.submit(inc, missing_fut, workers=[x.address], key="f2", priority=20) + # Put the bad one in between. Ensure the heap is full with all the + # tasks such that we have many batched fetches + f3 = c.submit( + sum, [*futs[20:40]], workers=[x.address], key="f3", priority=30 + ) + + # wait for all the tasks to be registered. Without the mocks we + # could not cleanly assert this but this is an important test + # assumption + while len(x.data_needed) != 41: + await asyncio.sleep(0.01) + assert missing_fut.key in x.tasks + ts = x.tasks[missing_fut.key] + assert ts in x.pending_data_per_worker[a.address] + # not at the top of the heap + key = x.pending_data_per_worker[a.address].peak() + assert key != missing_fut.key + + # This has been introducing duplicates to pending_data_per_worker + for _ in range(3): + await x.query_who_has(missing_fut.key) + + # We will now remove culprit from A such that X will handle the missing + # response. + # To not have the scheduler reschedule this task immediately, we will create + # another replica on another worker. We don't want X to be made aware of + # this which is why we're disabling / mocking the query_who_has to ensure + # that there is no background update. In a more realistic environment, this + # can be caused by certain delays in communication, particulary if AMM is + # runing + # We could also use AMM to create a replica if the API supports this. + with mock.patch.object(Worker, "query_who_has", return_value=None) as who_has_mock: + f_copy = c.submit( + inc, missing_fut, key="copy-intention-culprit", workers=[b.address] + ) + await f_copy + + # Now remove the replica from A *before* X requests the data + a.handle_remove_replicas([missing_fut.key], stimulus_id="test") + + # We want to ensure that the first batch includes all data for + x.target_message_size = sum(x.tasks[f.key].get_nbytes() for f in futs[:22]) + x.ensure_communicating() + + with mock.patch.object( + Worker, "ensure_communicating", return_value=None + ) as comm_mock: + with mock.patch.object( + Worker, "ensure_computing", return_value=None + ) as comp_mock: + while not x.data: + await asyncio.sleep(0.01) + + x.target_message_size = 1000000000 + x.ensure_communicating() + + await f1 + await f2 + await f3 + + +def test_unique_task_heap(): + heap = _UniqueTaskHeap() + + for x in range(10): + ts = TaskState(f"f{x}") + ts.priority = (0, 0, 1, x % 3) + heap.push(ts) + del ts + + heap_list = list(heap) + # iteration does not empty heap + assert heap + assert heap_list == sorted(heap_list, key=lambda ts: ts.priority) + + seen = set() + last_prio = (0, 0, 0, 0) + while heap: + peaked = heap.peak() + ts = heap.pop() + assert peaked == ts + seen.add(ts.key) + assert ts.priority + assert last_prio <= ts.priority + last_prio = last_prio + + ts = TaskState("foo") + heap.push(ts) + heap.push(ts) + assert len(heap) == 1 + assert heap.pop() == ts + assert not heap + + assert isinstance(repr(heap), str) diff --git a/distributed/worker.py b/distributed/worker.py index 72d3c93f5e3..b5424716d6d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -13,7 +13,14 @@ import warnings import weakref from collections import defaultdict, deque, namedtuple -from collections.abc import Callable, Collection, Iterable, Mapping, MutableMapping +from collections.abc import ( + Callable, + Collection, + Iterable, + Iterator, + Mapping, + MutableMapping, +) from concurrent.futures import Executor from contextlib import suppress from datetime import timedelta @@ -200,7 +207,7 @@ def __init__(self, key, runspec=None): self.dependencies = set() self.dependents = set() self.duration = None - self.priority = None + self.priority: tuple[int, ...] | None = None self.state = "released" self.who_has = set() self.coming_from = None @@ -267,6 +274,56 @@ def is_protected(self) -> bool: ) +class _UniqueTaskHeap: + def __init__(self, collection: Collection[TaskState] | None = None) -> None: + """A heap of TaskState objects ordered by TaskState.priority + Ties are broken by string comparison of the key. + Keys are guaranteed to be unique. + Iterating over this object returns the elements in priority order. + """ + if collection is None: + collection = [] + self._known = {ts.key for ts in collection} + self._heap = [(ts.priority, ts.key, ts) for ts in collection] + heapq.heapify(self._heap) + + def push(self, ts: TaskState) -> None: + """Add a new TaskState instance to the heap. If the key is already + known, no object is added. + + Note: This does not update the priority / heap order in case priority + changes. + """ + assert isinstance(ts, TaskState) + if ts.key not in self._known: + heapq.heappush(self._heap, (ts.priority, ts.key, ts)) + self._known.add(ts.key) + + def pop(self) -> TaskState: + """Pop the task with highest priority from the heap.""" + _, _, ts = heapq.heappop(self._heap) + self._known.remove(ts.key) + return ts + + def peak(self) -> TaskState: + """Get the highest priority TaskState without removing it from the heap""" + return self._heap[0][2] + + def __contains__(self, x: object) -> bool: + if isinstance(x, TaskState): + x = x.key + return x in self._known + + def __iter__(self) -> Iterator[TaskState]: + return iter([ts for _, _, ts in sorted(self._heap)]) + + def __len__(self) -> int: + return len(self._known) + + def __repr__(self) -> str: + return f"<{type(self).__name__}: {len(self)} items>" + + class Worker(ServerNode): """Worker node in a Dask distributed cluster @@ -342,7 +399,7 @@ class Worker(ServerNode): * **data.disk:** ``{key: object}``: Dictionary mapping keys to actual values stored on disk. Only available if condition for **data** being a zict.Buffer is met. - * **data_needed**: deque(keys) + * **data_needed**: heap(TaskState) The keys which still require data in order to execute, arranged in a deque * **ready**: [keys] Keys that are ready to run. Stored in a LIFO stack @@ -358,8 +415,8 @@ class Worker(ServerNode): long-running clients. * **has_what**: ``{worker: {deps}}`` The data that we care about that we think a worker has - * **pending_data_per_worker**: ``{worker: [dep]}`` - The data on each worker that we still want, prioritized as a deque + * **pending_data_per_worker**: ``{worker: heap(TaskState)}`` + The data on each worker that we still want, prioritized as a heap * **in_flight_tasks**: ``int`` A count of the number of tasks that are coming to us in current peer-to-peer connections @@ -457,10 +514,10 @@ class Worker(ServerNode): tasks: dict[str, TaskState] waiting_for_data_count: int has_what: defaultdict[str, set[str]] # {worker address: {ts.key, ...} - pending_data_per_worker: defaultdict[str, deque[str]] + pending_data_per_worker: defaultdict[str, _UniqueTaskHeap] nanny: Nanny | None _lock: threading.Lock - data_needed: list[tuple[int, str]] # heap[(ts.priority, ts.key)] + data_needed: _UniqueTaskHeap in_flight_workers: dict[str, set[str]] # {worker address: {ts.key, ...}} total_out_connections: int total_in_connections: int @@ -609,11 +666,11 @@ def __init__( self.tasks = {} self.waiting_for_data_count = 0 self.has_what = defaultdict(set) - self.pending_data_per_worker = defaultdict(deque) + self.pending_data_per_worker = defaultdict(_UniqueTaskHeap) self.nanny = nanny self._lock = threading.Lock() - self.data_needed = [] + self.data_needed = _UniqueTaskHeap() self.in_flight_workers = {} self.total_out_connections = dask.config.get( @@ -624,7 +681,7 @@ def __init__( ) self.comm_threshold_bytes = int(10e6) self.comm_nbytes = 0 - self._missing_dep_flight = set() + self._missing_dep_flight: set[TaskState] = set() self.threads = {} @@ -673,11 +730,11 @@ def __init__( ("executing", "released"): self.transition_executing_released, ("executing", "rescheduled"): self.transition_executing_rescheduled, ("fetch", "flight"): self.transition_fetch_flight, - ("fetch", "missing"): self.transition_fetch_missing, ("fetch", "released"): self.transition_generic_released, ("flight", "error"): self.transition_flight_error, ("flight", "fetch"): self.transition_flight_fetch, ("flight", "memory"): self.transition_flight_memory, + ("flight", "missing"): self.transition_flight_missing, ("flight", "released"): self.transition_flight_released, ("long-running", "error"): self.transition_generic_error, ("long-running", "memory"): self.transition_long_running_memory, @@ -1156,6 +1213,10 @@ def _to_dict( "status": self.status, "ready": self.ready, "constrained": self.constrained, + "data_needed": list(self.data_needed), + "pending_data_per_worker": { + w: list(v) for w, v in self.pending_data_per_worker.items() + }, "long_running": self.long_running, "executing_count": self.executing_count, "in_flight_tasks": self.in_flight_tasks, @@ -1913,7 +1974,6 @@ def handle_cancel_compute(self, key, reason): ts = self.tasks.get(key) if ts and ts.state in READY | {"waiting"}: self.log.append((key, "cancel-compute", reason, time())) - ts.scheduler_holds_ref = False # All possible dependents of TS should not be in state Processing on # scheduler side and therefore should not be assigned to a worker, # yet. @@ -1942,7 +2002,7 @@ def handle_acquire_replicas( if ts.state != "memory": recommendations[ts] = "fetch" - self.update_who_has(who_has, stimulus_id=stimulus_id) + self.update_who_has(who_has) self.transitions(recommendations, stimulus_id=stimulus_id) def ensure_task_exists( @@ -2038,11 +2098,10 @@ def handle_compute_task( for msg in scheduler_msgs: self.batched_stream.send(msg) + + self.update_who_has(who_has) self.transitions(recommendations, stimulus_id=stimulus_id) - # We received new info, that's great but not related to the compute-task - # instruction - self.update_who_has(who_has, stimulus_id=stimulus_id) if nbytes is not None: for key, value in nbytes.items(): self.tasks[key].nbytes = value @@ -2055,7 +2114,7 @@ def transition_missing_fetch(self, ts, *, stimulus_id): self._missing_dep_flight.discard(ts) ts.state = "fetch" ts.done = False - heapq.heappush(self.data_needed, (ts.priority, ts.key)) + self.data_needed.push(ts) return {}, [] def transition_missing_released(self, ts, *, stimulus_id): @@ -2066,10 +2125,11 @@ def transition_missing_released(self, ts, *, stimulus_id): assert ts.key in self.tasks return recommendations, smsgs - def transition_fetch_missing(self, ts, *, stimulus_id): - # handle_missing will append to self.data_needed if new workers are found + def transition_flight_missing(self, ts, *, stimulus_id): + assert ts.done ts.state = "missing" self._missing_dep_flight.add(ts) + ts.done = False return {}, [] def transition_released_fetch(self, ts, *, stimulus_id): @@ -2077,10 +2137,10 @@ def transition_released_fetch(self, ts, *, stimulus_id): assert ts.state == "released" assert ts.priority is not None for w in ts.who_has: - self.pending_data_per_worker[w].append(ts.key) + self.pending_data_per_worker[w].push(ts) ts.state = "fetch" ts.done = False - heapq.heappush(self.data_needed, (ts.priority, ts.key)) + self.data_needed.push(ts) return {}, [] def transition_generic_released(self, ts, *, stimulus_id): @@ -2126,7 +2186,6 @@ def transition_fetch_flight(self, ts, worker, *, stimulus_id): if self.validate: assert ts.state == "fetch" assert ts.who_has - assert ts.key not in self.data_needed ts.done = False ts.state = "flight" @@ -2425,11 +2484,17 @@ def transition_flight_fetch(self, ts, *, stimulus_id): # we can reset the task and transition to fetch again. If it is not yet # finished, this should be a no-op if ts.done: - recommendations, smsgs = self.transition_generic_released( - ts, stimulus_id=stimulus_id - ) - recommendations[ts] = "fetch" - return recommendations, smsgs + recommendations = {} + ts.state = "fetch" + ts.coming_from = None + ts.done = False + if not ts.who_has: + recommendations[ts] = "missing" + else: + self.data_needed.push(ts) + for w in ts.who_has: + self.pending_data_per_worker[w].push(ts) + return recommendations, [] else: return {}, [] @@ -2692,24 +2757,15 @@ def ensure_communicating(self): self.total_out_connections, ) - _, key = heapq.heappop(self.data_needed) - - try: - ts = self.tasks[key] - except KeyError: - continue + ts = self.data_needed.pop() if ts.state != "fetch": continue - if not ts.who_has: - self.transition(ts, "missing", stimulus_id=stimulus_id) - continue - workers = [w for w in ts.who_has if w not in self.in_flight_workers] if not workers: assert ts.priority is not None - skipped_worker_in_flight.append((ts.priority, ts.key)) + skipped_worker_in_flight.append(ts) continue host = get_address_host(self.address) @@ -2740,7 +2796,7 @@ def ensure_communicating(self): ) for el in skipped_worker_in_flight: - heapq.heappush(self.data_needed, el) + self.data_needed.push(el) def _get_task_finished_msg(self, ts): if ts.key not in self.data and ts.key not in self.actors: @@ -2834,13 +2890,12 @@ def select_keys_for_gather(self, worker, dep): L = self.pending_data_per_worker[worker] while L: - d = L.popleft() - ts = self.tasks.get(d) - if ts is None or ts.state != "fetch": + ts = L.pop() + if ts.state != "fetch": continue if total_bytes + ts.get_nbytes() > self.target_message_size: break - deps.add(d) + deps.add(ts.key) total_bytes += ts.get_nbytes() return deps, total_bytes @@ -3077,7 +3132,10 @@ async def gather_dep( self.batched_stream.send( {"op": "missing-data", "errant_worker": worker, "key": d} ) - recommendations[ts] = "fetch" + if not ts.who_has: + recommendations[ts] = "missing" + else: + recommendations[ts] = "fetch" del data, response self.transitions(recommendations, stimulus_id=stimulus_id) self.ensure_computing() @@ -3089,7 +3147,7 @@ async def gather_dep( self.repetitively_busy += 1 await asyncio.sleep(0.100 * 1.5 ** self.repetitively_busy) - await self.query_who_has(*to_gather_keys, stimulus_id=stimulus_id) + await self.query_who_has(*to_gather_keys) self.ensure_communicating() @@ -3108,7 +3166,12 @@ async def find_missing(self): keys=[ts.key for ts in self._missing_dep_flight], ) who_has = {k: v for k, v in who_has.items() if v} - self.update_who_has(who_has, stimulus_id=stimulus_id) + self.update_who_has(who_has) + recommendations = {} + for ts in self._missing_dep_flight: + if ts.who_has: + recommendations[ts] = "fetch" + self.transitions(recommendations, stimulus_id=stimulus_id) finally: # This is quite arbitrary but the heartbeat has scaling implemented @@ -3118,24 +3181,20 @@ async def find_missing(self): self.ensure_communicating() self.ensure_computing() - async def query_who_has( - self, *deps: str, stimulus_id: str - ) -> dict[str, Collection[str]]: + async def query_who_has(self, *deps: str) -> dict[str, Collection[str]]: with log_errors(): who_has = await retry_operation(self.scheduler.who_has, keys=deps) - self.update_who_has(who_has, stimulus_id=stimulus_id) + self.update_who_has(who_has) return who_has - def update_who_has( - self, who_has: dict[str, Collection[str]], *, stimulus_id: str - ) -> None: + def update_who_has(self, who_has: dict[str, Collection[str]]) -> None: try: - recommendations = {} for dep, workers in who_has.items(): if not workers: continue if dep in self.tasks: + dep_ts = self.tasks[dep] if self.address in workers and self.tasks[dep].state != "memory": logger.debug( "Scheduler claims worker %s holds data for task %s which is not true.", @@ -3144,18 +3203,11 @@ def update_who_has( ) # Do not mutate the input dict. That's rude workers = set(workers) - {self.address} - dep_ts = self.tasks[dep] - if dep_ts.state in FETCH_INTENDED: - dep_ts.who_has.update(workers) - - if dep_ts.state == "missing": - recommendations[dep_ts] = "fetch" - - for worker in workers: - self.has_what[worker].add(dep) - self.pending_data_per_worker[worker].append(dep_ts.key) + dep_ts.who_has.update(workers) - self.transitions(recommendations, stimulus_id=stimulus_id) + for worker in workers: + self.has_what[worker].add(dep) + self.pending_data_per_worker[worker].push(dep_ts) except Exception as e: logger.exception(e) if LOG_PDB: @@ -3874,16 +3926,19 @@ def validate_task_fetch(self, ts): assert ts.key not in self.data assert self.address not in ts.who_has assert not ts.done + assert ts in self.data_needed + assert ts.who_has for w in ts.who_has: assert ts.key in self.has_what[w] + assert ts in self.pending_data_per_worker[w] def validate_task_missing(self, ts): assert ts.key not in self.data assert not ts.who_has assert not ts.done assert not any(ts.key in has_what for has_what in self.has_what.values()) - assert ts.key in self._missing_dep_flight + assert ts in self._missing_dep_flight def validate_task_cancelled(self, ts): assert ts.key not in self.data @@ -3903,7 +3958,6 @@ def validate_task_released(self, ts): assert ts not in self._in_flight_tasks assert ts not in self._missing_dep_flight assert ts not in self._missing_dep_flight - assert not ts.who_has assert not any(ts.key in has_what for has_what in self.has_what.values()) assert not ts.waiting_for_data assert not ts.done @@ -3973,13 +4027,9 @@ def validate_state(self): assert ( ts_wait.state in READY | {"executing", "flight", "fetch", "missing"} - or ts_wait.key in self._missing_dep_flight + or ts_wait in self._missing_dep_flight or ts_wait.who_has.issubset(self.in_flight_workers) ), (ts, ts_wait, self.story(ts), self.story(ts_wait)) - if ts.state == "memory": - assert isinstance(ts.nbytes, int) - assert not ts.waiting_for_data - assert ts.key in self.data or ts.key in self.actors assert self.waiting_for_data_count == waiting_for_data_count for worker, keys in self.has_what.items(): for k in keys: