Skip to content

Commit

Permalink
Paused workers shouldn't steal tasks (dask#5665)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 21, 2022
1 parent eadb35f commit 5054c19
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 17 deletions.
11 changes: 6 additions & 5 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dask.utils import parse_timedelta

from .comm.addressing import get_address_host
from .core import CommClosedError
from .core import CommClosedError, Status
from .diagnostics.plugin import SchedulerPlugin
from .utils import log_errors, recursive_to_dict

Expand Down Expand Up @@ -393,22 +393,23 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier):

with log_errors():
i = 0
idle = s.idle.values()
saturated = s.saturated
# Paused and closing workers must never become thieves
idle = [ws for ws in s.idle.values() if ws.status == Status.running]
if not idle or len(idle) == len(s.workers):
return

log = []
start = time()

if not s.saturated:
saturated = s.saturated
if not saturated:
saturated = topk(10, s.workers.values(), key=combined_occupancy)
saturated = [
ws
for ws in saturated
if combined_occupancy(ws) > 0.2 and len(ws.processing) > ws.nthreads
]
elif len(s.saturated) < 20:
elif len(saturated) < 20:
saturated = sorted(saturated, key=combined_occupancy, reverse=True)
if len(idle) < 20:
idle = sorted(idle, key=combined_occupancy)
Expand Down
7 changes: 5 additions & 2 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3239,8 +3239,11 @@ async def test_avoid_paused_workers(c, s, w1, w2, w3):
while s.workers[w2.address].status != Status.paused:
await asyncio.sleep(0.01)
futures = c.map(slowinc, range(8), delay=0.1)
while (len(w1.tasks), len(w2.tasks), len(w3.tasks)) != (4, 0, 4):
await asyncio.sleep(0.01)
await wait(futures)
assert w1.data
assert not w2.data
assert w3.data
assert len(w1.data) + len(w3.data) == 8


@gen_cluster(client=True, nthreads=[("", 1)])
Expand Down
40 changes: 30 additions & 10 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from distributed import Lock, Nanny, Worker, wait, worker_client
from distributed.compatibility import LINUX, WINDOWS
from distributed.config import config
from distributed.core import Status
from distributed.metrics import time
from distributed.scheduler import key_split
from distributed.system import MEMORY_LIMIT
Expand Down Expand Up @@ -816,28 +817,47 @@ async def test_steal_twice(c, s, a, b):

while len(s.tasks) < 100: # tasks are all allocated
await asyncio.sleep(0.01)
# Wait for b to start stealing tasks
while len(b.tasks) < 30:
await asyncio.sleep(0.01)

# Army of new workers arrives to help
workers = await asyncio.gather(*(Worker(s.address, loop=s.loop) for _ in range(20)))
workers = await asyncio.gather(*(Worker(s.address) for _ in range(20)))

await wait(futures)

has_what = dict(s.has_what) # take snapshot
empty_workers = [w for w, keys in has_what.items() if not len(keys)]
if len(empty_workers) > 2:
pytest.fail(
"Too many workers without keys (%d out of %d)"
% (len(empty_workers), len(has_what))
)
assert max(map(len, has_what.values())) < 30
# Note: this includes a and b
empty_workers = [w for w, keys in s.has_what.items() if not keys]
assert (
len(empty_workers) < 3
), f"Too many workers without keys ({len(empty_workers)} out of {len(s.workers)})"
# This also tests that some tasks were stolen from b
# (see `while len(b.tasks) < 30` above)
assert max(map(len, s.has_what.values())) < 30

assert a.in_flight_tasks == 0
assert b.in_flight_tasks == 0

await c._close()
await asyncio.gather(*(w.close() for w in workers))


@gen_cluster(client=True, nthreads=[("", 1)] * 3)
async def test_paused_workers_must_not_steal(c, s, w1, w2, w3):
w2.memory_pause_fraction = 1e-15
while s.workers[w2.address].status != Status.paused:
await asyncio.sleep(0.01)

x = c.submit(inc, 1, workers=w1.address)
await wait(x)

futures = [c.submit(slowadd, x, i, delay=0.1) for i in range(10)]
await wait(futures)

assert w1.data
assert not w2.data
assert w3.data


@gen_cluster(client=True)
async def test_dont_steal_already_released(c, s, a, b):
future = c.submit(slowinc, 1, delay=0.05, workers=a.address)
Expand Down

0 comments on commit 5054c19

Please sign in to comment.