Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore widely-shared dependencies in decide_worker #5325

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 49 additions & 15 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,19 @@ class TaskGroup:
"last_worker_tasks_left",
]

name: str
prefix: TaskPrefix | None
states: dict[str, int]
dependencies: set[TaskGroup]
nbytes_total: int
duration: float
types: set[str]
start: float
stop: float
all_durations: defaultdict[str, float]
last_worker: WorkerState | None
last_worker_tasks_left: int

Comment on lines +821 to +833
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more thoroughly dealt with in #6253

def __init__(self, name: str):
self.name = name
self.prefix: TaskPrefix | None = None
Expand All @@ -830,7 +843,7 @@ def __init__(self, name: str):
self.start: float = 0.0
self.stop: float = 0.0
self.all_durations: defaultdict[str, float] = defaultdict(float)
self.last_worker = None # type: ignore
self.last_worker = None
self.last_worker_tasks_left = 0

def add_duration(self, action: str, start: float, stop: float):
Expand Down Expand Up @@ -1865,7 +1878,7 @@ def transition_no_worker_memory(
pdb.set_trace()
raise

def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None
def decide_worker(self, ts: TaskState) -> WorkerState | None:
"""
Decide on a worker for task *ts*. Return a WorkerState.

Expand All @@ -1879,9 +1892,8 @@ def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None
in a round-robin fashion.
"""
if not self.workers:
return None # type: ignore
return None

ws: WorkerState
tg: TaskGroup = ts.group
valid_workers: set = self.valid_workers(ts)

Expand All @@ -1892,15 +1904,15 @@ def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None
):
self.unrunnable.add(ts)
ts.state = "no-worker"
return None # type: ignore
return None

# Group is larger than cluster with few dependencies?
# Minimize future data transfers.
# Group fills the cluster and dependencies are much smaller than cluster? Minimize future data transfers.
ndeps_cutoff: int = min(5, len(self.workers))
if (
valid_workers is None
and len(tg) > self.total_nthreads * 2
and len(tg.dependencies) < 5
and sum(map(len, tg.dependencies)) < 5
and len(tg) >= self.total_nthreads
and len(tg.dependencies) < ndeps_cutoff
and sum(map(len, tg.dependencies)) < ndeps_cutoff
):
ws = tg.last_worker

Expand Down Expand Up @@ -1955,7 +1967,8 @@ def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None
type(ws),
ws,
)
assert ws.address in self.workers
if ws:
assert ws.address in self.workers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also in #6253


return ws

Expand Down Expand Up @@ -7478,8 +7491,11 @@ def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict:


def decide_worker(
ts: TaskState, all_workers, valid_workers: set | None, objective
) -> WorkerState: # -> WorkerState | None
ts: TaskState,
all_workers: Collection[WorkerState],
valid_workers: set[WorkerState] | None,
objective,
) -> WorkerState | None:
"""
Decide which worker should take task *ts*.

Expand All @@ -7495,16 +7511,34 @@ def decide_worker(
of bytes sent between workers. This is determined by calling the
*objective* function.
"""
ws: WorkerState = None # type: ignore
ws: WorkerState | None = None
wws: WorkerState
dts: TaskState
deps: set = ts.dependencies
candidates: set
n_workers: int = len(valid_workers if valid_workers is not None else all_workers)
assert all([dts.who_has for dts in deps])
if ts.actor:
candidates = set(all_workers)
else:
candidates = {wws for dts in deps for wws in dts.who_has}
candidates = {
wws
for dts in deps
# Ignore dependencies that will need to be, or already are, copied to all workers
if len(dts.who_has) < n_workers
and not (
len(dts.dependents) >= n_workers
and len(dts.group) < n_workers // 2
Comment on lines +7529 to +7531
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and not (
len(dts.dependents) >= n_workers
and len(dts.group) < n_workers // 2
and (
len(dts.dependents) < n_workers
or len(dts.group) >= n_workers // 2

# Really want something like:
# map(len, dts.group.dependents) >= nthreads and len(dts.group) < n_workers // 2
# Or at least
# len(dts.dependents) * len(dts.group) >= nthreads and len(dts.group) < n_workers // 2
# But `nthreads` is O(k) to calculate if given `valid_workers`.
# and the `map(len, dts.group.dependents)` could be extremely expensive since we can't put
# much of an upper bound on it.
)
for wws in dts.who_has
}
if valid_workers is None:
if not candidates:
candidates = set(all_workers)
Expand Down
150 changes: 150 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c):
],
)
def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads):
if ndeps >= len(nthreads):
pytest.skip()

