From 1ddce73426ac85aac1dd54c800f6236d92526246 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 11 Jul 2022 16:58:21 +0100 Subject: [PATCH 1/2] Store ready and constrained tasks in heapsets --- distributed/tests/test_resources.py | 136 +++++++++++++++++++++--- distributed/tests/test_steal.py | 10 +- distributed/worker.py | 2 +- distributed/worker_state_machine.py | 156 ++++++++++++++-------------- 4 files changed, 207 insertions(+), 97 deletions(-) diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index 3ee222a0027..cc11fce0448 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -from time import time import pytest @@ -12,6 +11,13 @@ from distributed import Lock, Worker from distributed.client import wait from distributed.utils_test import gen_cluster, inc, lock_inc, slowadd, slowinc +from distributed.worker_state_machine import ( + ComputeTaskEvent, + Execute, + ExecuteSuccessEvent, + FreeKeysEvent, + TaskFinishedMsg, +) @gen_cluster( @@ -235,20 +241,124 @@ async def test_minimum_resource(c, s, a): assert a.state.total_resources == a.state.available_resources -@gen_cluster(client=True, nthreads=[("127.0.0.1", 2, {"resources": {"A": 1}})]) -async def test_prefer_constrained(c, s, a): - futures = c.map(slowinc, range(1000), delay=0.1) - constrained = c.map(inc, range(10), resources={"A": 1}) +@pytest.mark.parametrize("swap", [False, True]) +@pytest.mark.parametrize("p1,p2,expect_key", [(1, 0, "y"), (0, 1, "x")]) +def test_constrained_vs_ready_priority_1(ws, p1, p2, expect_key, swap): + """If there are both ready and constrained tasks, those with the highest priority + win (note: on the Worker, priorities have their sign inverted) + """ + ws.available_resources = {"R": 1} + ws.total_resources = {"R": 1} + RR = {"resource_restrictions": {"R": 1}} - start = time() - await wait(constrained) - end = time() - assert end - start < 4 - assert ( - len([ts for ts in s.tasks.values() if ts.state == "memory"]) - <= len(constrained) + 2 + ws.handle_stimulus(ComputeTaskEvent.dummy(key="clog", stimulus_id="clog")) + + stimuli = [ + ComputeTaskEvent.dummy("x", priority=(p1,), stimulus_id="s1"), + ComputeTaskEvent.dummy("y", priority=(p2,), **RR, stimulus_id="s2"), + ] + if swap: + stimuli = stimuli[::-1] # This must be inconsequential + + instructions = ws.handle_stimulus( + *stimuli, + ExecuteSuccessEvent.dummy("clog", stimulus_id="s3"), + ) + assert instructions == [ + TaskFinishedMsg.match(key="clog", stimulus_id="s3"), + Execute(key=expect_key, stimulus_id="s3"), + ] + + +@pytest.mark.parametrize("swap", [False, True]) +@pytest.mark.parametrize("p1,p2,expect_key", [(1, 0, "y"), (0, 1, "x")]) +def test_constrained_vs_ready_priority_2(ws, p1, p2, expect_key, swap): + """If there are both ready and constrained tasks, but not enough available + resources, priority is inconsequential - the tasks in the ready queue are picked up. + """ + ws.nthreads = 2 + ws.available_resources = {"R": 1} + ws.total_resources = {"R": 1} + RR = {"resource_restrictions": {"R": 1}} + + ws.handle_stimulus( + ComputeTaskEvent.dummy(key="clog1", stimulus_id="clog1"), + ComputeTaskEvent.dummy(key="clog2", **RR, stimulus_id="clog2"), + ) + + # Test that both priorities and order are inconsequential + stimuli = [ + ComputeTaskEvent.dummy("x", priority=(p1,), stimulus_id="s1"), + ComputeTaskEvent.dummy("y", priority=(p2,), **RR, stimulus_id="s2"), + ] + if swap: + stimuli = stimuli[::-1] + + instructions = ws.handle_stimulus( + *stimuli, + ExecuteSuccessEvent.dummy("clog1", stimulus_id="s3"), + ) + assert instructions == [ + TaskFinishedMsg.match(key="clog1", stimulus_id="s3"), + Execute(key="x", stimulus_id="s3"), + ] + + +def test_constrained_tasks_respect_priority(ws): + ws.available_resources = {"R": 1} + ws.total_resources = {"R": 1} + RR = {"resource_restrictions": {"R": 1}} + + instructions = ws.handle_stimulus( + ComputeTaskEvent.dummy(key="clog", **RR, stimulus_id="clog"), + ComputeTaskEvent.dummy(key="x1", priority=(1,), **RR, stimulus_id="s1"), + ComputeTaskEvent.dummy(key="x2", priority=(2,), **RR, stimulus_id="s2"), + ComputeTaskEvent.dummy(key="x3", priority=(0,), **RR, stimulus_id="s3"), + ExecuteSuccessEvent.dummy(key="clog", stimulus_id="s4"), # start x3 + ExecuteSuccessEvent.dummy(key="x3", stimulus_id="s5"), # start x1 + ExecuteSuccessEvent.dummy(key="x1", stimulus_id="s6"), # start x2 + ) + assert instructions == [ + Execute(key="clog", stimulus_id="clog"), + TaskFinishedMsg.match(key="clog", stimulus_id="s4"), + Execute(key="x3", stimulus_id="s4"), + TaskFinishedMsg.match(key="x3", stimulus_id="s5"), + Execute(key="x1", stimulus_id="s5"), + TaskFinishedMsg.match(key="x1", stimulus_id="s6"), + Execute(key="x2", stimulus_id="s6"), + ] + + +def test_task_cancelled_and_readded_with_resources(ws): + """See https://github.com/dask/distributed/issues/6710 + + A task is enqueued without resources, then cancelled by the client, then re-added + with the same key, this time with resources. + Test that resources are respected. + """ + ws.available_resources = {"R": 1} + ws.total_resources = {"R": 1} + RR = {"resource_restrictions": {"R": 1}} + + ws.handle_stimulus( + ComputeTaskEvent.dummy(key="clog", **RR, stimulus_id="s1"), + ComputeTaskEvent.dummy(key="x", stimulus_id="s2"), + ) + ts = ws.tasks["x"] + assert ts.state == "ready" + assert ts in ws.ready + assert ts not in ws.constrained + assert ts.resource_restrictions == {} + + ws.handle_stimulus( + FreeKeysEvent(keys=["x"], stimulus_id="clog"), + ComputeTaskEvent.dummy(key="x", **RR, stimulus_id="s2"), ) - assert s.workers[a.address].processing + ts = ws.tasks["x"] + assert ts.state == "constrained" + assert ts not in ws.ready + assert ts in ws.constrained + assert ts.resource_restrictions == {"R": 1} @pytest.mark.skip(reason="") diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index a6df08d4c2e..59c6f678859 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1063,12 +1063,12 @@ async def test_steal_concurrent_simple(c, s, *workers): await asyncio.sleep(0.1) # ready is a heap but we don't need last, just not the next - _, victim_key = w0.state.ready[-1] + victim_key = w0.state.ready.peek().key + victim_ts = s.tasks[victim_key] ws0 = s.workers[w0.address] ws1 = s.workers[w1.address] ws2 = s.workers[w2.address] - victim_ts = s.tasks[victim_key] steal.move_task_request(victim_ts, ws0, ws1) steal.move_task_request(victim_ts, ws0, ws2) @@ -1098,8 +1098,7 @@ async def test_steal_reschedule_reset_in_flight_occupancy(c, s, *workers): await asyncio.sleep(0.01) # ready is a heap but we don't need last, just not the next - _, victim_key = w0.state.ready[-1] - + victim_key = w0.state.ready.peek().key victim_ts = s.tasks[victim_key] wsA = victim_ts.processing_on @@ -1157,8 +1156,7 @@ async def test_steal_worker_dies_same_ip(c, s, w0, w1): while not w0.active_keys: await asyncio.sleep(0.01) - victim_key = list(w0.state.ready)[-1][1] - + victim_key = w0.state.ready.peek().key victim_ts = s.tasks[victim_key] wsA = victim_ts.processing_on diff --git a/distributed/worker.py b/distributed/worker.py index 60a82526bd4..fbd6444c618 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1877,7 +1877,7 @@ def stateof(self, key: str) -> dict[str, Any]: return { "executing": ts.state == "executing", "waiting_for_data": bool(ts.waiting_for_data), - "heap": key in pluck(1, self.state.ready), + "heap": ts in self.state.ready or ts in self.state.constrained, "data": key in self.data, } diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index c021593192e..4f481f2e1bc 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -24,7 +24,7 @@ from itertools import chain from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict, cast -from tlz import peekn, pluck +from tlz import peekn import dask from dask.utils import parse_bytes, typename @@ -1077,13 +1077,14 @@ class WorkerState: #: method, is available. plugins: dict[str, WorkerPlugin] - # heapq ``[(priority, key), ...]``. Keys that are ready to run. - ready: list[tuple[tuple[int, ...], str]] + #: Priority heap of tasks that are ready to run and have no resource constrains. + #: Mutually exclusive with :attr:`constrained`. + ready: HeapSet[TaskState] - #: Keys for which we have the data to run, but are waiting on abstract resources - #: like GPUs. Stored in a FIFO deque. + #: Priority heap of tasks that are ready to run, but are waiting on abstract + #: resources like GPUs. Mutually exclusive with :attr:`ready`. #: See :attr:`available_resources` and :doc:`resources`. - constrained: deque[str] + constrained: HeapSet[TaskState] #: Number of tasks that can be executing in parallel. #: At any given time, :meth:`executing_count` <= nthreads. @@ -1245,8 +1246,8 @@ def __init__( self.comm_nbytes = 0 self.missing_dep_flight = set() self.generation = 0 - self.ready = [] - self.constrained = deque() + self.ready = HeapSet(key=operator.attrgetter("priority")) + self.constrained = HeapSet(key=operator.attrgetter("priority")) self.executing = set() self.in_flight_tasks = set() self.executed_count = 0 @@ -1403,6 +1404,9 @@ def _purge_state(self, ts: TaskState) -> None: ts.next = None ts.done = False + self.missing_dep_flight.discard(ts) + self.ready.discard(ts) + self.constrained.discard(ts) self.executing.discard(ts) self.long_running.discard(ts) self.in_flight_tasks.discard(ts) @@ -1569,53 +1573,44 @@ def _ensure_computing(self) -> RecsInstrs: return {}, [] recs: Recs = {} - while self.constrained and len(self.executing) < self.nthreads: - key = self.constrained[0] - ts = self.tasks.get(key, None) - if ts is None or ts.state != "constrained": - self.constrained.popleft() - continue - - # There may be duplicates in the self.constrained and self.ready queues. - # This happens if a task: - # 1. is assigned to a Worker and transitioned to ready (heappush) - # 2. is stolen (no way to pop from heap, the task stays there) - # 3. is assigned to the worker again (heappush again) - if ts in recs: - continue - - if not self._resource_restrictions_satisfied(ts): + while len(self.executing) < self.nthreads: + ts = self._next_ready_task() + if not ts: break - self.constrained.popleft() - self._acquire_resources(ts) + if self.validate: + assert ts.state in READY + assert ts not in recs recs[ts] = "executing" + self._acquire_resources(ts) self.executing.add(ts) - while self.ready and len(self.executing) < self.nthreads: - _, key = heapq.heappop(self.ready) - ts = self.tasks.get(key) - if ts is None: - # It is possible for tasks to be released while still remaining on - # `ready`. The scheduler might have re-routed to a new worker and - # told this worker to release. If the task has "disappeared", just - # continue through the heap. - continue + return recs, [] - if key in self.data: - # See comment above about duplicates - if self.validate: - assert ts not in recs or recs[ts] == "memory" - recs[ts] = "memory" - elif ts.state in READY: - # See comment above about duplicates - if self.validate: - assert ts not in recs or recs[ts] == "executing" - recs[ts] = "executing" - self.executing.add(ts) + def _next_ready_task(self) -> TaskState | None: + """Pop the top-priority task from self.ready or self.constrained""" + if self.ready and self.constrained: + tsr = self.ready.peek() + tsc = self.constrained.peek() + assert tsr.priority + assert tsc.priority + if tsc.priority < tsr.priority and self._resource_restrictions_satisfied( + tsc + ): + return self.constrained.pop() + else: + return self.ready.pop() - return recs, [] + elif self.ready: + return self.ready.pop() + + elif self.constrained: + tsc = self.constrained.peek() + if self._resource_restrictions_satisfied(tsc): + return self.constrained.pop() + + return None def _get_task_finished_msg( self, ts: TaskState, stimulus_id: str @@ -1850,11 +1845,10 @@ def _transition_waiting_constrained( for dep in ts.dependencies ) assert all(dep.state == "memory" for dep in ts.dependencies) - # FIXME https://github.com/dask/distributed/issues/6710 - # assert ts.key not in pluck(1, self.ready) - # assert ts.key not in self.constrained + assert ts not in self.ready + assert ts not in self.constrained ts.state = "constrained" - self.constrained.append(ts.key) + self.constrained.add(ts) return self._ensure_computing() def _transition_executing_rescheduled( @@ -1887,9 +1881,8 @@ def _transition_waiting_ready( ) -> RecsInstrs: if self.validate: assert ts.state == "waiting" - # FIXME https://github.com/dask/distributed/issues/6710 - # assert ts.key not in pluck(1, self.ready) - # assert ts.key not in self.constrained + assert ts not in self.ready + assert ts not in self.constrained assert not ts.waiting_for_data for dep in ts.dependencies: assert dep.key in self.data or dep.key in self.actors @@ -1900,7 +1893,7 @@ def _transition_waiting_ready( ts.state = "ready" assert ts.priority is not None - heapq.heappush(self.ready, (ts.priority, ts.key)) + self.ready.add(ts) return self._ensure_computing() @@ -2182,12 +2175,11 @@ def _transition_constrained_executing( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: if self.validate: + assert ts.state == "constrained" assert not ts.waiting_for_data assert ts.key not in self.data - assert ts.state in READY - # FIXME https://github.com/dask/distributed/issues/6710 - # assert ts.key not in pluck(1, self.ready) - # assert ts.key not in self.constrained + assert ts not in self.ready + assert ts not in self.constrained for dep in ts.dependencies: assert dep.key in self.data or dep.key in self.actors @@ -2199,12 +2191,11 @@ def _transition_ready_executing( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: if self.validate: + assert ts.state == "ready" assert not ts.waiting_for_data assert ts.key not in self.data - assert ts.state in READY - # FIXME https://github.com/dask/distributed/issues/6710 - # assert ts.key not in pluck(1, self.ready) - # assert ts.key not in self.constrained + assert ts not in self.ready + assert ts not in self.constrained assert all( dep.key in self.data or dep.key in self.actors for dep in ts.dependencies @@ -2378,7 +2369,6 @@ def _transition_released_forgotten( ("missing", "released"): _transition_missing_released, ("missing", "error"): _transition_generic_error, ("missing", "waiting"): _transition_missing_waiting, - ("ready", "error"): _transition_generic_error, ("ready", "executing"): _transition_ready_executing, ("ready", "released"): _transition_generic_released, ("released", "error"): _transition_generic_error, @@ -3061,8 +3051,8 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: "address": self.address, "nthreads": self.nthreads, "running": self.running, - "ready": self.ready, - "constrained": self.constrained, + "ready": [ts.key for ts in self.ready.sorted()], + "constrained": [ts.key for ts in self.constrained.sorted()], "data": dict.fromkeys(self.data), "data_needed": { w: [ts.key for ts in tss.sorted()] @@ -3109,15 +3099,13 @@ def _validate_task_executing(self, ts: TaskState) -> None: def _validate_task_ready(self, ts: TaskState) -> None: if ts.state == "ready": assert not ts.resource_restrictions - assert ts.key in pluck(1, self.ready) - # FIXME https://github.com/dask/distributed/issues/6710 - # assert ts.key not in self.constrained + assert ts in self.ready + assert ts not in self.constrained else: assert ts.resource_restrictions assert ts.state == "constrained" - # FIXME https://github.com/dask/distributed/issues/6710 - # assert ts.key not in pluck(1, self.ready) - assert ts.key in self.constrained + assert ts not in self.ready + assert ts in self.constrained assert ts.key not in self.data assert not ts.done @@ -3135,8 +3123,9 @@ def _validate_task_waiting(self, ts: TaskState) -> None: def _validate_task_flight(self, ts: TaskState) -> None: assert ts.key not in self.data assert ts in self.in_flight_tasks - # FIXME https://github.com/dask/distributed/issues/6710 - # assert not any(dep.key in self.ready for dep in ts.dependents) + for dep in ts.dependents: + assert dep not in self.ready + assert dep not in self.constrained assert ts.coming_from assert ts.coming_from in self.in_flight_workers assert ts.key in self.in_flight_workers[ts.coming_from] @@ -3252,26 +3241,39 @@ def validate_state(self) -> None: assert k in self.tasks, self.story(k) assert worker in self.tasks[k].who_has + # Test contents of the various sets of TaskState objects for worker, tss in self.data_needed.items(): for ts in tss: assert ts.state == "fetch" assert worker in ts.who_has - + for ts in self.missing_dep_flight: + assert ts.state == "missing" + for ts in self.ready: + assert ts.state == "ready" + for ts in self.constrained: + assert ts.state == "constrained" + # FIXME https://github.com/dask/distributed/issues/6708 + # for ts in self.in_flight_tasks: + # assert ts.state == "flight" or ( + # ts.state in ("cancelled", "resumed") and ts.previous == "flight" + # ) # FIXME https://github.com/dask/distributed/issues/6689 # for ts in self.executing: # assert ts.state == "executing" or ( # ts.state in ("cancelled", "resumed") and ts.previous == "executing" - # ), self.story(ts) + # ) # for ts in self.long_running: # assert ts.state == "long-running" or ( # ts.state in ("cancelled", "resumed") and ts.previous == "long-running" - # ), self.story(ts) + # ) # Test that there aren't multiple TaskState objects with the same key in any # Set[TaskState]. See note in TaskState.__hash__. for ts in chain( *self.data_needed.values(), self.missing_dep_flight, + self.ready, + self.constrained, self.in_flight_tasks, self.executing, self.long_running, From a2d6fc143102f086c021e5f22075bc337b2b790e Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 13 Jul 2022 00:48:02 +0100 Subject: [PATCH 2/2] test_steal.py to pick from the right of the heap again --- distributed/collections.py | 27 ++++++++++++++++++- distributed/tests/test_collections.py | 38 +++++++++++++++++++++++++++ distributed/tests/test_steal.py | 7 ++--- 3 files changed, 68 insertions(+), 4 deletions(-) diff --git a/distributed/collections.py b/distributed/collections.py index b074c353ef1..ee4bf0e163d 100644 --- a/distributed/collections.py +++ b/distributed/collections.py @@ -90,7 +90,7 @@ def discard(self, value: T) -> None: self._heap.clear() def peek(self) -> T: - """Get the smallest element without removing it""" + """Return the smallest element without removing it""" if not self._data: raise KeyError("peek into empty set") while True: @@ -109,6 +109,31 @@ def pop(self) -> T: self._data.discard(value) return value + def peekright(self) -> T: + """Return one of the largest elements (not necessarily the largest!) without + removing it. It's guaranteed that ``self.peekright() >= self.peek()``. + """ + if not self._data: + raise KeyError("peek into empty set") + while True: + value = self._heap[-1][2]() + if value in self._data: + return value + del self._heap[-1] + + def popright(self) -> T: + """Remove and return one of the largest elements (not necessarily the largest!) + It's guaranteed that ``self.popright() >= self.peek()``. + """ + if not self._data: + raise KeyError("pop from an empty set") + while True: + _, _, vref = self._heap.pop() + value = vref() + if value in self._data: + self._data.discard(value) + return value + def __iter__(self) -> Iterator[T]: """Iterate over all elements. This is a O(n) operation which returns the elements in pseudo-random order. diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index 9b752b37285..6db5011a1ed 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -1,5 +1,6 @@ from __future__ import annotations +import heapq import operator import pickle import random @@ -157,6 +158,43 @@ def __init__(self, i): assert set(heap) == {cx} +@pytest.mark.parametrize("peek", [False, True]) +def test_heapset_popright(peek): + heap = HeapSet(key=operator.attrgetter("i")) + with pytest.raises(KeyError): + heap.peekright() + with pytest.raises(KeyError): + heap.popright() + + # The heap contains broken weakrefs + for i in range(200): + c = C(f"y{i}", random.random()) + heap.add(c) + if random.random() > 0.7: + heap.remove(c) + + c0 = heap.peek() + while len(heap) > 1: + # These two code paths determine which of the two methods deals with the + # removal of broken weakrefs + if peek: + c1 = heap.peekright() + assert c1.i >= c0.i + assert heap.popright() is c1 + else: + c1 = heap.popright() + assert c1.i >= c0.i + + # Test that the heap hasn't been corrupted + h2 = heap._heap[:] + heapq.heapify(h2) + assert h2 == heap._heap + + assert heap.peekright() is c0 + assert heap.popright() is c0 + assert not heap + + def test_heapset_pickle(): """Test pickle roundtrip for a HeapSet. diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 59c6f678859..29d6320f8c6 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1063,7 +1063,7 @@ async def test_steal_concurrent_simple(c, s, *workers): await asyncio.sleep(0.1) # ready is a heap but we don't need last, just not the next - victim_key = w0.state.ready.peek().key + victim_key = w0.state.ready.peekright().key victim_ts = s.tasks[victim_key] ws0 = s.workers[w0.address] @@ -1098,7 +1098,7 @@ async def test_steal_reschedule_reset_in_flight_occupancy(c, s, *workers): await asyncio.sleep(0.01) # ready is a heap but we don't need last, just not the next - victim_key = w0.state.ready.peek().key + victim_key = w0.state.ready.peekright().key victim_ts = s.tasks[victim_key] wsA = victim_ts.processing_on @@ -1156,7 +1156,8 @@ async def test_steal_worker_dies_same_ip(c, s, w0, w1): while not w0.active_keys: await asyncio.sleep(0.01) - victim_key = w0.state.ready.peek().key + # ready is a heap but we don't need last, just not the next + victim_key = w0.state.ready.peekright().key victim_ts = s.tasks[victim_key] wsA = victim_ts.processing_on