Skip to content

Commit

Permalink
Fix co-assignment for binary operations
Browse files Browse the repository at this point in the history
Bit of a hack, but closes dask#6597. I'd like to have a better metric for the batch size, but I think this is about as good as we can get. Any reasonably large number will do here.
  • Loading branch information
gjoseph92 committed Jun 23, 2022
1 parent b86fe0f commit 7ebd1d9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 37 deletions.
31 changes: 12 additions & 19 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,14 +804,6 @@ class TaskGroup:
#: The result types of this TaskGroup
types: set[str]

#: The worker most recently assigned a task from this group, or None when the group
#: is not identified to be root-like by `SchedulerState.decide_worker`.
last_worker: WorkerState | None

#: If `last_worker` is not None, the number of times that worker should be assigned
#: subsequent tasks until a new worker is chosen.
last_worker_tasks_left: int

prefix: TaskPrefix | None
start: float
stop: float
Expand All @@ -831,8 +823,6 @@ def __init__(self, name: str):
self.start = 0.0
self.stop = 0.0
self.all_durations = defaultdict(float)
self.last_worker = None
self.last_worker_tasks_left = 0

def add_duration(self, action: str, start: float, stop: float) -> None:
duration = stop - start
Expand Down Expand Up @@ -1269,6 +1259,8 @@ class SchedulerState:
"extensions",
"host_info",
"idle",
"last_root_worker",
"last_root_worker_tasks_left",
"n_tasks",
"queued",
"resources",
Expand Down Expand Up @@ -1337,6 +1329,8 @@ def __init__(
self.total_nthreads = 0
self.total_occupancy = 0.0
self.unknown_durations: dict[str, set[TaskState]] = {}
self.last_root_worker: WorkerState | None = None
self.last_root_worker_tasks_left: int = 0
self.queued = queued
self.unrunnable = unrunnable
self.validate = validate
Expand Down Expand Up @@ -1807,25 +1801,24 @@ def decide_worker(
recommendations[ts.key] = "no-worker"
return None

lws = tg.last_worker
lws = self.last_root_worker
if not (
lws
and tg.last_worker_tasks_left
and self.last_root_worker_tasks_left
and self.workers.get(lws.address) is lws
):
# Last-used worker is full or unknown; pick a new worker for the next few tasks
ws = min(pool, key=partial(self.worker_objective, ts))
tg.last_worker_tasks_left = math.floor(
ws = self.last_root_worker = min(
pool, key=lambda ws: len(ws.processing) / ws.nthreads
)
# TODO better batching metric (`len(tg)` is not necessarily the total number of root tasks!)
self.last_root_worker_tasks_left = math.floor(
(len(tg) / self.total_nthreads) * ws.nthreads
)
else:
ws = lws

# Record `last_worker`, or clear it on the final task
tg.last_worker = (
ws if tg.states["released"] + tg.states["waiting"] > 1 else None
)
tg.last_worker_tasks_left -= 1
self.last_root_worker_tasks_left -= 1
return ws

if not self.idle:
Expand Down
39 changes: 21 additions & 18 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c):
assert x.key in a.data or x.key in b.data


# @pytest.mark.skip("Current queuing does not support co-assignment")
@pytest.mark.parametrize("ndeps", [0, 1, 4])
@pytest.mark.parametrize(
"nthreads",
Expand All @@ -151,10 +150,6 @@ def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads):
"distributed.scheduler.work-stealing": False,
"distributed.scheduler.worker-saturation": float("inf"),
},
scheduler_kwargs=dict( # TODO remove
dashboard=True,
dashboard_address=":8787",
),
)
async def test_decide_worker_coschedule_order_neighbors_(c, s, *workers):
r"""
Expand Down Expand Up @@ -254,6 +249,24 @@ def random(**kwargs):
test_decide_worker_coschedule_order_neighbors_()


@pytest.mark.parametrize("ngroups", [1, 2, 3, 5])
@gen_cluster(
client=True,
nthreads=[("", 1), ("", 1)],
config={
"distributed.scheduler.worker-saturation": float("inf"),
},
)
async def test_decide_worker_coschedule_order_binary_op(c, s, a, b, ngroups):
roots = [[delayed(i, name=f"x-{n}-{i}") for i in range(8)] for n in range(ngroups)]
zs = [sum(rs) for rs in zip(*roots)]

await c.gather(c.compute(zs))

assert not a.incoming_transfer_log, [l["keys"] for l in a.incoming_transfer_log]
assert not b.incoming_transfer_log, [l["keys"] for l in b.incoming_transfer_log]


@pytest.mark.slow
@gen_cluster(
client=True,
Expand Down Expand Up @@ -381,17 +394,7 @@ async def _test_saturation_factor(c, s, a, b):


@pytest.mark.skip("Current queuing does not support co-assignment")
@pytest.mark.parametrize(
"saturation_factor",
[
1.0,
2.0,
pytest.param(
float("inf"),
marks=pytest.mark.skip("https://github.com/dask/distributed/issues/6597"),
),
],
)
@pytest.mark.parametrize("saturation_factor", [1.0, 2.0, float("inf")])
@gen_cluster(
client=True,
nthreads=[("", 2), ("", 1)],
Expand All @@ -406,8 +409,8 @@ async def test_oversaturation_multiple_task_groups(c, s, a, b, saturation_factor

assert not a.incoming_transfer_log, [l["keys"] for l in a.incoming_transfer_log]
assert not b.incoming_transfer_log, [l["keys"] for l in b.incoming_transfer_log]
assert len(a.tasks) == 18
assert len(b.tasks) == 9
assert len(a.state.tasks) == 18
assert len(b.state.tasks) == 9


@pytest.mark.slow
Expand Down

0 comments on commit 7ebd1d9

Please sign in to comment.