Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No longer double count transfer cost in stealing #7026

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 96 additions & 148 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@
from time import time
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast

import sortedcontainers
from tlz import topk
from tornado.ioloop import PeriodicCallback

import dask
from dask.utils import parse_timedelta

from distributed.comm.addressing import get_address_host
from distributed.core import CommClosedError, Status
from distributed.core import CommClosedError
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.utils import log_errors, recursive_to_dict

if TYPE_CHECKING:
# Recursive imports
from distributed.scheduler import Scheduler, TaskState, WorkerState
from distributed.scheduler import Scheduler, SchedulerState, TaskState, WorkerState

# Stealing requires multiple network bounces and if successful also task
# submission which may include code serialization. Therefore, be very
Expand Down Expand Up @@ -64,8 +62,6 @@ class InFlightInfo(TypedDict):

class WorkStealing(SchedulerPlugin):
scheduler: Scheduler
# ({ task states for level 0}, ..., {task states for level 14})
stealable_all: tuple[set[TaskState], ...]
# {worker: ({ task states for level 0}, ..., {task states for level 14})}
stealable: dict[str, tuple[set[TaskState], ...]]
# { task state: (worker, level) }
Expand All @@ -80,12 +76,12 @@ class WorkStealing(SchedulerPlugin):
in_flight: dict[TaskState, InFlightInfo]
# { worker state: occupancy }
in_flight_occupancy: defaultdict[WorkerState, float]
in_flight_tasks: defaultdict[WorkerState, int]
_in_flight_event: asyncio.Event
_request_counter: int

def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.stealable_all = tuple(set() for _ in range(15))
self.stealable = {}
self.key_stealable = {}

