Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Scheduler.restart logic #6504

Merged
merged 8 commits into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5091,19 +5091,19 @@ async def restart(self, client=None, timeout=30):
stimulus_id=stimulus_id,
)

nannies = {addr: ws.nanny for addr, ws in self.workers.items()}
nanny_workers = {
addr: ws.nanny for addr, ws in self.workers.items() if ws.nanny
}

for addr in list(self.workers):
try:
# Ask the worker to close if it doesn't have a nanny,
# otherwise the nanny will kill it anyway
await self.remove_worker(
address=addr, close=addr not in nannies, stimulus_id=stimulus_id
)
except Exception:
logger.info(
"Exception while restarting. This is normal", exc_info=True
)
# Close non-Nanny workers. We have no way to restart them, so we just let them go,
# and assume a deployment system is going to restart them for us.
await asyncio.gather(
*(
self.remove_worker(address=addr, stimulus_id=stimulus_id)
for addr in self.workers
if addr not in nanny_workers
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Key change: before, nannies contained all workers, so we were removing all workers immediately. Now, we only remove non-nanny workers, and leave nanny workers around to be restarted via RPC to the Nanny a few lines below.

)
)

self.clear_task_state()

Expand All @@ -5113,21 +5113,27 @@ async def restart(self, client=None, timeout=30):
except Exception as e:
logger.exception(e)

logger.debug("Send kill signal to nannies: %s", nannies)
logger.debug("Send kill signal to nannies: %s", nanny_workers)
async with contextlib.AsyncExitStack() as stack:
nannies = [
await stack.enter_async_context(
rpc(nanny_address, connection_args=self.connection_args)
)
for nanny_address in nannies.values()
if nanny_address is not None
for nanny_address in nanny_workers.values()
]

resps = All(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor for style, using asyncio.gather instead of All

[nanny.restart(close=True, timeout=timeout * 0.8) for nanny in nannies]
)
try:
resps = await asyncio.wait_for(resps, timeout)
resps = await asyncio.wait_for(
asyncio.gather(
*(
nanny.restart(close=True, timeout=timeout * 0.8)
for nanny in nannies
)
),
timeout,
)
# NOTE: the `WorkerState` entries for these workers will be removed
# naturally when they disconnect from the scheduler.
except TimeoutError:
logger.error(
"Nannies didn't report back restarted within "
Expand Down
89 changes: 89 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import json
import logging
Expand Down Expand Up @@ -625,6 +627,93 @@ async def test_restart(c, s, a, b):
assert not s.tasks


@gen_cluster(client=True, Worker=Nanny, timeout=60)
async def test_restart_some_nannies_some_not(c, s, a, b):
original_procs = {a.process.process, b.process.process}
original_workers = dict(s.workers)
async with Worker(s.address, nthreads=1) as w:
await c.wait_for_workers(3)

# Halfway through `Scheduler.restart`, only the non-Nanny workers should be removed.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, the plugin is triggered here:

for plugin in list(self.plugins.values()):
try:
plugin.restart(self)
except Exception as e:
logger.exception(e)

# Nanny-based workers should be kept around so we can call their `restart` RPC.
class ValidateRestartPlugin(SchedulerPlugin):
error: Exception | None

def restart(self, scheduler: Scheduler) -> None:
try:
assert scheduler.workers.keys() == {
a.worker_address,
b.worker_address,
}
assert all(ws.nanny for ws in scheduler.workers.values())
except Exception as e:
# `Scheduler.restart` swallows exceptions within plugins
self.error = e
raise
else:
self.error = None

plugin = ValidateRestartPlugin()
s.add_plugin(plugin)
await s.restart()

if plugin.error:
raise plugin.error
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what's the purpose of this complication - can't you just put the assertions you wrote in the plugin here instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the exception handling here:

for plugin in list(self.plugins.values()):
try:
plugin.restart(self)
except Exception as e:
logger.exception(e)


assert w.status == Status.closed

assert len(s.workers) == 2
# Confirm they restarted
# NOTE: == for `psutil.Process` compares PID and creation time
new_procs = {a.process.process, b.process.process}
assert new_procs != original_procs
# The workers should have new addresses
assert s.workers.keys().isdisjoint(original_workers.keys())
# The old WorkerState instances should be replaced
assert set(s.workers.values()).isdisjoint(original_workers.values())


class SlowRestartNanny(Nanny):
def __init__(self, *args, **kwargs):
self.restart_proceed = asyncio.Event()
self.restart_called = asyncio.Event()
super().__init__(*args, **kwargs)

async def restart(self, **kwargs):
self.restart_called.set()
await self.restart_proceed.wait()
return await super().restart(**kwargs)


@gen_cluster(
client=True,
nthreads=[("", 1)],
Worker=SlowRestartNanny,
worker_kwargs={"heartbeat_interval": "1ms"},
)
async def test_restart_heartbeat_before_closing(c, s: Scheduler, n: SlowRestartNanny):
"""
Ensure that if workers heartbeat in the middle of `Scheduler.restart`, they don't close themselves.
https://github.com/dask/distributed/issues/6494
"""
prev_workers = dict(s.workers)
restart_task = asyncio.create_task(s.restart())

await n.restart_called.wait()
await asyncio.sleep(0.5) # significantly longer than the heartbeat interval

# WorkerState should not be removed yet, because the worker hasn't been told to close
assert s.workers

n.restart_proceed.set()
# Wait until the worker has left (possibly until it's come back too)
while s.workers == prev_workers:
await asyncio.sleep(0.01)

await restart_task
await c.wait_for_workers(1)


@gen_cluster()
async def test_broadcast(s, a, b):
result = await s.broadcast(msg={"op": "ping"})
Expand Down