Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Priority families #6

Draft
wants to merge 3 commits into
base: queue-structural-coassign
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions distributed/families.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

import operator
from typing import TYPE_CHECKING, Collection, Iterator

if TYPE_CHECKING:
from distributed.scheduler import TaskState

prio_getter = operator.attrgetter("priority")


def group_by_family(
tasks: Collection[TaskState],
) -> Iterator[tuple[list[TaskState], bool]]:
tasks = sorted(tasks, key=prio_getter)
family_start_i = 0
while family_start_i < len(tasks):
ts = tasks[family_start_i]
prev_prio = start_prio = ts.priority[-1]
max_prio = prev_prio + 1
proper_family: bool = False
while ts.dependents:
ts = min(ts.dependents, key=prio_getter)
max_prio = ts.priority[-1]
if max_prio == prev_prio + 1:
# walk linear chains of consecutive priority
# TODO if this chain is all linear, try again from the start with the next-smallest dependent
prev_prio = max_prio
else:
# non-consecutive priority jump. this is our max node.
prev_prio = max_prio
proper_family = True
assert max_prio > start_prio + 1, (max_prio, start_prio, ts)
break

# all tasks from the start to the max (inclusive) belong to the family.
next_start_i = family_start_i + (max_prio - start_prio) + 1
yield tasks[family_start_i:next_start_i], proper_family
family_start_i = next_start_i
13 changes: 12 additions & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from distributed.diagnostics.memory_sampler import MemorySamplerExtension
from distributed.diagnostics.plugin import SchedulerPlugin, _get_plugin_name
from distributed.event import EventExtension
from distributed.families import group_by_family
from distributed.http import get_handlers
from distributed.lock import LockExtension
from distributed.metrics import monotonic, time
Expand Down Expand Up @@ -1113,6 +1114,8 @@ class TaskState:
#: The group of tasks to which this one belongs
group: TaskGroup

family: list[TaskState] | None

#: Same as of group.name
group_key: str

Expand Down Expand Up @@ -1163,6 +1166,7 @@ def __init__(self, key: str, run_spec: object, state: TaskStateState):
self.type = None # type: ignore
self.group_key = key_split_group(key)
self.group = None # type: ignore
self.family = None
self.metadata = {}
self.annotations = {}
self.erred_on = set()
Expand Down Expand Up @@ -4370,7 +4374,14 @@ def update_graph(
# Compute recommendations
recommendations: dict = {}

for ts in sorted(runnables, key=operator.attrgetter("priority"), reverse=True):
# Calculate families
sorted_runnables = sorted(runnables, key=operator.attrgetter("priority"))
for fam, proper in group_by_family(sorted_runnables):
if proper:
for ts in fam:
ts.family = fam

for ts in sorted_runnables[::-1]:
if ts.state == "released" and ts.run_spec:
recommendations[ts.key] = "waiting"

Expand Down
Loading