Expand All @@ -105,6 +101,7 @@ def __init__(self, scheduler: Scheduler):
self.count = 0
self.in_flight = {}
self.in_flight_occupancy = defaultdict(lambda: 0)
self.in_flight_tasks = defaultdict(lambda: 0)
self._in_flight_event = asyncio.Event()
self._request_counter = 0
self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm
Expand Down Expand Up @@ -183,6 +180,8 @@ def transition(
victim = d["victim"]
self.in_flight_occupancy[thief] -= d["thief_duration"]
self.in_flight_occupancy[victim] += d["victim_duration"]
self.in_flight_tasks[victim] += 1
self.in_flight_tasks[thief] -= 1
if not self.in_flight:
self.in_flight_occupancy.clear()
self._in_flight_event.set()
Expand All @@ -199,7 +198,6 @@ def put_key_in_stealable(self, ts: TaskState) -> None:
assert ts.processing_on
ws = ts.processing_on
worker = ws.address
self.stealable_all[level].add(ts)
self.stealable[worker][level].add(ts)
self.key_stealable[ts] = (worker, level)

Expand All @@ -213,10 +211,6 @@ def remove_key_from_stealable(self, ts: TaskState) -> None:
self.stealable[worker][level].remove(ts)
except KeyError:
pass
try:
self.stealable_all[level].remove(ts)
except KeyError:
pass

def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, None]:
"""The compute to communication time ratio of a key
Expand All @@ -234,14 +228,13 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non
if not ts.dependencies: # no dependencies fast path
return 0, 0

assert ts.processing_on
ws = ts.processing_on
compute_time = ws.processing[ts]
compute_time = self.scheduler.get_task_duration(ts)

if not compute_time:
# occupancy/ws.proccessing[ts] is only allowed to be zero for
# long running tasks which cannot be stolen
assert ts in ws.long_running
assert ts.processing_on
assert ts in ts.processing_on.long_running
return None, None

nbytes = ts.get_nbytes_deps()
Expand Down Expand Up @@ -298,6 +291,8 @@ def move_task_request(

self.in_flight_occupancy[victim] -= victim_duration
self.in_flight_occupancy[thief] += thief_duration
self.in_flight_tasks[victim] -= 1
self.in_flight_tasks[thief] += 1
return stimulus_id
except CommClosedError:
logger.info("Worker comm %r closed while stealing: %r", victim, ts)
Expand Down Expand Up @@ -403,116 +398,105 @@ def balance(self) -> None:
def combined_occupancy(ws: WorkerState) -> float:
return ws.occupancy + self.in_flight_occupancy[ws]

def maybe_move_task(
level: int,
ts: TaskState,
victim: WorkerState,
thief: WorkerState,
duration: float,
cost_multiplier: float,
) -> None:
occ_thief = combined_occupancy(thief)
occ_victim = combined_occupancy(victim)

if occ_thief + cost_multiplier * duration <= occ_victim - duration / 2:
self.move_task_request(ts, victim, thief)
log.append(
(
start,
level,
ts.key,
duration,
victim.address,
occ_victim,
thief.address,
occ_thief,
)
)
s.check_idle_saturated(victim, occ=occ_victim)
s.check_idle_saturated(thief, occ=occ_thief)
def combined_nprocessing(ws: WorkerState) -> float:
return ws.occupancy + self.in_flight_tasks[ws]

with log_errors():
i = 0
# Paused and closing workers must never become thieves
idle = [ws for ws in s.idle.values() if ws.status == Status.running]
if not idle or len(idle) == len(s.workers):
potential_thieves = set(s.idle.values())
if not potential_thieves or len(potential_thieves) == len(s.workers):
return

victim: WorkerState | None
saturated: set[WorkerState] | list[WorkerState] = s.saturated
if not saturated:
saturated = topk(10, s.workers.values(), key=combined_occupancy)
saturated = [
potential_victims: set[WorkerState] | list[WorkerState] = s.saturated
if not potential_victims:
potential_victims = topk(10, s.workers.values(), key=combined_occupancy)
potential_victims = [
ws
for ws in saturated
if combined_occupancy(ws) > 0.2 and len(ws.processing) > ws.nthreads
for ws in potential_victims
if combined_occupancy(ws) > 0.2
and combined_nprocessing(ws) > ws.nthreads
and ws not in potential_thieves
]
elif len(saturated) < 20:
saturated = sorted(saturated, key=combined_occupancy, reverse=True)
if len(idle) < 20:
idle = sorted(idle, key=combined_occupancy)

for level, cost_multiplier in enumerate(self.cost_multipliers):
if not idle:
if not potential_victims:
# TODO: Unclear how to reach this and what the implications
# are. The return is only an optimization since the for-loop
# below would be a no op but we'd safe ourselves a few loop
# cycles. Unless any measurements about runtime, occupancy,
# etc. changes we'd not get out of this and may have an
# unbalanced cluster
return
if len(potential_victims) < 20:
potential_victims = sorted(
potential_victims, key=combined_occupancy, reverse=True
)
assert potential_victims
assert potential_thieves
avg_occ_per_threads = (
self.scheduler.total_occupancy / self.scheduler.total_nthreads
)
for level, _ in enumerate(self.cost_multipliers):
if not potential_thieves:
break
for victim in list(saturated):
for victim in list(potential_victims):

stealable = self.stealable[victim.address][level]
if not stealable or not idle:
if not stealable or not potential_thieves:
continue

for ts in list(stealable):
if not potential_thieves:
break
if (
ts not in self.key_stealable
or ts.processing_on is not victim
):
stealable.discard(ts)
continue
i += 1
if not idle:
break

thieves = _potential_thieves_for(ts, idle)
if not thieves:
break
thief = thieves[i % len(thieves)]

duration = victim.processing.get(ts)
if duration is None:
stealable.discard(ts)
if not (thief := _get_thief(s, ts, potential_thieves)):
continue

maybe_move_task(
level, ts, victim, thief, duration, cost_multiplier
)

if self.cost_multipliers[level] < 20: # don't steal from public at cost
stealable = self.stealable_all[level]
for ts in list(stealable):
if not idle:
break
if ts not in self.key_stealable:
task_occ_on_victim = victim.processing.get(ts)
if task_occ_on_victim is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should need to add the thief back in

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be easier to use something like _pick_thief again and only worry about removing the thief at the correct points as opposed to first removing it and then worrying about adding it back in.

stealable.discard(ts)
continue

victim = ts.processing_on
if victim is None:
stealable.discard(ts)
continue
if combined_occupancy(victim) < 0.2:
continue
if len(victim.processing) <= victim.nthreads:
continue

i += 1
thieves = _potential_thieves_for(ts, idle)
if not thieves:
continue
thief = thieves[i % len(thieves)]
duration = victim.processing[ts]
occ_thief = combined_occupancy(thief)
occ_victim = combined_occupancy(victim)
comm_cost = self.scheduler.get_comm_cost(ts, thief)
compute = self.scheduler.get_task_duration(ts)

maybe_move_task(
level, ts, victim, thief, duration, cost_multiplier
)
if (
occ_thief + comm_cost + compute
<= occ_victim - task_occ_on_victim / 2
):
self.move_task_request(ts, victim, thief)
log.append(
(
start,
level,
ts.key,
task_occ_on_victim,
victim.address,
occ_victim,
thief.address,
occ_thief,
)
)

occ_thief = combined_occupancy(thief)
p = len(thief.processing) + self.in_flight_tasks[thief]

nc = thief.nthreads
# TODO: this is replicating some logic of
# check_idle_saturated
# pending: float = occ_thief * (p - nc) / (p * nc)
if not (p < nc or occ_thief < nc * avg_occ_per_threads / 2):
potential_thieves.discard(thief)
stealable.discard(ts)
self.scheduler.check_idle_saturated(
victim, occ=combined_occupancy(victim)
)

if log:
self.log(log)
Expand All @@ -526,8 +510,6 @@ def restart(self, scheduler: Any) -> None:
for s in stealable:
s.clear()

for s in self.stealable_all:
s.clear()
self.key_stealable.clear()

def story(self, *keys_or_ts: str | TaskState) -> list:
Expand All @@ -542,51 +524,17 @@ def story(self, *keys_or_ts: str | TaskState) -> list:
return out


def _potential_thieves_for(
ts: TaskState,
idle: sortedcontainers.SortedValuesView[WorkerState] | list[WorkerState],
) -> sortedcontainers.SortedValuesView[WorkerState] | list[WorkerState]:
"""Return the list of workers from ``idle`` that could steal ``ts``."""
if _has_restrictions(ts):
return [ws for ws in idle if _can_steal(ws, ts)]
else:
return idle


def _can_steal(thief: WorkerState, ts: TaskState) -> bool:
"""Determine whether worker ``thief`` can steal task ``ts``.

Assumes that `ts` has some restrictions.
"""
if (
ts.host_restrictions
and get_address_host(thief.address) not in ts.host_restrictions
):
return False
elif ts.worker_restrictions and thief.address not in ts.worker_restrictions:
return False

if not ts.resource_restrictions:
return True

for resource, value in ts.resource_restrictions.items():
try:
supplied = thief.resources[resource]
except KeyError:
return False
else:
if supplied < value:
return False
return True


def _has_restrictions(ts: TaskState) -> bool:
"""Determine whether the given task has restrictions and whether these
restrictions are strict.
"""
return not ts.loose_restrictions and bool(
ts.host_restrictions or ts.worker_restrictions or ts.resource_restrictions
)
Comment on lines -545 to -589
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I'm missing something, but have we completely dropped checking for restrictions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was duplicated code. I'm reusing Scheduler.valid_workers

def _get_thief(
scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState]
) -> WorkerState | None:
valid_workers = scheduler.valid_workers(ts)
if valid_workers:
subset = potential_thieves & valid_workers
if subset:
return next(iter(subset))
elif not ts.loose_restrictions:
return None
return next(iter(potential_thieves))


fast_tasks = {"split-shuffle"}
Loading