diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index b6c4961ff6d..64f4eab12fa 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -57,7 +57,14 @@ slowinc, slowsum, ) -from distributed.worker import Worker, error_message, logger, parse_memory_limit +from distributed.worker import ( + TaskState, + UniqueTaskHeap, + Worker, + error_message, + logger, + parse_memory_limit, +) pytestmark = pytest.mark.ci1 @@ -1390,7 +1397,6 @@ async def test_prefer_gather_from_local_address(c, s, w1, w2, w3): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1)] * 20, - timeout=30, config={"distributed.worker.connections.incoming": 1}, ) async def test_avoid_oversubscription(c, s, *workers): @@ -2186,12 +2192,26 @@ async def test_gpu_executor(c, s, w): assert "gpu" not in w.executors -def assert_task_states_on_worker(expected, worker): - for dep_key, expected_state in expected.items(): - assert dep_key in worker.tasks, (worker.name, dep_key, worker.tasks) - dep_ts = worker.tasks[dep_key] - assert dep_ts.state == expected_state, (worker.name, dep_ts, expected_state) - assert set(expected) == set(worker.tasks) +async def assert_task_states_on_worker(expected, worker): + active_exc = None + for _ in range(10): + try: + for dep_key, expected_state in expected.items(): + assert dep_key in worker.tasks, (worker.name, dep_key, worker.tasks) + dep_ts = worker.tasks[dep_key] + assert dep_ts.state == expected_state, ( + worker.name, + dep_ts, + expected_state, + ) + assert set(expected) == set(worker.tasks) + return + except AssertionError as exc: + active_exc = exc + await asyncio.sleep(0.1) + # If after a second the workers are not in equilibrium, they are broken + assert active_exc + raise active_exc @gen_cluster(client=True) @@ -2235,7 +2255,7 @@ def raise_exc(*args): g.key: "memory", res.key: "error", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) # Expected states after we release references to the futures f.release() g.release() @@ -2251,7 +2271,7 @@ def raise_exc(*args): res.key: "error", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) res.release() @@ -2304,7 +2324,7 @@ def raise_exc(*args): g.key: "memory", res.key: "error", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) # Expected states after we release references to the futures res.release() @@ -2318,7 +2338,7 @@ def raise_exc(*args): g.key: "memory", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) f.release() g.release() @@ -2369,7 +2389,7 @@ def raise_exc(*args): g.key: "memory", res.key: "error", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) # Expected states after we release references to the futures f.release() @@ -2383,8 +2403,8 @@ def raise_exc(*args): g.key: "memory", } - assert_task_states_on_worker(expected_states, a) - assert_task_states_on_worker(expected_states, b) + await assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, b) g.release() @@ -2418,8 +2438,7 @@ def raise_exc(*args): g.key: "memory", h.key: "memory", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_A, a) + await assert_task_states_on_worker(expected_states_A, a) expected_states_B = { f.key: "memory", @@ -2427,8 +2446,7 @@ def raise_exc(*args): h.key: "memory", res.key: "error", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_B, b) + await assert_task_states_on_worker(expected_states_B, b) f.release() @@ -2436,8 +2454,7 @@ def raise_exc(*args): g.key: "memory", h.key: "memory", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_A, a) + await assert_task_states_on_worker(expected_states_A, a) expected_states_B = { f.key: "released", @@ -2445,8 +2462,7 @@ def raise_exc(*args): h.key: "memory", res.key: "error", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_B, b) + await assert_task_states_on_worker(expected_states_B, b) g.release() @@ -2454,8 +2470,7 @@ def raise_exc(*args): g.key: "released", h.key: "memory", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_A, a) + await assert_task_states_on_worker(expected_states_A, a) # B must not forget a task since all have a still valid dependent expected_states_B = { @@ -2463,19 +2478,18 @@ def raise_exc(*args): h.key: "memory", res.key: "error", } - assert_task_states_on_worker(expected_states_B, b) + await assert_task_states_on_worker(expected_states_B, b) h.release() - await asyncio.sleep(0.05) expected_states_A = {} - assert_task_states_on_worker(expected_states_A, a) + await assert_task_states_on_worker(expected_states_A, a) expected_states_B = { f.key: "released", h.key: "released", res.key: "error", } - assert_task_states_on_worker(expected_states_B, b) + await assert_task_states_on_worker(expected_states_B, b) res.release() # We no longer hold any refs. Cluster should reset completely @@ -3130,33 +3144,38 @@ async def test_missing_released_zombie_tasks(c, s, a, b): @gen_cluster(client=True) async def test_missing_released_zombie_tasks_2(c, s, a, b): - a.total_in_connections = 0 - f1 = c.submit(inc, 1, key="f1", workers=[a.address]) - f2 = c.submit(inc, f1, key="f2", workers=[b.address]) + # If get_data_from_worker raises this will suggest a dead worker to B and it + # will transition the task to missing. We want to make sure that a missing + # task is properly released and not left as a zombie + with mock.patch.object( + distributed.worker, + "get_data_from_worker", + side_effect=CommClosedError, + ): + f1 = c.submit(inc, 1, key="f1", workers=[a.address]) + f2 = c.submit(inc, f1, key="f2", workers=[b.address]) - while f1.key not in b.tasks: - await asyncio.sleep(0) + while f1.key not in b.tasks: + await asyncio.sleep(0) - ts = b.tasks[f1.key] - assert ts.state == "fetch" + ts = b.tasks[f1.key] + assert ts.state == "fetch" - # A few things can happen to clear who_has. The dominant process is upon - # connection failure to a worker. Regardless of how the set was cleared, the - # task will be transitioned to missing where the worker is trying to - # reaquire this information from the scheduler. While this is happening on - # worker side, the tasks are released and we want to ensure that no dangling - # zombie tasks are left on the worker - ts.who_has.clear() + while ts.state != "missing": + # If we sleep for a longer time, the worker will spin into an + # endless loop of asking the scheduler who_has and trying to connect + # to A + await asyncio.sleep(0) - del f1, f2 + del f1, f2 - while b.tasks: - await asyncio.sleep(0.01) + while b.tasks: + await asyncio.sleep(0.01) - assert_worker_story( - b.story(ts), - [("f1", "missing", "released", "released", {"f1": "forgotten"})], - ) + assert_worker_story( + b.story(ts), + [("f1", "missing", "released", "released", {"f1": "forgotten"})], + ) @pytest.mark.slow @@ -3441,6 +3460,8 @@ async def test_Worker__to_dict(c, s, a): "config", "incoming_transfer_log", "outgoing_transfer_log", + "data_needed", + "pending_data_per_worker", } assert d["tasks"]["x"]["key"] == "x" @@ -3462,3 +3483,45 @@ async def test_TaskState__to_dict(c, s, a): assert isinstance(tasks["z"], dict) assert tasks["x"]["dependents"] == [""] assert tasks["y"]["dependencies"] == [""] + + +def test_unique_task_heap(): + heap = UniqueTaskHeap() + + for x in range(10): + ts = TaskState(f"f{x}") + ts.priority = (0, 0, 1, x % 3) + heap.push(ts) + + heap_list = list(heap) + # iteration does not empty heap + assert len(heap) == 10 + assert heap_list == sorted(heap_list, key=lambda ts: ts.priority) + + seen = set() + last_prio = (0, 0, 0, 0) + while heap: + peeked = heap.peek() + ts = heap.pop() + assert peeked == ts + seen.add(ts.key) + assert ts.priority + assert last_prio <= ts.priority + last_prio = last_prio + + ts = TaskState("foo") + heap.push(ts) + heap.push(ts) + assert len(heap) == 1 + + assert repr(heap) == "" + + assert heap.pop() == ts + assert not heap + + # Test that we're cleaning the seen set on pop + heap.push(ts) + assert len(heap) == 1 + assert heap.pop() == ts + + assert repr(heap) == "" diff --git a/distributed/worker.py b/distributed/worker.py index 72d3c93f5e3..02528770d28 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -13,7 +13,14 @@ import warnings import weakref from collections import defaultdict, deque, namedtuple -from collections.abc import Callable, Collection, Iterable, Mapping, MutableMapping +from collections.abc import ( + Callable, + Collection, + Iterable, + Iterator, + Mapping, + MutableMapping, +) from concurrent.futures import Executor from contextlib import suppress from datetime import timedelta @@ -110,7 +117,6 @@ "resumed", } READY = {"ready", "constrained"} -FETCH_INTENDED = {"missing", "fetch", "flight", "cancelled", "resumed"} # Worker.status subsets RUNNING = {Status.running, Status.paused, Status.closing_gracefully} @@ -193,6 +199,8 @@ class TaskState: """ + priority: tuple[int, ...] | None + def __init__(self, key, runspec=None): assert key is not None self.key = key @@ -267,6 +275,54 @@ def is_protected(self) -> bool: ) +class UniqueTaskHeap(Collection): + """A heap of TaskState objects ordered by TaskState.priority + Ties are broken by string comparison of the key. Keys are guaranteed to be + unique. Iterating over this object returns the elements in priority order. + """ + + def __init__(self, collection: Collection[TaskState] = ()): + self._known = {ts.key for ts in collection} + self._heap = [(ts.priority, ts.key, ts) for ts in collection] + heapq.heapify(self._heap) + + def push(self, ts: TaskState) -> None: + """Add a new TaskState instance to the heap. If the key is already + known, no object is added. + + Note: This does not update the priority / heap order in case priority + changes. + """ + assert isinstance(ts, TaskState) + if ts.key not in self._known: + heapq.heappush(self._heap, (ts.priority, ts.key, ts)) + self._known.add(ts.key) + + def pop(self) -> TaskState: + """Pop the task with highest priority from the heap.""" + _, key, ts = heapq.heappop(self._heap) + self._known.remove(key) + return ts + + def peek(self) -> TaskState: + """Get the highest priority TaskState without removing it from the heap""" + return self._heap[0][2] + + def __contains__(self, x: object) -> bool: + if isinstance(x, TaskState): + x = x.key + return x in self._known + + def __iter__(self) -> Iterator[TaskState]: + return (ts for _, _, ts in sorted(self._heap)) + + def __len__(self) -> int: + return len(self._known) + + def __repr__(self) -> str: + return f"<{type(self).__name__}: {len(self)} items>" + + class Worker(ServerNode): """Worker node in a Dask distributed cluster @@ -342,8 +398,8 @@ class Worker(ServerNode): * **data.disk:** ``{key: object}``: Dictionary mapping keys to actual values stored on disk. Only available if condition for **data** being a zict.Buffer is met. - * **data_needed**: deque(keys) - The keys which still require data in order to execute, arranged in a deque + * **data_needed**: UniqueTaskHeap + The tasks which still require data in order to execute, prioritized as a heap * **ready**: [keys] Keys that are ready to run. Stored in a LIFO stack * **constrained**: [keys] @@ -358,8 +414,8 @@ class Worker(ServerNode): long-running clients. * **has_what**: ``{worker: {deps}}`` The data that we care about that we think a worker has - * **pending_data_per_worker**: ``{worker: [dep]}`` - The data on each worker that we still want, prioritized as a deque + * **pending_data_per_worker**: ``{worker: UniqueTaskHeap}`` + The data on each worker that we still want, prioritized as a heap * **in_flight_tasks**: ``int`` A count of the number of tasks that are coming to us in current peer-to-peer connections @@ -457,10 +513,10 @@ class Worker(ServerNode): tasks: dict[str, TaskState] waiting_for_data_count: int has_what: defaultdict[str, set[str]] # {worker address: {ts.key, ...} - pending_data_per_worker: defaultdict[str, deque[str]] + pending_data_per_worker: defaultdict[str, UniqueTaskHeap] nanny: Nanny | None _lock: threading.Lock - data_needed: list[tuple[int, str]] # heap[(ts.priority, ts.key)] + data_needed: UniqueTaskHeap in_flight_workers: dict[str, set[str]] # {worker address: {ts.key, ...}} total_out_connections: int total_in_connections: int @@ -609,11 +665,11 @@ def __init__( self.tasks = {} self.waiting_for_data_count = 0 self.has_what = defaultdict(set) - self.pending_data_per_worker = defaultdict(deque) + self.pending_data_per_worker = defaultdict(UniqueTaskHeap) self.nanny = nanny self._lock = threading.Lock() - self.data_needed = [] + self.data_needed = UniqueTaskHeap() self.in_flight_workers = {} self.total_out_connections = dask.config.get( @@ -673,11 +729,11 @@ def __init__( ("executing", "released"): self.transition_executing_released, ("executing", "rescheduled"): self.transition_executing_rescheduled, ("fetch", "flight"): self.transition_fetch_flight, - ("fetch", "missing"): self.transition_fetch_missing, ("fetch", "released"): self.transition_generic_released, ("flight", "error"): self.transition_flight_error, ("flight", "fetch"): self.transition_flight_fetch, ("flight", "memory"): self.transition_flight_memory, + ("flight", "missing"): self.transition_flight_missing, ("flight", "released"): self.transition_flight_released, ("long-running", "error"): self.transition_generic_error, ("long-running", "memory"): self.transition_long_running_memory, @@ -1156,6 +1212,10 @@ def _to_dict( "status": self.status, "ready": self.ready, "constrained": self.constrained, + "data_needed": list(self.data_needed), + "pending_data_per_worker": { + w: list(v) for w, v in self.pending_data_per_worker.items() + }, "long_running": self.long_running, "executing_count": self.executing_count, "in_flight_tasks": self.in_flight_tasks, @@ -1913,7 +1973,6 @@ def handle_cancel_compute(self, key, reason): ts = self.tasks.get(key) if ts and ts.state in READY | {"waiting"}: self.log.append((key, "cancel-compute", reason, time())) - ts.scheduler_holds_ref = False # All possible dependents of TS should not be in state Processing on # scheduler side and therefore should not be assigned to a worker, # yet. @@ -1942,7 +2001,7 @@ def handle_acquire_replicas( if ts.state != "memory": recommendations[ts] = "fetch" - self.update_who_has(who_has, stimulus_id=stimulus_id) + self.update_who_has(who_has) self.transitions(recommendations, stimulus_id=stimulus_id) def ensure_task_exists( @@ -2038,11 +2097,10 @@ def handle_compute_task( for msg in scheduler_msgs: self.batched_stream.send(msg) + + self.update_who_has(who_has) self.transitions(recommendations, stimulus_id=stimulus_id) - # We received new info, that's great but not related to the compute-task - # instruction - self.update_who_has(who_has, stimulus_id=stimulus_id) if nbytes is not None: for key, value in nbytes.items(): self.tasks[key].nbytes = value @@ -2055,7 +2113,7 @@ def transition_missing_fetch(self, ts, *, stimulus_id): self._missing_dep_flight.discard(ts) ts.state = "fetch" ts.done = False - heapq.heappush(self.data_needed, (ts.priority, ts.key)) + self.data_needed.push(ts) return {}, [] def transition_missing_released(self, ts, *, stimulus_id): @@ -2066,10 +2124,11 @@ def transition_missing_released(self, ts, *, stimulus_id): assert ts.key in self.tasks return recommendations, smsgs - def transition_fetch_missing(self, ts, *, stimulus_id): - # handle_missing will append to self.data_needed if new workers are found + def transition_flight_missing(self, ts, *, stimulus_id): + assert ts.done ts.state = "missing" self._missing_dep_flight.add(ts) + ts.done = False return {}, [] def transition_released_fetch(self, ts, *, stimulus_id): @@ -2077,10 +2136,10 @@ def transition_released_fetch(self, ts, *, stimulus_id): assert ts.state == "released" assert ts.priority is not None for w in ts.who_has: - self.pending_data_per_worker[w].append(ts.key) + self.pending_data_per_worker[w].push(ts) ts.state = "fetch" ts.done = False - heapq.heappush(self.data_needed, (ts.priority, ts.key)) + self.data_needed.push(ts) return {}, [] def transition_generic_released(self, ts, *, stimulus_id): @@ -2126,7 +2185,6 @@ def transition_fetch_flight(self, ts, worker, *, stimulus_id): if self.validate: assert ts.state == "fetch" assert ts.who_has - assert ts.key not in self.data_needed ts.done = False ts.state = "flight" @@ -2425,11 +2483,17 @@ def transition_flight_fetch(self, ts, *, stimulus_id): # we can reset the task and transition to fetch again. If it is not yet # finished, this should be a no-op if ts.done: - recommendations, smsgs = self.transition_generic_released( - ts, stimulus_id=stimulus_id - ) - recommendations[ts] = "fetch" - return recommendations, smsgs + recommendations = {} + ts.state = "fetch" + ts.coming_from = None + ts.done = False + if not ts.who_has: + recommendations[ts] = "missing" + else: + self.data_needed.push(ts) + for w in ts.who_has: + self.pending_data_per_worker[w].push(ts) + return recommendations, [] else: return {}, [] @@ -2692,24 +2756,15 @@ def ensure_communicating(self): self.total_out_connections, ) - _, key = heapq.heappop(self.data_needed) - - try: - ts = self.tasks[key] - except KeyError: - continue + ts = self.data_needed.pop() if ts.state != "fetch": continue - if not ts.who_has: - self.transition(ts, "missing", stimulus_id=stimulus_id) - continue - workers = [w for w in ts.who_has if w not in self.in_flight_workers] if not workers: assert ts.priority is not None - skipped_worker_in_flight.append((ts.priority, ts.key)) + skipped_worker_in_flight.append(ts) continue host = get_address_host(self.address) @@ -2740,7 +2795,7 @@ def ensure_communicating(self): ) for el in skipped_worker_in_flight: - heapq.heappush(self.data_needed, el) + self.data_needed.push(el) def _get_task_finished_msg(self, ts): if ts.key not in self.data and ts.key not in self.actors: @@ -2834,13 +2889,12 @@ def select_keys_for_gather(self, worker, dep): L = self.pending_data_per_worker[worker] while L: - d = L.popleft() - ts = self.tasks.get(d) - if ts is None or ts.state != "fetch": + ts = L.pop() + if ts.state != "fetch": continue if total_bytes + ts.get_nbytes() > self.target_message_size: break - deps.add(d) + deps.add(ts.key) total_bytes += ts.get_nbytes() return deps, total_bytes @@ -3077,7 +3131,7 @@ async def gather_dep( self.batched_stream.send( {"op": "missing-data", "errant_worker": worker, "key": d} ) - recommendations[ts] = "fetch" + recommendations[ts] = "fetch" if ts.who_has else "missing" del data, response self.transitions(recommendations, stimulus_id=stimulus_id) self.ensure_computing() @@ -3089,7 +3143,7 @@ async def gather_dep( self.repetitively_busy += 1 await asyncio.sleep(0.100 * 1.5 ** self.repetitively_busy) - await self.query_who_has(*to_gather_keys, stimulus_id=stimulus_id) + await self.query_who_has(*to_gather_keys) self.ensure_communicating() @@ -3108,7 +3162,12 @@ async def find_missing(self): keys=[ts.key for ts in self._missing_dep_flight], ) who_has = {k: v for k, v in who_has.items() if v} - self.update_who_has(who_has, stimulus_id=stimulus_id) + self.update_who_has(who_has) + recommendations = {} + for ts in self._missing_dep_flight: + if ts.who_has: + recommendations[ts] = "fetch" + self.transitions(recommendations, stimulus_id=stimulus_id) finally: # This is quite arbitrary but the heartbeat has scaling implemented @@ -3118,24 +3177,20 @@ async def find_missing(self): self.ensure_communicating() self.ensure_computing() - async def query_who_has( - self, *deps: str, stimulus_id: str - ) -> dict[str, Collection[str]]: + async def query_who_has(self, *deps: str) -> dict[str, Collection[str]]: with log_errors(): who_has = await retry_operation(self.scheduler.who_has, keys=deps) - self.update_who_has(who_has, stimulus_id=stimulus_id) + self.update_who_has(who_has) return who_has - def update_who_has( - self, who_has: dict[str, Collection[str]], *, stimulus_id: str - ) -> None: + def update_who_has(self, who_has: dict[str, Collection[str]]) -> None: try: - recommendations = {} for dep, workers in who_has.items(): if not workers: continue if dep in self.tasks: + dep_ts = self.tasks[dep] if self.address in workers and self.tasks[dep].state != "memory": logger.debug( "Scheduler claims worker %s holds data for task %s which is not true.", @@ -3144,18 +3199,11 @@ def update_who_has( ) # Do not mutate the input dict. That's rude workers = set(workers) - {self.address} - dep_ts = self.tasks[dep] - if dep_ts.state in FETCH_INTENDED: - dep_ts.who_has.update(workers) - - if dep_ts.state == "missing": - recommendations[dep_ts] = "fetch" - - for worker in workers: - self.has_what[worker].add(dep) - self.pending_data_per_worker[worker].append(dep_ts.key) + dep_ts.who_has.update(workers) - self.transitions(recommendations, stimulus_id=stimulus_id) + for worker in workers: + self.has_what[worker].add(dep) + self.pending_data_per_worker[worker].push(dep_ts) except Exception as e: logger.exception(e) if LOG_PDB: @@ -3874,16 +3922,19 @@ def validate_task_fetch(self, ts): assert ts.key not in self.data assert self.address not in ts.who_has assert not ts.done + assert ts in self.data_needed + assert ts.who_has for w in ts.who_has: assert ts.key in self.has_what[w] + assert ts in self.pending_data_per_worker[w] def validate_task_missing(self, ts): assert ts.key not in self.data assert not ts.who_has assert not ts.done assert not any(ts.key in has_what for has_what in self.has_what.values()) - assert ts.key in self._missing_dep_flight + assert ts in self._missing_dep_flight def validate_task_cancelled(self, ts): assert ts.key not in self.data @@ -3903,7 +3954,6 @@ def validate_task_released(self, ts): assert ts not in self._in_flight_tasks assert ts not in self._missing_dep_flight assert ts not in self._missing_dep_flight - assert not ts.who_has assert not any(ts.key in has_what for has_what in self.has_what.values()) assert not ts.waiting_for_data assert not ts.done @@ -3973,13 +4023,9 @@ def validate_state(self): assert ( ts_wait.state in READY | {"executing", "flight", "fetch", "missing"} - or ts_wait.key in self._missing_dep_flight + or ts_wait in self._missing_dep_flight or ts_wait.who_has.issubset(self.in_flight_workers) ), (ts, ts_wait, self.story(ts), self.story(ts_wait)) - if ts.state == "memory": - assert isinstance(ts.nbytes, int) - assert not ts.waiting_for_data - assert ts.key in self.data or ts.key in self.actors assert self.waiting_for_data_count == waiting_for_data_count for worker, keys in self.has_what.items(): for k in keys: