diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 51606398cc9..c9e4aadd636 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1805,10 +1805,23 @@ def decide_worker(self, ts: TaskState) -> WorkerState | None: if not (ws and tg.last_worker_tasks_left and ws.address in self.workers): # Last-used worker is full or unknown; pick a new worker for the next few tasks + + # We just pick the worker with the shortest queue (or if queuing is disabled, + # the fewest processing tasks). We've already decided dependencies are unimportant, + # so we don't care to schedule near them. + backlog = operator.attrgetter( + "processing" if math.isinf(self.WORKER_OVERSATURATION) else "queued" + ) ws = min( - (self.idle or self.workers).values(), - key=partial(self.worker_objective, ts), + self.workers.values(), key=lambda ws: len(backlog(ws)) / ws.nthreads ) + if self.validate: + assert ws is not tg.last_worker, ( + f"Colocation reused worker {ws} for {tg}, " + f"idle: {list(self.idle.values())}, " + f"workers: {list(self.workers.values())}" + ) + tg.last_worker_tasks_left = math.floor( (len(tg) / self.total_nthreads) * ws.nthreads ) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 4dad0b256c8..09af6984076 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -320,6 +320,35 @@ async def _test_oversaturation_factor(c, s, a, b): _test_oversaturation_factor() +@pytest.mark.parametrize( + "saturation_factor", + [ + 0.0, + 1.0, + pytest.param( + float("inf"), + marks=pytest.mark.skip("https://github.com/dask/distributed/issues/6597"), + ), + ], +) +@gen_cluster( + client=True, + nthreads=[("", 2), ("", 1)], +) +async def test_oversaturation_multiple_task_groups(c, s, a, b, saturation_factor): + s.WORKER_OVERSATURATION = saturation_factor + xs = [delayed(i, name=f"x-{i}") for i in range(9)] + ys = [delayed(i, name=f"y-{i}") for i in range(9)] + zs = [x + y for x, y in zip(xs, ys)] + + await c.gather(c.compute(zs)) + + assert not a.incoming_transfer_log, [l["keys"] for l in a.incoming_transfer_log] + assert not b.incoming_transfer_log, [l["keys"] for l in b.incoming_transfer_log] + assert len(a.tasks) == 18 + assert len(b.tasks) == 9 + + @gen_cluster( client=True, nthreads=[("", 2)] * 2,