Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Co assignment groups #7141

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 116 additions & 15 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,8 @@ class TaskState:
#: Task annotations
annotations: dict[str, Any]

cogroup: int | None

#: Cached hash of :attr:`~TaskState.client_key`
_hash: int

Expand Down Expand Up @@ -1351,6 +1353,7 @@ def __init__(self, key: str, run_spec: object, state: TaskStateState):
self.metadata = {}
self.annotations = {}
self.erred_on = set()
self.cogroup = None
TaskState._instances.add(self)

def __hash__(self) -> int:
Expand Down Expand Up @@ -1468,6 +1471,7 @@ class SchedulerState:
"""

bandwidth: int
cogroups: dict[int, set[TaskState]]

#: Clients currently connected to the scheduler
clients: dict[str, ClientState]
Expand Down Expand Up @@ -1623,7 +1627,7 @@ def __init__(
ws for ws in self.workers.values() if ws.status == Status.running
}
self.plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins}

self.cogroups = {}
# Variables from dask.config, cached by __init__ for performance
self.UNKNOWN_TASK_DURATION = parse_timedelta(
dask.config.get("distributed.scheduler.unknown-task-duration")
Expand Down Expand Up @@ -2066,6 +2070,40 @@ def transition_no_worker_memory(
pdb.set_trace()
raise

def cogroup_objective(self, cogroup: int, ws: WorkerState) -> tuple:
# Cogroups are not always connected subgraphs but if we assume they
# were, only the top prio task would need a transfer
tasks_in_group = self.cogroups[cogroup]
# TODO: this could be made more efficient / we should remeber max if it is required
ts_top_prio = max(tasks_in_group, key=lambda ts: ts.priority)
dts: TaskState
comm_bytes: int = 0
cotasks_on_worker = 0
for ts in tasks_in_group:
if ts in ws.processing or ws in ts.who_has:
cotasks_on_worker += 1
for dts in ts_top_prio.dependencies:
if (
# This is new compared to worker_objective
(dts not in tasks_in_group or dts not in ws.processing)
and ws not in dts.who_has
):
nbytes = dts.get_nbytes()
comm_bytes += nbytes

stack_time: float = ws.occupancy / ws.nthreads
start_time: float = stack_time + comm_bytes / self.bandwidth

if ts_top_prio.actor:
raise NotImplementedError("Cogroup assignment for actors not implemented")
else:
return (-cotasks_on_worker, start_time, ws.nbytes)
Comment on lines +2073 to +2100
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very naive way to decide where to put the task. We could also use a similar approach to #7076 but this felt minimal invasice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this says, prefer to pick a worker which already has many tasks in the same group even if the start time and required comms are very large.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haven't put much thought into this but basically whenever there are tasks in the same cogroup assigned to a worker, this worker will be preferred and otherwise we use the ordinary logic.

Please consider the entire decide_worker_cogroup + cogroup_objective block as highly experimental. This PR is more about the cogroup algorithm itself. This is just a cheap way of integrating it in the existing scheduling


def decide_worker_cogroup(self, ts) -> WorkerState | None:
assert ts.cogroup is not None
pool = self.running
return min(pool, key=partial(self.cogroup_objective, ts.cogroup))

def decide_worker_rootish_queuing_disabled(
self, ts: TaskState
) -> WorkerState | None:
Expand Down Expand Up @@ -2251,21 +2289,13 @@ def transition_waiting_processing(self, key, stimulus_id):
"""
try:
ts: TaskState = self.tasks[key]

if self.is_rootish(ts):
# NOTE: having two root-ish methods is temporary. When the feature flag is removed,
# there should only be one, which combines co-assignment and queuing.
# Eventually, special-casing root tasks might be removed entirely, with better heuristics.
if math.isinf(self.WORKER_SATURATION):
if not (ws := self.decide_worker_rootish_queuing_disabled(ts)):
return {ts.key: "no-worker"}, {}, {}
else:
if not (ws := self.decide_worker_rootish_queuing_enabled()):
return {ts.key: "queued"}, {}, {}
if ts.cogroup is not None:
decider = self.decide_worker_cogroup
else:
if not (ws := self.decide_worker_non_rootish(ts)):
return {ts.key: "no-worker"}, {}, {}
decider = self.decide_worker_non_rootish

if not (ws := decider(ts)):
return {ts.key: "no-worker"}, {}, {}
Comment on lines +2292 to +2298
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As already stated, I haven't dealt with queuing, yet. The structure of all the decide functions felt sufficiently confusing that I didn't know where to put the new logic. Should not be too difficult but will require some thought. I mostly wanted to verify the core logic quickly

