From 5054c19cab631b7d8d0f36ba221c3e0996cd4816 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 21 Jan 2022 12:38:38 +0000 Subject: [PATCH] Paused workers shouldn't steal tasks (#5665) --- distributed/stealing.py | 11 ++++---- distributed/tests/test_scheduler.py | 7 +++-- distributed/tests/test_steal.py | 40 +++++++++++++++++++++-------- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index 101a228ce04..337ff24f756 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -15,7 +15,7 @@ from dask.utils import parse_timedelta from .comm.addressing import get_address_host -from .core import CommClosedError +from .core import CommClosedError, Status from .diagnostics.plugin import SchedulerPlugin from .utils import log_errors, recursive_to_dict @@ -393,22 +393,23 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier): with log_errors(): i = 0 - idle = s.idle.values() - saturated = s.saturated + # Paused and closing workers must never become thieves + idle = [ws for ws in s.idle.values() if ws.status == Status.running] if not idle or len(idle) == len(s.workers): return log = [] start = time() - if not s.saturated: + saturated = s.saturated + if not saturated: saturated = topk(10, s.workers.values(), key=combined_occupancy) saturated = [ ws for ws in saturated if combined_occupancy(ws) > 0.2 and len(ws.processing) > ws.nthreads ] - elif len(s.saturated) < 20: + elif len(saturated) < 20: saturated = sorted(saturated, key=combined_occupancy, reverse=True) if len(idle) < 20: idle = sorted(idle, key=combined_occupancy) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 0f958120483..e2f32ce5f2b 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -3239,8 +3239,11 @@ async def test_avoid_paused_workers(c, s, w1, w2, w3): while s.workers[w2.address].status != Status.paused: await asyncio.sleep(0.01) futures = c.map(slowinc, range(8), delay=0.1) - while (len(w1.tasks), len(w2.tasks), len(w3.tasks)) != (4, 0, 4): - await asyncio.sleep(0.01) + await wait(futures) + assert w1.data + assert not w2.data + assert w3.data + assert len(w1.data) + len(w3.data) == 8 @gen_cluster(client=True, nthreads=[("", 1)]) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 72bf63c1afb..2205ba154dc 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -16,6 +16,7 @@ from distributed import Lock, Nanny, Worker, wait, worker_client from distributed.compatibility import LINUX, WINDOWS from distributed.config import config +from distributed.core import Status from distributed.metrics import time from distributed.scheduler import key_split from distributed.system import MEMORY_LIMIT @@ -816,28 +817,47 @@ async def test_steal_twice(c, s, a, b): while len(s.tasks) < 100: # tasks are all allocated await asyncio.sleep(0.01) + # Wait for b to start stealing tasks + while len(b.tasks) < 30: + await asyncio.sleep(0.01) # Army of new workers arrives to help - workers = await asyncio.gather(*(Worker(s.address, loop=s.loop) for _ in range(20))) + workers = await asyncio.gather(*(Worker(s.address) for _ in range(20))) await wait(futures) - has_what = dict(s.has_what) # take snapshot - empty_workers = [w for w, keys in has_what.items() if not len(keys)] - if len(empty_workers) > 2: - pytest.fail( - "Too many workers without keys (%d out of %d)" - % (len(empty_workers), len(has_what)) - ) - assert max(map(len, has_what.values())) < 30 + # Note: this includes a and b + empty_workers = [w for w, keys in s.has_what.items() if not keys] + assert ( + len(empty_workers) < 3 + ), f"Too many workers without keys ({len(empty_workers)} out of {len(s.workers)})" + # This also tests that some tasks were stolen from b + # (see `while len(b.tasks) < 30` above) + assert max(map(len, s.has_what.values())) < 30 assert a.in_flight_tasks == 0 assert b.in_flight_tasks == 0 - await c._close() await asyncio.gather(*(w.close() for w in workers)) +@gen_cluster(client=True, nthreads=[("", 1)] * 3) +async def test_paused_workers_must_not_steal(c, s, w1, w2, w3): + w2.memory_pause_fraction = 1e-15 + while s.workers[w2.address].status != Status.paused: + await asyncio.sleep(0.01) + + x = c.submit(inc, 1, workers=w1.address) + await wait(x) + + futures = [c.submit(slowadd, x, i, delay=0.1) for i in range(10)] + await wait(futures) + + assert w1.data + assert not w2.data + assert w3.data + + @gen_cluster(client=True) async def test_dont_steal_already_released(c, s, a, b): future = c.submit(slowinc, 1, delay=0.05, workers=a.address)