-
-
Notifications
You must be signed in to change notification settings - Fork 719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
No longer double count transfer cost in stealing #7026
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,21 +8,19 @@ | |
from time import time | ||
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast | ||
|
||
import sortedcontainers | ||
from tlz import topk | ||
from tornado.ioloop import PeriodicCallback | ||
|
||
import dask | ||
from dask.utils import parse_timedelta | ||
|
||
from distributed.comm.addressing import get_address_host | ||
from distributed.core import CommClosedError, Status | ||
from distributed.core import CommClosedError | ||
from distributed.diagnostics.plugin import SchedulerPlugin | ||
from distributed.utils import log_errors, recursive_to_dict | ||
|
||
if TYPE_CHECKING: | ||
# Recursive imports | ||
from distributed.scheduler import Scheduler, TaskState, WorkerState | ||
from distributed.scheduler import Scheduler, SchedulerState, TaskState, WorkerState | ||
|
||
# Stealing requires multiple network bounces and if successful also task | ||
# submission which may include code serialization. Therefore, be very | ||
|
@@ -64,8 +62,6 @@ class InFlightInfo(TypedDict): | |
|
||
class WorkStealing(SchedulerPlugin): | ||
scheduler: Scheduler | ||
# ({ task states for level 0}, ..., {task states for level 14}) | ||
stealable_all: tuple[set[TaskState], ...] | ||
# {worker: ({ task states for level 0}, ..., {task states for level 14})} | ||
stealable: dict[str, tuple[set[TaskState], ...]] | ||
# { task state: (worker, level) } | ||
|
@@ -80,12 +76,12 @@ class WorkStealing(SchedulerPlugin): | |
in_flight: dict[TaskState, InFlightInfo] | ||
# { worker state: occupancy } | ||
in_flight_occupancy: defaultdict[WorkerState, float] | ||
in_flight_tasks: defaultdict[WorkerState, int] | ||
_in_flight_event: asyncio.Event | ||
_request_counter: int | ||
|
||
def __init__(self, scheduler: Scheduler): | ||
self.scheduler = scheduler | ||
self.stealable_all = tuple(set() for _ in range(15)) | ||
self.stealable = {} | ||
self.key_stealable = {} | ||
|
||
|
@@ -105,6 +101,7 @@ def __init__(self, scheduler: Scheduler): | |
self.count = 0 | ||
self.in_flight = {} | ||
self.in_flight_occupancy = defaultdict(lambda: 0) | ||
self.in_flight_tasks = defaultdict(lambda: 0) | ||
self._in_flight_event = asyncio.Event() | ||
self._request_counter = 0 | ||
self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm | ||
|
@@ -183,6 +180,8 @@ def transition( | |
victim = d["victim"] | ||
self.in_flight_occupancy[thief] -= d["thief_duration"] | ||
self.in_flight_occupancy[victim] += d["victim_duration"] | ||
self.in_flight_tasks[victim] += 1 | ||
self.in_flight_tasks[thief] -= 1 | ||
if not self.in_flight: | ||
self.in_flight_occupancy.clear() | ||
self._in_flight_event.set() | ||
|
@@ -199,7 +198,6 @@ def put_key_in_stealable(self, ts: TaskState) -> None: | |
assert ts.processing_on | ||
ws = ts.processing_on | ||
worker = ws.address | ||
self.stealable_all[level].add(ts) | ||
self.stealable[worker][level].add(ts) | ||
self.key_stealable[ts] = (worker, level) | ||
|
||
|
@@ -213,10 +211,6 @@ def remove_key_from_stealable(self, ts: TaskState) -> None: | |
self.stealable[worker][level].remove(ts) | ||
except KeyError: | ||
pass | ||
try: | ||
self.stealable_all[level].remove(ts) | ||
except KeyError: | ||
pass | ||
|
||
def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, None]: | ||
"""The compute to communication time ratio of a key | ||
|
@@ -234,14 +228,13 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non | |
if not ts.dependencies: # no dependencies fast path | ||
return 0, 0 | ||
|
||
assert ts.processing_on | ||
ws = ts.processing_on | ||
compute_time = ws.processing[ts] | ||
compute_time = self.scheduler.get_task_duration(ts) | ||
|
||
if not compute_time: | ||
# occupancy/ws.proccessing[ts] is only allowed to be zero for | ||
# long running tasks which cannot be stolen | ||
assert ts in ws.long_running | ||
assert ts.processing_on | ||
assert ts in ts.processing_on.long_running | ||
return None, None | ||
|
||
nbytes = ts.get_nbytes_deps() | ||
|
@@ -298,6 +291,8 @@ def move_task_request( | |
|
||
self.in_flight_occupancy[victim] -= victim_duration | ||
self.in_flight_occupancy[thief] += thief_duration | ||
self.in_flight_tasks[victim] -= 1 | ||
self.in_flight_tasks[thief] += 1 | ||
return stimulus_id | ||
except CommClosedError: | ||
logger.info("Worker comm %r closed while stealing: %r", victim, ts) | ||
|
@@ -403,116 +398,105 @@ def balance(self) -> None: | |
def combined_occupancy(ws: WorkerState) -> float: | ||
return ws.occupancy + self.in_flight_occupancy[ws] | ||
|
||
def maybe_move_task( | ||
level: int, | ||
ts: TaskState, | ||
victim: WorkerState, | ||
thief: WorkerState, | ||
duration: float, | ||
cost_multiplier: float, | ||
) -> None: | ||
occ_thief = combined_occupancy(thief) | ||
occ_victim = combined_occupancy(victim) | ||
|
||
if occ_thief + cost_multiplier * duration <= occ_victim - duration / 2: | ||
self.move_task_request(ts, victim, thief) | ||
log.append( | ||
( | ||
start, | ||
level, | ||
ts.key, | ||
duration, | ||
victim.address, | ||
occ_victim, | ||
thief.address, | ||
occ_thief, | ||
) | ||
) | ||
s.check_idle_saturated(victim, occ=occ_victim) | ||
s.check_idle_saturated(thief, occ=occ_thief) | ||
def combined_nprocessing(ws: WorkerState) -> float: | ||
return ws.occupancy + self.in_flight_tasks[ws] | ||
|
||
with log_errors(): | ||
i = 0 | ||
# Paused and closing workers must never become thieves | ||
idle = [ws for ws in s.idle.values() if ws.status == Status.running] | ||
if not idle or len(idle) == len(s.workers): | ||
potential_thieves = set(s.idle.values()) | ||
if not potential_thieves or len(potential_thieves) == len(s.workers): | ||
return | ||
|
||
victim: WorkerState | None | ||
saturated: set[WorkerState] | list[WorkerState] = s.saturated | ||
if not saturated: | ||
saturated = topk(10, s.workers.values(), key=combined_occupancy) | ||
saturated = [ | ||
potential_victims: set[WorkerState] | list[WorkerState] = s.saturated | ||
if not potential_victims: | ||
potential_victims = topk(10, s.workers.values(), key=combined_occupancy) | ||
potential_victims = [ | ||
ws | ||
for ws in saturated | ||
if combined_occupancy(ws) > 0.2 and len(ws.processing) > ws.nthreads | ||
for ws in potential_victims | ||
if combined_occupancy(ws) > 0.2 | ||
and combined_nprocessing(ws) > ws.nthreads | ||
and ws not in potential_thieves | ||
] | ||
elif len(saturated) < 20: | ||
saturated = sorted(saturated, key=combined_occupancy, reverse=True) | ||
if len(idle) < 20: | ||
idle = sorted(idle, key=combined_occupancy) | ||
|
||
for level, cost_multiplier in enumerate(self.cost_multipliers): | ||
if not idle: | ||
if not potential_victims: | ||
# TODO: Unclear how to reach this and what the implications | ||
# are. The return is only an optimization since the for-loop | ||
# below would be a no op but we'd safe ourselves a few loop | ||
# cycles. Unless any measurements about runtime, occupancy, | ||
# etc. changes we'd not get out of this and may have an | ||
# unbalanced cluster | ||
return | ||
if len(potential_victims) < 20: | ||
potential_victims = sorted( | ||
potential_victims, key=combined_occupancy, reverse=True | ||
) | ||
assert potential_victims | ||
assert potential_thieves | ||
avg_occ_per_threads = ( | ||
self.scheduler.total_occupancy / self.scheduler.total_nthreads | ||
) | ||
for level, _ in enumerate(self.cost_multipliers): | ||
if not potential_thieves: | ||
break | ||
for victim in list(saturated): | ||
for victim in list(potential_victims): | ||
|
||
stealable = self.stealable[victim.address][level] | ||
if not stealable or not idle: | ||
if not stealable or not potential_thieves: | ||
continue | ||
|
||
for ts in list(stealable): | ||
if not potential_thieves: | ||
break | ||
if ( | ||
ts not in self.key_stealable | ||
or ts.processing_on is not victim | ||
): | ||
stealable.discard(ts) | ||
continue | ||
i += 1 | ||
if not idle: | ||
break | ||
|
||
thieves = _potential_thieves_for(ts, idle) | ||
if not thieves: | ||
break | ||
thief = thieves[i % len(thieves)] | ||
|
||
duration = victim.processing.get(ts) | ||
if duration is None: | ||
stealable.discard(ts) | ||
if not (thief := _get_thief(s, ts, potential_thieves)): | ||
continue | ||
|
||
maybe_move_task( | ||
level, ts, victim, thief, duration, cost_multiplier | ||
) | ||
|
||
if self.cost_multipliers[level] < 20: # don't steal from public at cost | ||
stealable = self.stealable_all[level] | ||
for ts in list(stealable): | ||
if not idle: | ||
break | ||
if ts not in self.key_stealable: | ||
task_occ_on_victim = victim.processing.get(ts) | ||
if task_occ_on_victim is None: | ||
stealable.discard(ts) | ||
continue | ||
|
||
victim = ts.processing_on | ||
if victim is None: | ||
stealable.discard(ts) | ||
continue | ||
if combined_occupancy(victim) < 0.2: | ||
continue | ||
if len(victim.processing) <= victim.nthreads: | ||
continue | ||
|
||
i += 1 | ||
thieves = _potential_thieves_for(ts, idle) | ||
if not thieves: | ||
continue | ||
thief = thieves[i % len(thieves)] | ||
duration = victim.processing[ts] | ||
occ_thief = combined_occupancy(thief) | ||
occ_victim = combined_occupancy(victim) | ||
comm_cost = self.scheduler.get_comm_cost(ts, thief) | ||
compute = self.scheduler.get_task_duration(ts) | ||
|
||
maybe_move_task( | ||
level, ts, victim, thief, duration, cost_multiplier | ||
) | ||
if ( | ||
occ_thief + comm_cost + compute | ||
<= occ_victim - task_occ_on_victim / 2 | ||
): | ||
self.move_task_request(ts, victim, thief) | ||
log.append( | ||
( | ||
start, | ||
level, | ||
ts.key, | ||
task_occ_on_victim, | ||
victim.address, | ||
occ_victim, | ||
thief.address, | ||
occ_thief, | ||
) | ||
) | ||
|
||
occ_thief = combined_occupancy(thief) | ||
p = len(thief.processing) + self.in_flight_tasks[thief] | ||
|
||
nc = thief.nthreads | ||
# TODO: this is replicating some logic of | ||
# check_idle_saturated | ||
# pending: float = occ_thief * (p - nc) / (p * nc) | ||
if not (p < nc or occ_thief < nc * avg_occ_per_threads / 2): | ||
potential_thieves.discard(thief) | ||
stealable.discard(ts) | ||
self.scheduler.check_idle_saturated( | ||
victim, occ=combined_occupancy(victim) | ||
) | ||
|
||
if log: | ||
self.log(log) | ||
|
@@ -526,8 +510,6 @@ def restart(self, scheduler: Any) -> None: | |
for s in stealable: | ||
s.clear() | ||
|
||
for s in self.stealable_all: | ||
s.clear() | ||
self.key_stealable.clear() | ||
|
||
def story(self, *keys_or_ts: str | TaskState) -> list: | ||
|
@@ -542,51 +524,17 @@ def story(self, *keys_or_ts: str | TaskState) -> list: | |
return out | ||
|
||
|
||
def _potential_thieves_for( | ||
ts: TaskState, | ||
idle: sortedcontainers.SortedValuesView[WorkerState] | list[WorkerState], | ||
) -> sortedcontainers.SortedValuesView[WorkerState] | list[WorkerState]: | ||
"""Return the list of workers from ``idle`` that could steal ``ts``.""" | ||
if _has_restrictions(ts): | ||
return [ws for ws in idle if _can_steal(ws, ts)] | ||
else: | ||
return idle | ||
|
||
|
||
def _can_steal(thief: WorkerState, ts: TaskState) -> bool: | ||
"""Determine whether worker ``thief`` can steal task ``ts``. | ||
|
||
Assumes that `ts` has some restrictions. | ||
""" | ||
if ( | ||
ts.host_restrictions | ||
and get_address_host(thief.address) not in ts.host_restrictions | ||
): | ||
return False | ||
elif ts.worker_restrictions and thief.address not in ts.worker_restrictions: | ||
return False | ||
|
||
if not ts.resource_restrictions: | ||
return True | ||
|
||
for resource, value in ts.resource_restrictions.items(): | ||
try: | ||
supplied = thief.resources[resource] | ||
except KeyError: | ||
return False | ||
else: | ||
if supplied < value: | ||
return False | ||
return True | ||
|
||
|
||
def _has_restrictions(ts: TaskState) -> bool: | ||
"""Determine whether the given task has restrictions and whether these | ||
restrictions are strict. | ||
""" | ||
return not ts.loose_restrictions and bool( | ||
ts.host_restrictions or ts.worker_restrictions or ts.resource_restrictions | ||
) | ||
Comment on lines
-545
to
-589
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if I'm missing something, but have we completely dropped checking for restrictions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was duplicated code. I'm reusing |
||
def _get_thief( | ||
scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState] | ||
) -> WorkerState | None: | ||
valid_workers = scheduler.valid_workers(ts) | ||
if valid_workers: | ||
subset = potential_thieves & valid_workers | ||
if subset: | ||
return next(iter(subset)) | ||
elif not ts.loose_restrictions: | ||
return None | ||
return next(iter(potential_thieves)) | ||
|
||
|
||
fast_tasks = {"split-shuffle"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should need to add the thief back in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be easier to use something like
_pick_thief
again and only worry about removing the thief at the correct points as opposed to first removing it and then worrying about adding it back in.