Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schedule a queued task when a task secedes #7224

Merged
merged 6 commits into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3288,20 +3288,7 @@ def bulk_schedule_after_adding_worker(self, ws: WorkerState) -> Recs:

Returns priority-ordered recommendations.
"""
maybe_runnable: list[TaskState] = []
# Schedule any queued tasks onto the new worker
if not math.isinf(self.WORKER_SATURATION) and self.queued:
for qts in reversed(
list(
self.queued.peekn(_task_slots_available(ws, self.WORKER_SATURATION))
)
):
if self.validate:
assert qts.state == "queued"
assert not qts.processing_on
assert not qts.waiting_on

maybe_runnable.append(qts)
maybe_runnable = list(_next_queued_tasks_for_worker(self, ws))[::-1]

# Schedule any restricted tasks onto the new worker, if the worker can run them
for ts in self.unrunnable:
Expand Down Expand Up @@ -5338,6 +5325,14 @@ def handle_long_running(
ws.add_to_long_running(ts)
self.check_idle_saturated(ws)

recommendations = {
qts.key: "processing" for qts in _next_queued_tasks_for_worker(self, ws)
}
if self.validate:
assert len(recommendations) <= 1, (ws, recommendations)

self.transitions(recommendations, stimulus_id)

def handle_worker_status_change(
self, status: str | Status, worker: str | WorkerState, stimulus_id: str
) -> None:
Expand Down Expand Up @@ -7886,21 +7881,32 @@ def _exit_processing_common(
state.check_idle_saturated(ws)
state.release_resources(ts, ws)

# If a slot has opened up for a queued task, schedule it.
if state.queued and not _worker_full(ws, state.WORKER_SATURATION):
qts = state.queued.peek()
for qts in _next_queued_tasks_for_worker(state, ws):
if state.validate:
assert qts.state == "queued", qts.state
assert qts.key not in recommendations, recommendations[qts.key]

# NOTE: we don't need to schedule more than one task at once here. Since this is
# called each time 1 task completes, multiple tasks must complete for multiple
# slots to open up.
recommendations[qts.key] = "processing"

return ws


def _next_queued_tasks_for_worker(
state: SchedulerState, ws: WorkerState
) -> Iterator[TaskState]:
"""Queued tasks to run, in priority order, on all open slots on a worker"""
if not state.queued or ws.status != Status.running:
return

# NOTE: this is called most frequently because a single task has completed, so there
# are <= 1 task slots available on the worker.
# `peekn` has fast paths for the cases N<=0 and N==1.
for qts in state.queued.peekn(_task_slots_available(ws, state.WORKER_SATURATION)):
if state.validate:
assert qts.state == "queued", qts.state
assert not qts.processing_on
assert not qts.waiting_on
yield qts


def _add_to_memory(
state: SchedulerState,
ts: TaskState,
Expand Down
22 changes: 21 additions & 1 deletion distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
varying,
wait_for_state,
)
from distributed.worker import dumps_function, dumps_task, get_worker
from distributed.worker import dumps_function, dumps_task, get_worker, secede

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -479,6 +479,26 @@ async def test_queued_remove_add_worker(c, s, a, b):
await wait(fs)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_secede_opens_slot(c, s, a):
first = Event()
second = Event()

def func(first, second):
first.wait()
secede()
second.wait()

fs = c.map(func, [first] * 5, [second] * 5)
await async_wait_for(lambda: a.state.executing, timeout=5)

await first.set()
await async_wait_for(lambda: len(a.state.long_running) == len(fs), timeout=5)

await second.set()
await c.gather(fs)


@pytest.mark.parametrize(
"saturation_config, expected_task_counts",
[
Expand Down