worker_msgs = _add_to_processing(self, ts, ws)
return {}, {}, worker_msgs
except Exception as e:
Expand Down Expand Up @@ -4573,7 +4603,21 @@ def update_graph(
# Compute recommendations
recommendations: dict = {}

for ts in sorted(runnables, key=operator.attrgetter("priority"), reverse=True):
sorted_tasks = sorted(
runnables, key=operator.attrgetter("priority"), reverse=True
)

if self.cogroups:
start = max(self.cogroups.keys()) + 1
else:
start = 0
Comment on lines +4610 to +4613
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.cogroups:
start = max(self.cogroups.keys()) + 1
else:
start = 0
start = max(self.cogroups.keys(), default=-1) + 1

cogroups = coassignment_groups(sorted_tasks[::-1], start=start)
self.cogroups.update(cogroups)
for gr_ix, tss in self.cogroups.items():
for ts in tss:
ts.cogroup = gr_ix
Comment on lines +4616 to +4618
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one wants to go with a plain dict for maintaining cogroups, I think it would make more sense if this invariant were maintained in coassignment_groups (see below).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't decided, yet, what to use to maintain this. Maintenance of this structure is not implemented yet (e.g. we're not cleaning it up again). For now, I am using a dict for simplicity. I'm also not set on gr_ix being an integer fwiw


for ts in sorted_tasks:
if ts.state == "released" and ts.run_spec:
recommendations[ts.key] = "waiting"

Expand Down Expand Up @@ -8408,3 +8452,60 @@ def transition(
self.metadata[key] = ts.metadata
self.state[key] = finish
self.keys.discard(key)


def coassignment_groups(
tasks: Sequence[TaskState], start: int = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so tasks is a list of taskstates sorted in increasing priority order.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I wanted to add a TODO to verify this but this is guaranteed in update_graph so for this prototype, it works

) -> dict[int, set[TaskState]]:
groups = {}
group = start
ix = 0
min_prio = None
max_prio = None

group_dependents_seen = set()
while ix < len(tasks):
current = tasks[ix]
if min_prio is None:
min_prio = ix

if current in group_dependents_seen or not current.dependents:
min_prio = None
max_prio = None
ix += 1
continue
# There is a way to implement this faster by just continuing to iterate
# over ix and check if the next is a dependent or not. I chose to go
# this route because this is what we wrote down initially
next = min(current.dependents, key=lambda ts: ts.priority)
next_ix = ix
while tasks[next_ix] is not next:
next_ix += 1

# Detect a jump
if next_ix != ix + 1:
while len(next.dependents) == 1:
(dep,) = next.dependents
if len(dep.dependencies) != 1:
# This algorithm has the shortcoming that groups may grow
# too large if we walk straight to the dependent of a group.
# Especially in staged reductions (e.g. tree reductions, the
# next "jump" would combine multiple cogroups). For now,
# just ignore these. This means that we'll practically only
# have cogroups at the bottom of a graph but this is where
# it matters the most anyhow
group_dependents_seen.add(dep)
break
next = dep
while tasks[next_ix] is not next:
next_ix += 1
max_prio = next_ix + 1
groups[group] = set(tasks[min_prio:max_prio])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
groups[group] = set(tasks[min_prio:max_prio])
tasks = set(tasks[min_prio:max_prio])
for ts in tasks:
ts.cogroup = group
groups[group] = tasks

Rationale: this connection between TaskState and cogroup data structures must be maintained, best to do so at construction time, rather than having to remember that things are done later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose not to do this s.t. coassignment_groups is a pure function. Much easier to test and reason about. Might be slightly worse in performance but I doubt this will be relevant

group += 1
ix = max_prio
min_prio = None
max_prio = None
else:
ix = next_ix

return groups
19 changes: 19 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7628,3 +7628,22 @@ async def test_wait_for_workers_n_workers_value_check(c, s, a, b, value, excepti
ctx = nullcontext()
with ctx:
await c.wait_for_workers(value)


@gen_cluster(client=True)
async def test_repartition_coassignment(c, s, a, b):
ddf = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-17",
dtypes={"x": float, "y": float},
freq="1d",
)
assert ddf.npartitions == 16
ddf_repart = ddf.repartition(npartitions=ddf.npartitions // 2)

fut = c.compute(ddf_repart)

while not a.state.tasks and b.state.tasks:
await asyncio.sleep(0.1)

assert len(a.state.tasks) == len(b.state.tasks)
Loading