diff --git a/distributed/stealing.py b/distributed/stealing.py index f1539c886e6..32a6274ed20 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -8,21 +8,19 @@ from time import time from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast -import sortedcontainers from tlz import topk from tornado.ioloop import PeriodicCallback import dask from dask.utils import parse_timedelta -from distributed.comm.addressing import get_address_host -from distributed.core import CommClosedError, Status +from distributed.core import CommClosedError from distributed.diagnostics.plugin import SchedulerPlugin from distributed.utils import log_errors, recursive_to_dict if TYPE_CHECKING: # Recursive imports - from distributed.scheduler import Scheduler, TaskState, WorkerState + from distributed.scheduler import Scheduler, SchedulerState, TaskState, WorkerState # Stealing requires multiple network bounces and if successful also task # submission which may include code serialization. Therefore, be very @@ -64,8 +62,6 @@ class InFlightInfo(TypedDict): class WorkStealing(SchedulerPlugin): scheduler: Scheduler - # ({ task states for level 0}, ..., {task states for level 14}) - stealable_all: tuple[set[TaskState], ...] # {worker: ({ task states for level 0}, ..., {task states for level 14})} stealable: dict[str, tuple[set[TaskState], ...]] # { task state: (worker, level) } @@ -80,12 +76,12 @@ class WorkStealing(SchedulerPlugin): in_flight: dict[TaskState, InFlightInfo] # { worker state: occupancy } in_flight_occupancy: defaultdict[WorkerState, float] + in_flight_tasks: defaultdict[WorkerState, int] _in_flight_event: asyncio.Event _request_counter: int def __init__(self, scheduler: Scheduler): self.scheduler = scheduler - self.stealable_all = tuple(set() for _ in range(15)) self.stealable = {} self.key_stealable = {} @@ -105,6 +101,7 @@ def __init__(self, scheduler: Scheduler): self.count = 0 self.in_flight = {} self.in_flight_occupancy = defaultdict(lambda: 0) + self.in_flight_tasks = defaultdict(lambda: 0) self._in_flight_event = asyncio.Event() self._request_counter = 0 self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm @@ -183,6 +180,8 @@ def transition( victim = d["victim"] self.in_flight_occupancy[thief] -= d["thief_duration"] self.in_flight_occupancy[victim] += d["victim_duration"] + self.in_flight_tasks[victim] += 1 + self.in_flight_tasks[thief] -= 1 if not self.in_flight: self.in_flight_occupancy.clear() self._in_flight_event.set() @@ -199,7 +198,6 @@ def put_key_in_stealable(self, ts: TaskState) -> None: assert ts.processing_on ws = ts.processing_on worker = ws.address - self.stealable_all[level].add(ts) self.stealable[worker][level].add(ts) self.key_stealable[ts] = (worker, level) @@ -213,10 +211,6 @@ def remove_key_from_stealable(self, ts: TaskState) -> None: self.stealable[worker][level].remove(ts) except KeyError: pass - try: - self.stealable_all[level].remove(ts) - except KeyError: - pass def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, None]: """The compute to communication time ratio of a key @@ -234,14 +228,13 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non if not ts.dependencies: # no dependencies fast path return 0, 0 - assert ts.processing_on - ws = ts.processing_on - compute_time = ws.processing[ts] + compute_time = self.scheduler.get_task_duration(ts) if not compute_time: # occupancy/ws.proccessing[ts] is only allowed to be zero for # long running tasks which cannot be stolen - assert ts in ws.long_running + assert ts.processing_on + assert ts in ts.processing_on.long_running return None, None nbytes = ts.get_nbytes_deps() @@ -298,6 +291,8 @@ def move_task_request( self.in_flight_occupancy[victim] -= victim_duration self.in_flight_occupancy[thief] += thief_duration + self.in_flight_tasks[victim] -= 1 + self.in_flight_tasks[thief] += 1 return stimulus_id except CommClosedError: logger.info("Worker comm %r closed while stealing: %r", victim, ts) @@ -403,64 +398,55 @@ def balance(self) -> None: def combined_occupancy(ws: WorkerState) -> float: return ws.occupancy + self.in_flight_occupancy[ws] - def maybe_move_task( - level: int, - ts: TaskState, - victim: WorkerState, - thief: WorkerState, - duration: float, - cost_multiplier: float, - ) -> None: - occ_thief = combined_occupancy(thief) - occ_victim = combined_occupancy(victim) - - if occ_thief + cost_multiplier * duration <= occ_victim - duration / 2: - self.move_task_request(ts, victim, thief) - log.append( - ( - start, - level, - ts.key, - duration, - victim.address, - occ_victim, - thief.address, - occ_thief, - ) - ) - s.check_idle_saturated(victim, occ=occ_victim) - s.check_idle_saturated(thief, occ=occ_thief) + def combined_nprocessing(ws: WorkerState) -> float: + return ws.occupancy + self.in_flight_tasks[ws] with log_errors(): i = 0 # Paused and closing workers must never become thieves - idle = [ws for ws in s.idle.values() if ws.status == Status.running] - if not idle or len(idle) == len(s.workers): + potential_thieves = set(s.idle.values()) + if not potential_thieves or len(potential_thieves) == len(s.workers): return - victim: WorkerState | None - saturated: set[WorkerState] | list[WorkerState] = s.saturated - if not saturated: - saturated = topk(10, s.workers.values(), key=combined_occupancy) - saturated = [ + potential_victims: set[WorkerState] | list[WorkerState] = s.saturated + if not potential_victims: + potential_victims = topk(10, s.workers.values(), key=combined_occupancy) + potential_victims = [ ws - for ws in saturated - if combined_occupancy(ws) > 0.2 and len(ws.processing) > ws.nthreads + for ws in potential_victims + if combined_occupancy(ws) > 0.2 + and combined_nprocessing(ws) > ws.nthreads + and ws not in potential_thieves ] - elif len(saturated) < 20: - saturated = sorted(saturated, key=combined_occupancy, reverse=True) - if len(idle) < 20: - idle = sorted(idle, key=combined_occupancy) - - for level, cost_multiplier in enumerate(self.cost_multipliers): - if not idle: + if not potential_victims: + # TODO: Unclear how to reach this and what the implications + # are. The return is only an optimization since the for-loop + # below would be a no op but we'd safe ourselves a few loop + # cycles. Unless any measurements about runtime, occupancy, + # etc. changes we'd not get out of this and may have an + # unbalanced cluster + return + if len(potential_victims) < 20: + potential_victims = sorted( + potential_victims, key=combined_occupancy, reverse=True + ) + assert potential_victims + assert potential_thieves + avg_occ_per_threads = ( + self.scheduler.total_occupancy / self.scheduler.total_nthreads + ) + for level, _ in enumerate(self.cost_multipliers): + if not potential_thieves: break - for victim in list(saturated): + for victim in list(potential_victims): + stealable = self.stealable[victim.address][level] - if not stealable or not idle: + if not stealable or not potential_thieves: continue for ts in list(stealable): + if not potential_thieves: + break if ( ts not in self.key_stealable or ts.processing_on is not victim @@ -468,51 +454,49 @@ def maybe_move_task( stealable.discard(ts) continue i += 1 - if not idle: - break - - thieves = _potential_thieves_for(ts, idle) - if not thieves: - break - thief = thieves[i % len(thieves)] - - duration = victim.processing.get(ts) - if duration is None: - stealable.discard(ts) + if not (thief := _get_thief(s, ts, potential_thieves)): continue - - maybe_move_task( - level, ts, victim, thief, duration, cost_multiplier - ) - - if self.cost_multipliers[level] < 20: # don't steal from public at cost - stealable = self.stealable_all[level] - for ts in list(stealable): - if not idle: - break - if ts not in self.key_stealable: + task_occ_on_victim = victim.processing.get(ts) + if task_occ_on_victim is None: stealable.discard(ts) continue - victim = ts.processing_on - if victim is None: - stealable.discard(ts) - continue - if combined_occupancy(victim) < 0.2: - continue - if len(victim.processing) <= victim.nthreads: - continue - - i += 1 - thieves = _potential_thieves_for(ts, idle) - if not thieves: - continue - thief = thieves[i % len(thieves)] - duration = victim.processing[ts] + occ_thief = combined_occupancy(thief) + occ_victim = combined_occupancy(victim) + comm_cost = self.scheduler.get_comm_cost(ts, thief) + compute = self.scheduler.get_task_duration(ts) - maybe_move_task( - level, ts, victim, thief, duration, cost_multiplier - ) + if ( + occ_thief + comm_cost + compute + <= occ_victim - task_occ_on_victim / 2 + ): + self.move_task_request(ts, victim, thief) + log.append( + ( + start, + level, + ts.key, + task_occ_on_victim, + victim.address, + occ_victim, + thief.address, + occ_thief, + ) + ) + + occ_thief = combined_occupancy(thief) + p = len(thief.processing) + self.in_flight_tasks[thief] + + nc = thief.nthreads + # TODO: this is replicating some logic of + # check_idle_saturated + # pending: float = occ_thief * (p - nc) / (p * nc) + if not (p < nc or occ_thief < nc * avg_occ_per_threads / 2): + potential_thieves.discard(thief) + stealable.discard(ts) + self.scheduler.check_idle_saturated( + victim, occ=combined_occupancy(victim) + ) if log: self.log(log) @@ -526,8 +510,6 @@ def restart(self, scheduler: Any) -> None: for s in stealable: s.clear() - for s in self.stealable_all: - s.clear() self.key_stealable.clear() def story(self, *keys_or_ts: str | TaskState) -> list: @@ -542,51 +524,17 @@ def story(self, *keys_or_ts: str | TaskState) -> list: return out -def _potential_thieves_for( - ts: TaskState, - idle: sortedcontainers.SortedValuesView[WorkerState] | list[WorkerState], -) -> sortedcontainers.SortedValuesView[WorkerState] | list[WorkerState]: - """Return the list of workers from ``idle`` that could steal ``ts``.""" - if _has_restrictions(ts): - return [ws for ws in idle if _can_steal(ws, ts)] - else: - return idle - - -def _can_steal(thief: WorkerState, ts: TaskState) -> bool: - """Determine whether worker ``thief`` can steal task ``ts``. - - Assumes that `ts` has some restrictions. - """ - if ( - ts.host_restrictions - and get_address_host(thief.address) not in ts.host_restrictions - ): - return False - elif ts.worker_restrictions and thief.address not in ts.worker_restrictions: - return False - - if not ts.resource_restrictions: - return True - - for resource, value in ts.resource_restrictions.items(): - try: - supplied = thief.resources[resource] - except KeyError: - return False - else: - if supplied < value: - return False - return True - - -def _has_restrictions(ts: TaskState) -> bool: - """Determine whether the given task has restrictions and whether these - restrictions are strict. - """ - return not ts.loose_restrictions and bool( - ts.host_restrictions or ts.worker_restrictions or ts.resource_restrictions - ) +def _get_thief( + scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState] +) -> WorkerState | None: + valid_workers = scheduler.valid_workers(ts) + if valid_workers: + subset = potential_thieves & valid_workers + if subset: + return next(iter(subset)) + elif not ts.loose_restrictions: + return None + return next(iter(potential_thieves)) fast_tasks = {"split-shuffle"} diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 4a3751d6d91..af54324ef16 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -23,9 +23,11 @@ from distributed.metrics import time from distributed.system import MEMORY_LIMIT from distributed.utils_test import ( + SizeOf, captured_logger, freeze_batched_send, gen_cluster, + gen_nbytes, inc, nodebug_setup_module, nodebug_teardown_module, @@ -169,13 +171,24 @@ async def test_stop_in_flight(c, s, a, b): del futs while s.tasks or a.state.tasks or b.state.tasks: await asyncio.sleep(0.1) + event = Event() + + def block(x, event): + event.wait() + return x + 1 + futs = c.map( - slowinc, range(num_tasks), workers=[a.address], allow_other_workers=True + block, + range(num_tasks), + event=event, + workers=[a.address], + allow_other_workers=True, ) while not len(a.state.tasks) == num_tasks: await asyncio.sleep(0.01) assert len(b.state.tasks) == 0 await steal.start() + await event.set() await c.gather(futs) assert len(a.state.tasks) != num_tasks assert len(b.state.tasks) != 0 @@ -185,6 +198,7 @@ async def test_stop_in_flight(c, s, a, b): client=True, nthreads=[("127.0.0.1", 1)] * 2, config={"distributed.scheduler.work-stealing-interval": "10ms"}, + timeout=10, ) async def test_allow_tasks_stolen_before_first_completes(c, s, a, b): # https://github.com/dask/distributed/issues/5564 @@ -233,6 +247,7 @@ def blocked_task(x, lock): await steal.start() # A is still blocked by executing task f-1 so this can only pass if # workstealing moves the tasks to B + await asyncio.sleep(5) await c.gather(more_tasks) assert len(b.data) == 10 await first @@ -634,19 +649,12 @@ def block(*args, event, **kwargs): counter = itertools.count() - class Sizeof: - def __init__(self, nbytes): - self._nbytes = nbytes - 16 - - def __sizeof__(self) -> int: - return self._nbytes - futures = [] for w, ts in zip(workers, inp): for t in sorted(ts, reverse=True): if t: [dat] = await c.scatter( - [Sizeof(int(t * s.bandwidth))], workers=w.address + [SizeOf(int(t * s.bandwidth))], workers=w.address ) else: dat = 123 @@ -721,7 +729,7 @@ def __sizeof__(self) -> int: # schedule a task on the threadpool ( [[4, 2, 2, 2, 2, 1, 1], [4, 2, 1, 1], [], [], []], - [[4, 2, 2, 2, 2], [4, 2, 1], [1], [1], [1]], + [[4, 2, 2, 2], [4, 2, 1, 1], [2], [1], [1]], ), ], ) @@ -750,21 +758,55 @@ async def test_restart(c, s, a, b): await asyncio.sleep(0.01) steal = s.extensions["stealing"] - assert any(st for st in steal.stealable_all) + # assert any(st for st in steal.stealable_all) assert any(x for L in steal.stealable.values() for x in L) await c.restart() - assert not any(x for x in steal.stealable_all) + # assert not any(x for x in steal.stealable_all) assert not any(x for L in steal.stealable.values() for x in L) +@gen_cluster(client=True) +async def test_do_not_steal_communication_heavy_tasks(c, s, a, b): + # Never steal unreasonably large tasks + steal = s.extensions["stealing"] + x = c.submit(gen_nbytes, int(s.bandwidth) * 1000, workers=a.address, pure=False) + y = c.submit(gen_nbytes, int(s.bandwidth) * 1000, workers=a.address, pure=False) + + def block_reduce(x, y, event): + event.wait() + return None + + event = Event() + futures = [ + c.submit( + block_reduce, + x, + y, + event=event, + pure=False, + workers=a.address, + allow_other_workers=True, + ) + for i in range(10) + ] + while not a.state.tasks: + await asyncio.sleep(0.1) + steal.balance() + await steal.stop() + await event.set() + await c.gather(futures) + assert not b.data + + @gen_cluster( client=True, config={"distributed.scheduler.default-task-durations": {"slowadd": 0.001}}, ) async def test_steal_communication_heavy_tasks(c, s, a, b): steal = s.extensions["stealing"] + await steal.stop() x = c.submit(mul, b"0", int(s.bandwidth), workers=a.address) y = c.submit(mul, b"1", int(s.bandwidth), workers=b.address) @@ -785,10 +827,8 @@ async def test_steal_communication_heavy_tasks(c, s, a, b): await asyncio.sleep(0.01) steal.balance() - while steal.in_flight: - await asyncio.sleep(0.001) - assert s.workers[b.address].processing + await steal.stop() @gen_cluster(client=True) @@ -810,25 +850,26 @@ async def test_steal_twice(c, s, a, b): await asyncio.sleep(0.01) # Army of new workers arrives to help - workers = await asyncio.gather(*(Worker(s.address) for _ in range(20))) + async with contextlib.AsyncExitStack() as stack: + # This is pretty timing sensitive + workers = [stack.enter_async_context(Worker(s.address)) for _ in range(10)] + workers = await asyncio.gather(*workers) - await wait(futures) + await wait(futures) - # Note: this includes a and b - empty_workers = [ws for ws in s.workers.values() if not ws.has_what] - assert ( - len(empty_workers) < 3 - ), f"Too many workers without keys ({len(empty_workers)} out of {len(s.workers)})" - # This also tests that some tasks were stolen from b - # (see `while len(b.state.tasks) < 30` above) - # If queuing is enabled, then there was nothing to steal from b, - # so this just tests the queue was balanced not-terribly. - assert max(len(ws.has_what) for ws in s.workers.values()) < 30 - - assert a.state.in_flight_tasks_count == 0 - assert b.state.in_flight_tasks_count == 0 - - await asyncio.gather(*(w.close() for w in workers)) + # Note: this includes a and b + empty_workers = [ws for ws in s.workers.values() if not ws.has_what] + assert ( + len(empty_workers) < 3 + ), f"Too many workers without keys ({len(empty_workers)} out of {len(s.workers)})" + # This also tests that some tasks were stolen from b + # (see `while len(b.state.tasks) < 30` above) + # If queuing is enabled, then there was nothing to steal from b, + # so this just tests the queue was balanced not-terribly. + assert max(len(ws.has_what) for ws in s.workers.values()) < 30 + + assert a.state.in_flight_tasks_count == 0 + assert b.state.in_flight_tasks_count == 0 @gen_cluster( diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 2147982f750..7e3457cb829 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -19,6 +19,7 @@ from tornado import gen import dask.config +from dask.sizeof import sizeof from distributed import Client, Event, Nanny, Scheduler, Worker, config, default_client from distributed.batched import BatchedSend @@ -29,6 +30,7 @@ from distributed.tests.test_batched import EchoServer from distributed.utils import get_mp_context from distributed.utils_test import ( + SizeOf, _LockedCommPool, _UnhashableCallable, assert_story, @@ -39,6 +41,7 @@ dump_cluster_state, freeze_batched_send, gen_cluster, + gen_nbytes, gen_test, inc, new_config, @@ -1024,3 +1027,23 @@ def test_ws_with_running_task(ws_with_running_task): assert ws.available_resources == {"R": 0} assert ws.total_resources == {"R": 1} assert ts.state in ("executing", "long-running") + + +def test_sizeof(): + assert sizeof(SizeOf(100)) == 100 + assert isinstance(gen_nbytes(100), SizeOf) + assert sizeof(gen_nbytes(100)) == 100 + + +@pytest.mark.parametrize( + "input, exc, msg", + [ + (12345.0, TypeError, "Expected integer"), + (-1, ValueError, "larger than"), + (0, ValueError, "larger than"), + (10, ValueError, "larger than"), + ], +) +def test_sizeof_error(input, exc, msg): + with pytest.raises(exc, match=msg): + SizeOf(input) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index e851bb0a9b8..295e0ae252e 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -37,6 +37,7 @@ from tornado.ioloop import IOLoop import dask +from dask.sizeof import sizeof from distributed import Scheduler, system from distributed import versions as version_module @@ -2458,3 +2459,27 @@ async def fetch_metrics(port: int, prefix: str | None = None) -> dict[str, Any]: if prefix is None or family.name.startswith(prefix) } return families + + +class SizeOf: + """ + An object that returns exactly nbytes when inspected by dask.sizeof.sizeof + """ + + def __init__(self, nbytes: int) -> None: + if not isinstance(nbytes, int): + raise TypeError(f"Expected integer for nbytes but got {type(nbytes)}") + size_obj = sizeof(object()) + if nbytes < size_obj: + raise ValueError( + f"Expected a value larger than {size_obj} integer but got {nbytes}." + ) + self._nbytes = nbytes - size_obj + + def __sizeof__(self) -> int: + return self._nbytes + + +def gen_nbytes(nbytes: int) -> SizeOf: + """A function that emulates exactly nbytes on the worker data structure.""" + return SizeOf(nbytes)