-
-
Notifications
You must be signed in to change notification settings - Fork 719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ignore widely-shared dependencies in decide_worker
#5325
base: main
Are you sure you want to change the base?
Changes from all commits
1c41702
aaa12a3
e1fd58b
e5175ce
d0f0955
fd7e790
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -818,6 +818,19 @@ class TaskGroup: | |||||||||||||
"last_worker_tasks_left", | ||||||||||||||
] | ||||||||||||||
|
||||||||||||||
name: str | ||||||||||||||
prefix: TaskPrefix | None | ||||||||||||||
states: dict[str, int] | ||||||||||||||
dependencies: set[TaskGroup] | ||||||||||||||
nbytes_total: int | ||||||||||||||
duration: float | ||||||||||||||
types: set[str] | ||||||||||||||
start: float | ||||||||||||||
stop: float | ||||||||||||||
all_durations: defaultdict[str, float] | ||||||||||||||
last_worker: WorkerState | None | ||||||||||||||
last_worker_tasks_left: int | ||||||||||||||
|
||||||||||||||
def __init__(self, name: str): | ||||||||||||||
self.name = name | ||||||||||||||
self.prefix: TaskPrefix | None = None | ||||||||||||||
|
@@ -830,7 +843,7 @@ def __init__(self, name: str): | |||||||||||||
self.start: float = 0.0 | ||||||||||||||
self.stop: float = 0.0 | ||||||||||||||
self.all_durations: defaultdict[str, float] = defaultdict(float) | ||||||||||||||
self.last_worker = None # type: ignore | ||||||||||||||
self.last_worker = None | ||||||||||||||
self.last_worker_tasks_left = 0 | ||||||||||||||
|
||||||||||||||
def add_duration(self, action: str, start: float, stop: float): | ||||||||||||||
|
@@ -1865,7 +1878,7 @@ def transition_no_worker_memory( | |||||||||||||
pdb.set_trace() | ||||||||||||||
raise | ||||||||||||||
|
||||||||||||||
def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None | ||||||||||||||
def decide_worker(self, ts: TaskState) -> WorkerState | None: | ||||||||||||||
""" | ||||||||||||||
Decide on a worker for task *ts*. Return a WorkerState. | ||||||||||||||
|
||||||||||||||
|
@@ -1879,9 +1892,8 @@ def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None | |||||||||||||
in a round-robin fashion. | ||||||||||||||
""" | ||||||||||||||
if not self.workers: | ||||||||||||||
return None # type: ignore | ||||||||||||||
return None | ||||||||||||||
|
||||||||||||||
ws: WorkerState | ||||||||||||||
tg: TaskGroup = ts.group | ||||||||||||||
valid_workers: set = self.valid_workers(ts) | ||||||||||||||
|
||||||||||||||
|
@@ -1892,15 +1904,15 @@ def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None | |||||||||||||
): | ||||||||||||||
self.unrunnable.add(ts) | ||||||||||||||
ts.state = "no-worker" | ||||||||||||||
return None # type: ignore | ||||||||||||||
return None | ||||||||||||||
|
||||||||||||||
# Group is larger than cluster with few dependencies? | ||||||||||||||
# Minimize future data transfers. | ||||||||||||||
# Group fills the cluster and dependencies are much smaller than cluster? Minimize future data transfers. | ||||||||||||||
ndeps_cutoff: int = min(5, len(self.workers)) | ||||||||||||||
if ( | ||||||||||||||
valid_workers is None | ||||||||||||||
and len(tg) > self.total_nthreads * 2 | ||||||||||||||
and len(tg.dependencies) < 5 | ||||||||||||||
and sum(map(len, tg.dependencies)) < 5 | ||||||||||||||
and len(tg) >= self.total_nthreads | ||||||||||||||
and len(tg.dependencies) < ndeps_cutoff | ||||||||||||||
and sum(map(len, tg.dependencies)) < ndeps_cutoff | ||||||||||||||
): | ||||||||||||||
ws = tg.last_worker | ||||||||||||||
|
||||||||||||||
|
@@ -1955,7 +1967,8 @@ def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None | |||||||||||||
type(ws), | ||||||||||||||
ws, | ||||||||||||||
) | ||||||||||||||
assert ws.address in self.workers | ||||||||||||||
if ws: | ||||||||||||||
assert ws.address in self.workers | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is also in #6253 |
||||||||||||||
|
||||||||||||||
return ws | ||||||||||||||
|
||||||||||||||
|
@@ -7478,8 +7491,11 @@ def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict: | |||||||||||||
|
||||||||||||||
|
||||||||||||||
def decide_worker( | ||||||||||||||
ts: TaskState, all_workers, valid_workers: set | None, objective | ||||||||||||||
) -> WorkerState: # -> WorkerState | None | ||||||||||||||
ts: TaskState, | ||||||||||||||
all_workers: Collection[WorkerState], | ||||||||||||||
valid_workers: set[WorkerState] | None, | ||||||||||||||
objective, | ||||||||||||||
) -> WorkerState | None: | ||||||||||||||
""" | ||||||||||||||
Decide which worker should take task *ts*. | ||||||||||||||
|
||||||||||||||
|
@@ -7495,16 +7511,34 @@ def decide_worker( | |||||||||||||
of bytes sent between workers. This is determined by calling the | ||||||||||||||
*objective* function. | ||||||||||||||
""" | ||||||||||||||
ws: WorkerState = None # type: ignore | ||||||||||||||
ws: WorkerState | None = None | ||||||||||||||
wws: WorkerState | ||||||||||||||
dts: TaskState | ||||||||||||||
deps: set = ts.dependencies | ||||||||||||||
candidates: set | ||||||||||||||
n_workers: int = len(valid_workers if valid_workers is not None else all_workers) | ||||||||||||||
assert all([dts.who_has for dts in deps]) | ||||||||||||||
if ts.actor: | ||||||||||||||
candidates = set(all_workers) | ||||||||||||||
else: | ||||||||||||||
candidates = {wws for dts in deps for wws in dts.who_has} | ||||||||||||||
candidates = { | ||||||||||||||
wws | ||||||||||||||
for dts in deps | ||||||||||||||
# Ignore dependencies that will need to be, or already are, copied to all workers | ||||||||||||||
if len(dts.who_has) < n_workers | ||||||||||||||
and not ( | ||||||||||||||
len(dts.dependents) >= n_workers | ||||||||||||||
and len(dts.group) < n_workers // 2 | ||||||||||||||
Comment on lines
+7529
to
+7531
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
# Really want something like: | ||||||||||||||
# map(len, dts.group.dependents) >= nthreads and len(dts.group) < n_workers // 2 | ||||||||||||||
# Or at least | ||||||||||||||
# len(dts.dependents) * len(dts.group) >= nthreads and len(dts.group) < n_workers // 2 | ||||||||||||||
# But `nthreads` is O(k) to calculate if given `valid_workers`. | ||||||||||||||
# and the `map(len, dts.group.dependents)` could be extremely expensive since we can't put | ||||||||||||||
# much of an upper bound on it. | ||||||||||||||
) | ||||||||||||||
for wws in dts.who_has | ||||||||||||||
} | ||||||||||||||
if valid_workers is None: | ||||||||||||||
if not candidates: | ||||||||||||||
candidates = set(all_workers) | ||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -136,6 +136,9 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c): | |
], | ||
) | ||
def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads): | ||
if ndeps >= len(nthreads): | ||
pytest.skip() | ||
|
||
@gen_cluster( | ||
client=True, | ||
nthreads=nthreads, | ||
|
@@ -239,6 +242,153 @@ def random(**kwargs): | |
test_decide_worker_coschedule_order_neighbors_() | ||
|
||
|
||
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) | ||
async def test_decide_worker_common_dep_ignored(client, s, *workers): | ||
r""" | ||
When we have basic linear chains, but all the downstream tasks also share a common dependency, ignore that dependency. | ||
|
||
i j k l m n o p | ||
\__\__\__\___/__/__/__/ | ||
| | | | | | | | | | ||
| | | | X | | | | | ||
a b c d e f g h | ||
|
||
^ Ignore the location of X when picking a worker for i..p. | ||
It will end up being copied to all workers anyway. | ||
|
||
If a dependency will end up on every worker regardless, because many things depend on it, | ||
we should ignore it when selecting our candidate workers. Otherwise, we'll end up considering | ||
every worker as a candidate, which is 1) slow and 2) often leads to poor choices. | ||
""" | ||
roots = [ | ||
delayed(slowinc)(1, 0.1 / (i + 1), dask_key_name=f"root-{i}") for i in range(16) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason for choosing delayed over futures? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really. It made it easy to visualize the graph though. Would we prefer futures? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes sense. Typically I would prefer futures since if something goes wrong and we'd need to debug, futures contain fewer layers. However, the visualization argument is strong. I would recommend leaving a comment in-code to avoid overly eager refactoring down the line |
||
] | ||
# This shared dependency will get copied to all workers, eventually making all workers valid candidates for each dep | ||
everywhere = delayed(None, name="everywhere") | ||
deps = [ | ||
delayed(lambda x, y: None)(r, everywhere, dask_key_name=f"dep-{i}") | ||
for i, r in enumerate(roots) | ||
] | ||
|
||
rs, ds = dask.persist(roots, deps) | ||
await wait(ds) | ||
gjoseph92 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
keys = { | ||
worker.name: dict( | ||
root_keys=sorted( | ||
[int(k.split("-")[1]) for k in worker.data if k.startswith("root")] | ||
), | ||
deps_of_root=sorted( | ||
[int(k.split("-")[1]) for k in worker.data if k.startswith("dep")] | ||
), | ||
) | ||
for worker in workers | ||
} | ||
|
||
for k in keys.values(): | ||
assert k["root_keys"] == k["deps_of_root"] | ||
|
||
for worker in workers: | ||
log = worker.incoming_transfer_log | ||
if log: | ||
assert len(log) == 1 | ||
assert list(log[0]["keys"]) == ["everywhere"] | ||
|
||
|
||
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) | ||
async def test_decide_worker_large_subtrees_colocated(client, s, *workers): | ||
r""" | ||
Ensure that the above "ignore common dependencies" logic doesn't affect wide (but isolated) subtrees. | ||
|
||
........ ........ ........ ........ | ||
\\\\//// \\\\//// \\\\//// \\\\//// | ||
a b c d | ||
Comment on lines
+303
to
+305
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haven't put much thought into this question, therefore a short answer is more than enough. Would your logic be impacted if the subtrees flow together again, i.e. they have a common dependent or set of dependents as in a tree reduction. If the answer is "no, this logic doesn't go that deep into the graph" (which is what I'm currently guessing), that's fine. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, it doesn't go any further into the graph. |
||
|
||
Each one of a, b, etc. has more dependents than there are workers. But just because a has | ||
lots of dependents doesn't necessarily mean it will end up copied to every worker. | ||
Because a also has a few siblings, a's dependents shouldn't spread out over the whole cluster. | ||
""" | ||
roots = [delayed(inc)(i, dask_key_name=f"root-{i}") for i in range(len(workers))] | ||
deps = [ | ||
delayed(inc)(r, dask_key_name=f"dep-{i}-{j}") | ||
for i, r in enumerate(roots) | ||
for j in range(len(workers) * 2) | ||
] | ||
|
||
rs, ds = dask.persist(roots, deps) | ||
await wait(ds) | ||
|
||
keys = { | ||
worker.name: dict( | ||
root_keys={ | ||
int(k.split("-")[1]) for k in worker.data if k.startswith("root") | ||
}, | ||
deps_of_root={ | ||
int(k.split("-")[1]) for k in worker.data if k.startswith("dep") | ||
}, | ||
) | ||
for worker in workers | ||
} | ||
|
||
for k in keys.values(): | ||
assert k["root_keys"] == k["deps_of_root"] | ||
assert len(k["root_keys"]) == len(roots) / len(workers) | ||
|
||
for worker in workers: | ||
assert not worker.incoming_transfer_log | ||
|
||
|
||
@gen_cluster( | ||
client=True, | ||
nthreads=[("127.0.0.1", 1)] * 4, | ||
config={"distributed.scheduler.work-stealing": False}, | ||
) | ||
async def test_decide_worker_large_multiroot_subtrees_colocated(client, s, *workers): | ||
r""" | ||
Same as the above test, but also check isolated trees with multiple roots. | ||
|
||
........ ........ ........ ........ | ||
\\\\//// \\\\//// \\\\//// \\\\//// | ||
a b c d e f g h | ||
""" | ||
roots = [ | ||
delayed(inc)(i, dask_key_name=f"root-{i}") for i in range(len(workers) * 2) | ||
] | ||
deps = [ | ||
delayed(lambda x, y: None)( | ||
r, roots[i * 2 + 1], dask_key_name=f"dep-{i * 2}-{j}" | ||
) | ||
for i, r in enumerate(roots[::2]) | ||
for j in range(len(workers) * 2) | ||
] | ||
|
||
rs, ds = dask.persist(roots, deps) | ||
await wait(ds) | ||
|
||
keys = { | ||
worker.name: dict( | ||
root_keys={ | ||
int(k.split("-")[1]) for k in worker.data if k.startswith("root") | ||
}, | ||
deps_of_root=set().union( | ||
*( | ||
(int(k.split("-")[1]), int(k.split("-")[1]) + 1) | ||
for k in worker.data | ||
if k.startswith("dep") | ||
) | ||
), | ||
) | ||
for worker in workers | ||
} | ||
|
||
for k in keys.values(): | ||
assert k["root_keys"] == k["deps_of_root"] | ||
assert len(k["root_keys"]) == len(roots) / len(workers) | ||
|
||
for worker in workers: | ||
assert not worker.incoming_transfer_log | ||
|
||
|
||
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) | ||
async def test_move_data_over_break_restrictions(client, s, a, b, c): | ||
[x] = await client.scatter([1], workers=b.address) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is more thoroughly dealt with in #6253