From b86fe0f61c4f7c9c63151f6a3630dab3db3c9540 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Thu, 23 Jun 2022 15:47:20 -0600 Subject: [PATCH] Co-assignment when queuing is disabled --- distributed/scheduler.py | 35 ++++++++++++++++++++++++++--- distributed/tests/test_scheduler.py | 7 ++++-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9be831f8d16..adc3af81946 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -804,6 +804,14 @@ class TaskGroup: #: The result types of this TaskGroup types: set[str] + #: The worker most recently assigned a task from this group, or None when the group + #: is not identified to be root-like by `SchedulerState.decide_worker`. + last_worker: WorkerState | None + + #: If `last_worker` is not None, the number of times that worker should be assigned + #: subsequent tasks until a new worker is chosen. + last_worker_tasks_left: int + prefix: TaskPrefix | None start: float stop: float @@ -823,6 +831,8 @@ def __init__(self, name: str): self.start = 0.0 self.stop = 0.0 self.all_durations = defaultdict(float) + self.last_worker = None + self.last_worker_tasks_left = 0 def add_duration(self, action: str, start: float, stop: float) -> None: duration = stop - start @@ -1331,7 +1341,7 @@ def __init__( self.unrunnable = unrunnable self.validate = validate self.workers = workers - self.running = { + self.running: set[WorkerState] = { ws for ws in self.workers.values() if ws.status == Status.running } self.plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins} @@ -1791,13 +1801,32 @@ def decide_worker( and len(tg.dependencies) < 5 and sum(map(len, tg.dependencies)) < 5 ): - if math.isinf(self.WORKER_SATURATION): + if math.isinf(self.WORKER_SATURATION): # no scheduler-side queuing pool = self.idle.values() if self.idle else self.running if not pool: recommendations[ts.key] = "no-worker" return None - return min(pool, key=lambda ws: len(ws.processing) / ws.nthreads) + lws = tg.last_worker + if not ( + lws + and tg.last_worker_tasks_left + and self.workers.get(lws.address) is lws + ): + # Last-used worker is full or unknown; pick a new worker for the next few tasks + ws = min(pool, key=partial(self.worker_objective, ts)) + tg.last_worker_tasks_left = math.floor( + (len(tg) / self.total_nthreads) * ws.nthreads + ) + else: + ws = lws + + # Record `last_worker`, or clear it on the final task + tg.last_worker = ( + ws if tg.states["released"] + tg.states["waiting"] > 1 else None + ) + tg.last_worker_tasks_left -= 1 + return ws if not self.idle: # All workers busy? Task gets/stays queued. diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index b53fa500d8b..509cad37124 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -134,7 +134,7 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c): assert x.key in a.data or x.key in b.data -@pytest.mark.skip("Current queuing does not support co-assignment") +# @pytest.mark.skip("Current queuing does not support co-assignment") @pytest.mark.parametrize("ndeps", [0, 1, 4]) @pytest.mark.parametrize( "nthreads", @@ -147,7 +147,10 @@ def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads): @gen_cluster( client=True, nthreads=nthreads, - config={"distributed.scheduler.work-stealing": False}, + config={ + "distributed.scheduler.work-stealing": False, + "distributed.scheduler.worker-saturation": float("inf"), + }, scheduler_kwargs=dict( # TODO remove dashboard=True, dashboard_address=":8787",