diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 859699f9b8e..0d40f92fad0 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2813,7 +2813,7 @@ def transition_processing_memory( if tts._processing_on: self.set_duration_estimate(tts, tts._processing_on) if steal: - steal.put_key_in_stealable(tts) + steal.recalculate_cost(tts) ############################ # Update State Information # diff --git a/distributed/stealing.py b/distributed/stealing.py index e5fa43d72ee..7911403bfbf 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import logging from collections import defaultdict, deque from math import log2 @@ -63,14 +64,11 @@ def __init__(self, scheduler): for worker in scheduler.workers: self.add_worker(worker=worker) - callback_time = parse_timedelta( + self._callback_time = parse_timedelta( dask.config.get("distributed.scheduler.work-stealing-interval"), default="ms", ) # `callback_time` is in milliseconds - pc = PeriodicCallback(callback=self.balance, callback_time=callback_time * 1000) - self._pc = pc - self.scheduler.periodic_callbacks["stealing"] = pc self.scheduler.add_plugin(self) self.scheduler.extensions["stealing"] = self self.scheduler.events["stealing"] = deque(maxlen=100000) @@ -79,9 +77,36 @@ def __init__(self, scheduler): self.in_flight = dict() # { worker state: occupancy } self.in_flight_occupancy = defaultdict(lambda: 0) + self._in_flight_event = asyncio.Event() self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm + async def start(self, scheduler=None): + """Start the background coroutine to balance the tasks on the cluster. + Idempotent. + The scheduler argument is ignored. It is merely required to satisify the + plugin interface. Since this class is simultaneouly an extension, the + scheudler instance is already registered during initialization + """ + if "stealing" in self.scheduler.periodic_callbacks: + return + pc = PeriodicCallback( + callback=self.balance, callback_time=self._callback_time * 1000 + ) + pc.start() + self.scheduler.periodic_callbacks["stealing"] = pc + self._in_flight_event.set() + + async def stop(self): + """Stop the background task balancing tasks on the cluster. + This will block until all currently running stealing requests are + finished. Idempotent + """ + pc = self.scheduler.periodic_callbacks.pop("stealing", None) + if pc: + pc.stop() + await self._in_flight_event.wait() + def _to_dict(self, *, exclude: Container[str] = ()) -> dict: """ A very verbose dictionary representation for debugging purposes. @@ -118,7 +143,10 @@ def remove_worker(self, scheduler=None, worker=None): del self.stealable[worker] def teardown(self): - self._pc.stop() + pcs = self.scheduler.periodic_callbacks + if "stealing" in pcs: + pcs["stealing"].stop() + del pcs["stealing"] def transition( self, key, start, finish, compute_start=None, compute_stop=None, *args, **kwargs @@ -137,6 +165,7 @@ def transition( self.in_flight_occupancy[victim] += d["victim_duration"] if not self.in_flight: self.in_flight_occupancy.clear() + self._in_flight_event.set() def recalculate_cost(self, ts): if ts not in self.in_flight: @@ -177,7 +206,7 @@ def steal_time_ratio(self, ts): level: The location within a stealable list to place this value """ split = ts.prefix.name - if split in fast_tasks or split in self.scheduler.unknown_durations: + if split in fast_tasks: return None, None if not ts.dependencies: # no dependencies fast path @@ -233,6 +262,7 @@ def move_task_request(self, ts, victim, thief) -> str: "thief_duration": thief_duration, "stimulus_id": stimulus_id, } + self._in_flight_event.clear() self.in_flight_occupancy[victim] -= victim_duration self.in_flight_occupancy[thief] += thief_duration @@ -274,6 +304,7 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None): if not self.in_flight: self.in_flight_occupancy.clear() + self._in_flight_event.set() if self.scheduler.validate: assert ts.processing_on == victim diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 4209b801446..88cf3c96b1c 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4306,7 +4306,7 @@ async def test_retire_many_workers(c, s, *workers): config={"distributed.scheduler.default-task-durations": {"f": "10ms"}}, ) async def test_weight_occupancy_against_data_movement(c, s, a, b): - s.extensions["stealing"]._pc.callback_time = 1000000 + await s.extensions["stealing"].stop() def f(x, y=0, z=0): sleep(0.01) @@ -4329,7 +4329,7 @@ def f(x, y=0, z=0): config={"distributed.scheduler.default-task-durations": {"f": "10ms"}}, ) async def test_distribute_tasks_by_nthreads(c, s, a, b): - s.extensions["stealing"]._pc.callback_time = 1000000 + await s.extensions["stealing"].stop() def f(x, y=0): sleep(0.01) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 288e009f078..4442b5ca2ea 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1020,7 +1020,7 @@ async def test_balance_many_workers(c, s, *workers): @nodebug @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 30) async def test_balance_many_workers_2(c, s, *workers): - s.extensions["stealing"]._pc.callback_time = 100000000 + await s.extensions["stealing"].stop() futures = c.map(slowinc, range(90), delay=0.2) await wait(futures) assert {len(w.has_what) for w in s.workers.values()} == {3} diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 70c935cb47d..290188ab298 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import itertools import logging import random @@ -110,11 +111,123 @@ async def test_worksteal_many_thieves(c, s, *workers): assert sum(map(len, s.has_what.values())) < 150 -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -async def test_dont_steal_unknown_functions(c, s, a, b): - futures = c.map(inc, range(100), workers=a.address, allow_other_workers=True) - await wait(futures) - assert len(a.data) >= 95, [len(a.data), len(b.data)] +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 2, + config={"distributed.scheduler.work-stealing-interval": "10ms"}, +) +async def test_stop_plugin(c, s, a, b): + steal = s.extensions["stealing"] + + await steal.stop() + futs = c.map(slowinc, range(10), workers=[a.address], allow_other_workers=True) + await c.gather(futs) + assert len(a.data) == 10 + + # nothing happens + for _ in range(10): + await steal.stop() + + +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 2, + config={"distributed.scheduler.work-stealing-interval": "1ms"}, +) +async def test_stop_in_flight(c, s, a, b): + steal = s.extensions["stealing"] + num_tasks = 10 + futs = c.map( + slowinc, range(num_tasks), workers=[a.address], allow_other_workers=True + ) + while not steal.in_flight: + await asyncio.sleep(0) + + assert steal.in_flight + await steal.stop() + assert not steal.in_flight + + assert len(a.data) != num_tasks + del futs + while s.tasks or a.tasks or b.tasks: + await asyncio.sleep(0.1) + + futs = c.map( + slowinc, range(num_tasks), workers=[a.address], allow_other_workers=True + ) + await c.gather(futs) + assert len(a.data) == num_tasks + + del futs + while s.tasks or a.tasks or b.tasks: + await asyncio.sleep(0.1) + futs = c.map( + slowinc, range(num_tasks), workers=[a.address], allow_other_workers=True + ) + while not len(a.tasks) == num_tasks: + await asyncio.sleep(0.01) + assert len(b.tasks) == 0 + await steal.start() + await c.gather(futs) + assert len(a.tasks) != num_tasks + assert len(b.tasks) != 0 + + +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 2, + config={"distributed.scheduler.work-stealing-interval": "10ms"}, +) +async def test_allow_tasks_stolen_before_first_completes(c, s, a, b): + # https://github.com/dask/distributed/issues/5564 + from distributed import Semaphore + + steal = s.extensions["stealing"] + await steal.stop() + lock = await Semaphore(max_leases=1) + + # We will reuse the same function such that multiple dispatches have the + # same task prefix. This ensures that we have tasks queued up but all of + # them are still classified as unknown. + # The lock allows us to control the duration of the first task without + # delaying test runtime or flakyness + def blocked_task(x, lock): + if x == 0: + with lock: + return x + return x + + async with lock: + first = c.submit(blocked_task, 0, lock, workers=[a.address], key="f-0") + while first.key not in a.tasks: + await asyncio.sleep(0.001) + # Ensure the task is indeed blocked + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(first, 0.01) + + more_tasks = c.map( + blocked_task, + # zero is a sentinel for using the lock. + # Start counting at one for non-blocking funcs + range(1, 11), + lock=lock, + workers=[a.address], + key=[f"f-{ix}" for ix in range(1, 11)], + allow_other_workers=True, + ) + # All tasks are put on A since this is what we asked for. Only work + # stealing should rebalance the tasks once we allow for it + while not len(a.tasks) == 11: + await asyncio.sleep(0.1) + + assert len(b.tasks) == 0 + + await steal.start() + # A is still blocked by executing task f-1 so this can only pass if + # workstealing moves the tasks to B + await c.gather(more_tasks) + assert len(b.data) == 10 + await first @gen_cluster( @@ -256,7 +369,7 @@ async def test_dont_steal_worker_restrictions(c, s, a, b): assert len(a.tasks) == 100 assert len(b.tasks) == 0 - result = s.extensions["stealing"].balance() + s.extensions["stealing"].balance() await asyncio.sleep(0.1) @@ -376,10 +489,12 @@ async def test_steal_resource_restrictions(c, s, a): await b.close() -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 5) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 5, + config={"distributed.scheduler.work-stealing-interval": "20ms"}, +) async def test_balance_without_dependencies(c, s, *workers): - s.extensions["stealing"]._pc.callback_time = 20 - def slow(x): y = random.random() * 0.1 sleep(y) @@ -422,10 +537,12 @@ async def test_dont_steal_executing_tasks_2(c, s, a, b): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1)] * 10, - config={"distributed.scheduler.default-task-durations": {"slowidentity": 0.2}}, + config={ + "distributed.scheduler.default-task-durations": {"slowidentity": 0.2}, + "distributed.scheduler.work-stealing-interval": "20ms", + }, ) async def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest): - s.extensions["stealing"]._pc.callback_time = 20 x = c.submit(mul, b"0", 100000000, workers=a.address) # 100 MB await wait(x) @@ -441,10 +558,12 @@ async def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest): client=True, nthreads=[("127.0.0.1", 1)] * 10, worker_kwargs={"memory_limit": MEMORY_LIMIT}, - config={"distributed.scheduler.default-task-durations": {"slowidentity": 0.2}}, + config={ + "distributed.scheduler.default-task-durations": {"slowidentity": 0.2}, + "distributed.scheduler.work-stealing-interval": "20ms", + }, ) async def test_steal_when_more_tasks(c, s, a, *rest): - s.extensions["stealing"]._pc.callback_time = 20 x = c.submit(mul, b"0", 50000000, workers=a.address) # 50 MB await wait(x) @@ -463,7 +582,8 @@ async def test_steal_when_more_tasks(c, s, a, *rest): "distributed.scheduler.default-task-durations": { "slowidentity": 0.2, "slow2": 1, - } + }, + "distributed.scheduler.work-stealing-interval": "20ms", }, ) async def test_steal_more_attractive_tasks(c, s, a, *rest): @@ -471,7 +591,6 @@ def slow2(x): sleep(1) return x - s.extensions["stealing"]._pc.callback_time = 20 x = c.submit(mul, b"0", 100000000, workers=a.address) # 100 MB await wait(x) @@ -491,7 +610,7 @@ def func(x): async def assert_balanced(inp, expected, c, s, *workers): steal = s.extensions["stealing"] - steal._pc.stop() + await steal.stop() counter = itertools.count() tasks = list(concat(inp)) @@ -750,12 +869,15 @@ def long(delay): ) <= 1 -@gen_cluster(client=True, nthreads=[("127.0.0.1", 5)] * 2) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 5)] * 2, + config={"distributed.scheduler.work-stealing-interval": "20ms"}, +) async def test_cleanup_repeated_tasks(c, s, a, b): class Foo: pass - s.extensions["stealing"]._pc.callback_time = 20 await c.submit(slowidentity, -1, delay=0.1) objects = [c.submit(Foo, pure=False, workers=a.address) for _ in range(50)] @@ -802,21 +924,18 @@ async def test_lose_task(c, s, a, b): assert "Error" not in out -@gen_cluster(client=True) -async def test_worker_stealing_interval(c, s, a, b): +@pytest.mark.parametrize("interval, expected", [(None, 100), ("500ms", 500), (2, 2)]) +@gen_cluster(nthreads=[]) +async def test_parse_stealing_interval(s, interval, expected): from distributed.scheduler import WorkStealing - ws = WorkStealing(s) - assert ws._pc.callback_time == 100 - - with dask.config.set({"distributed.scheduler.work-stealing-interval": "500ms"}): - ws = WorkStealing(s) - assert ws._pc.callback_time == 500 - - # Default unit is `ms` - with dask.config.set({"distributed.scheduler.work-stealing-interval": 2}): + if interval: + ctx = dask.config.set({"distributed.scheduler.work-stealing-interval": "500ms"}) + else: + ctx = contextlib.nullcontext() + with ctx: ws = WorkStealing(s) - assert ws._pc.callback_time == 2 + s.periodic_callbacks["stealing"].callback_time == expected @gen_cluster(client=True) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 1a5a13dd419..7fe820a4da1 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1209,7 +1209,7 @@ def some_name(): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) async def test_reschedule(c, s, a, b): - s.extensions["stealing"]._pc.stop() + await s.extensions["stealing"].stop() a_address = a.address def f(x): @@ -2601,7 +2601,7 @@ async def test_forget_dependents_after_release(c, s, a): @gen_cluster(client=True) async def test_steal_during_task_deserialization(c, s, a, b, monkeypatch): stealing_ext = s.extensions["stealing"] - stealing_ext._pc.stop() + await stealing_ext.stop() from distributed.utils import ThreadPoolExecutor class CountingThreadPool(ThreadPoolExecutor):