@gen_cluster(
client=True,
nthreads=nthreads,
Expand Down Expand Up @@ -239,6 +242,153 @@ def random(**kwargs):
test_decide_worker_coschedule_order_neighbors_()


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4)
async def test_decide_worker_common_dep_ignored(client, s, *workers):
r"""
When we have basic linear chains, but all the downstream tasks also share a common dependency, ignore that dependency.

i j k l m n o p
\__\__\__\___/__/__/__/
| | | | | | | | |
| | | | X | | | |
a b c d e f g h

^ Ignore the location of X when picking a worker for i..p.
It will end up being copied to all workers anyway.

If a dependency will end up on every worker regardless, because many things depend on it,
we should ignore it when selecting our candidate workers. Otherwise, we'll end up considering
every worker as a candidate, which is 1) slow and 2) often leads to poor choices.
"""
roots = [
delayed(slowinc)(1, 0.1 / (i + 1), dask_key_name=f"root-{i}") for i in range(16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for choosing delayed over futures?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. It made it easy to visualize the graph though. Would we prefer futures?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense. Typically I would prefer futures since if something goes wrong and we'd need to debug, futures contain fewer layers. However, the visualization argument is strong. I would recommend leaving a comment in-code to avoid overly eager refactoring down the line

]
# This shared dependency will get copied to all workers, eventually making all workers valid candidates for each dep
everywhere = delayed(None, name="everywhere")
deps = [
delayed(lambda x, y: None)(r, everywhere, dask_key_name=f"dep-{i}")
for i, r in enumerate(roots)
]

rs, ds = dask.persist(roots, deps)
await wait(ds)
gjoseph92 marked this conversation as resolved.
Show resolved Hide resolved

keys = {
worker.name: dict(
root_keys=sorted(
[int(k.split("-")[1]) for k in worker.data if k.startswith("root")]
),
deps_of_root=sorted(
[int(k.split("-")[1]) for k in worker.data if k.startswith("dep")]
),
)
for worker in workers
}

for k in keys.values():
assert k["root_keys"] == k["deps_of_root"]

for worker in workers:
log = worker.incoming_transfer_log
if log:
assert len(log) == 1
assert list(log[0]["keys"]) == ["everywhere"]


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4)
async def test_decide_worker_large_subtrees_colocated(client, s, *workers):
r"""
Ensure that the above "ignore common dependencies" logic doesn't affect wide (but isolated) subtrees.

........ ........ ........ ........
\\\\//// \\\\//// \\\\//// \\\\////
a b c d
Comment on lines +303 to +305
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't put much thought into this question, therefore a short answer is more than enough.

Would your logic be impacted if the subtrees flow together again, i.e. they have a common dependent or set of dependents as in a tree reduction.

If the answer is "no, this logic doesn't go that deep into the graph" (which is what I'm currently guessing), that's fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it doesn't go any further into the graph.


Each one of a, b, etc. has more dependents than there are workers. But just because a has
lots of dependents doesn't necessarily mean it will end up copied to every worker.
Because a also has a few siblings, a's dependents shouldn't spread out over the whole cluster.
"""
roots = [delayed(inc)(i, dask_key_name=f"root-{i}") for i in range(len(workers))]
deps = [
delayed(inc)(r, dask_key_name=f"dep-{i}-{j}")
for i, r in enumerate(roots)
for j in range(len(workers) * 2)
]

rs, ds = dask.persist(roots, deps)
await wait(ds)

keys = {
worker.name: dict(
root_keys={
int(k.split("-")[1]) for k in worker.data if k.startswith("root")
},
deps_of_root={
int(k.split("-")[1]) for k in worker.data if k.startswith("dep")
},
)
for worker in workers
}

for k in keys.values():
assert k["root_keys"] == k["deps_of_root"]
assert len(k["root_keys"]) == len(roots) / len(workers)

for worker in workers:
assert not worker.incoming_transfer_log


@gen_cluster(
client=True,
nthreads=[("127.0.0.1", 1)] * 4,
config={"distributed.scheduler.work-stealing": False},
)
async def test_decide_worker_large_multiroot_subtrees_colocated(client, s, *workers):
r"""
Same as the above test, but also check isolated trees with multiple roots.

........ ........ ........ ........
\\\\//// \\\\//// \\\\//// \\\\////
a b c d e f g h
"""
roots = [
delayed(inc)(i, dask_key_name=f"root-{i}") for i in range(len(workers) * 2)
]
deps = [
delayed(lambda x, y: None)(
r, roots[i * 2 + 1], dask_key_name=f"dep-{i * 2}-{j}"
)
for i, r in enumerate(roots[::2])
for j in range(len(workers) * 2)
]

rs, ds = dask.persist(roots, deps)
await wait(ds)

keys = {
worker.name: dict(
root_keys={
int(k.split("-")[1]) for k in worker.data if k.startswith("root")
},
deps_of_root=set().union(
*(
(int(k.split("-")[1]), int(k.split("-")[1]) + 1)
for k in worker.data
if k.startswith("dep")
)
),
)
for worker in workers
}

for k in keys.values():
assert k["root_keys"] == k["deps_of_root"]
assert len(k["root_keys"]) == len(roots) / len(workers)

for worker in workers:
assert not worker.incoming_transfer_log


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3)
async def test_move_data_over_break_restrictions(client, s, a, b, c):
[x] = await client.scatter([1], workers=b.address)
Expand Down