Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Make workers gracefully handle sigint #2844

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions distributed/cli/dask_worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import atexit
import functools
import logging
import gc
import os
Expand All @@ -13,7 +14,7 @@
from dask.system import CPU_COUNT
from distributed import Nanny, Worker
from distributed.security import Security
from distributed.cli.utils import check_python_3, install_signal_handlers
from distributed.cli.utils import check_python_3
from distributed.comm import get_address_host_port
from distributed.preloading import validate_preload_argv
from distributed.proctitle import (
Expand Down Expand Up @@ -386,26 +387,32 @@ def del_pid_file():
for i in range(nprocs)
]

async def close_all():
# Unregister all workers from scheduler
if nanny:
await asyncio.gather(*[n.close(timeout=2) for n in nannies])

signal_fired = False

def on_signal(signum):
async def _on_signal(signum):
nonlocal signal_fired
signal_fired = True
if signum != signal.SIGINT:
from distributed.utils import log_errors

with log_errors():
logger.info("Exiting on signal %d", signum)
asyncio.ensure_future(close_all())
signal_fired = True
if signum == signal.SIGINT:
logger.info("Gracefully closing worker because of SIGINT call")
await asyncio.gather(*[n.close_gracefully() for n in nannies])
Comment on lines +399 to +401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be reasonable to give SIGTERM the same treatment as SIGINT as well?

logger.info("Closing workers")
await asyncio.gather(*[n.close() for n in nannies])

def on_signal(sig):
asyncio.ensure_future(_on_signal(sig))

async def run():
await asyncio.gather(*nannies)
await asyncio.gather(*[n.finished() for n in nannies])

install_signal_handlers(loop, cleanup=on_signal)

for sig in [signal.SIGINT, signal.SIGTERM]:
asyncio.get_event_loop().add_signal_handler(
sig, functools.partial(on_signal, sig)
)
Comment on lines +412 to +415
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try:
loop.run_sync(run)
except TimeoutError:
Expand Down
37 changes: 35 additions & 2 deletions distributed/cli/tests/test_dask_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import signal

import pytest
from click.testing import CliRunner

Expand All @@ -10,10 +12,11 @@
from time import sleep

import distributed.cli.dask_worker
from distributed import Client, Scheduler
from distributed import Client, Scheduler, Worker, wait
from distributed.compatibility import WINDOWS
from distributed.metrics import time
from distributed.utils import sync, tmpfile
from distributed.utils_test import popen, terminate_process, wait_for_port
from distributed.utils_test import popen, terminate_process, wait_for_port, slowinc
from distributed.utils_test import loop, cleanup # noqa: F401


Expand Down Expand Up @@ -47,6 +50,36 @@ def test_nanny_worker_ports(loop):
)


@pytest.mark.skipif(WINDOWS, reason="Not supported on Windows")
@pytest.mark.asyncio
async def test_sigint(cleanup):
async with Scheduler(port=0) as s:
with popen(["dask-worker", s.address, "--name", "alice"]) as worker:
async with Client(s.address, asynchronous=True) as c:
async with Worker(s.address) as w:
await c.wait_for_workers(2)
a, b = s.workers.values()
scattered = await asyncio.gather(
c.scatter(list(range(0, 10)), workers=[a.address]),
c.scatter(list(range(10, 20)), workers=[b.address]),
)
scattered = scattered[0] + scattered[1]
assert a.has_what and b.has_what

submitted = c.map(slowinc, range(10), delay=0.05)
await asyncio.sleep(0.10)

worker.send_signal(signal.SIGINT)
while len(s.workers) > 1:
await asyncio.sleep(0.01)

await asyncio.sleep(0.5)

await wait(submitted)
assert all(future.status == "finished" for future in scattered)
assert all(future.status == "finished" for future in submitted)


def test_memory_limit(loop):
with popen(["dask-scheduler", "--no-dashboard"]) as sched:
with popen(
Expand Down
11 changes: 9 additions & 2 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(
# cannot call it 'close' on the rpc side for naming conflict
"get_logs": self.get_logs,
"terminate": self.close,
"close_gracefully": self.close_gracefully,
"close_gracefully": self.close_gracefully_signal,
"run": self.run,
}

Expand Down Expand Up @@ -423,14 +423,21 @@ def _close(self, *args, **kwargs):
warnings.warn("Worker._close has moved to Worker.close", stacklevel=2)
return self.close(*args, **kwargs)

def close_gracefully(self, comm=None):
def close_gracefully_signal(self, comm=None):
"""
A signal that we shouldn't try to restart workers if they go away

This is used as part of the cluster shutdown process.
"""
self.status = "closing-gracefully"

async def close_gracefully(self):
try:
await self.rpc(self.worker_address).close_gracefully()
except CommClosedError: # worker will have closed connection
pass
await self.close()

async def close(self, comm=None, timeout=5, report=None):
"""
Close the worker process, stop all comms.
Expand Down
11 changes: 11 additions & 0 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,14 @@ async def test_config(cleanup):
async with Client(s.address, asynchronous=True) as client:
config = await client.run(dask.config.get, "foo")
assert config[n.worker_address] == "bar"


@gen_cluster(client=True, Worker=Nanny)
async def test_close_gracefully(c, s, a, b):
futures = await c.scatter(list(range(10)))
assert all(ws.has_what for ws in s.workers.values())

await a.close_gracefully()
assert a.status == "closed"
assert len(s.workers) == 1
assert all(f.status == "finished" for f in futures)
2 changes: 1 addition & 1 deletion distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ def terminate_process(proc):
if sys.platform.startswith("win"):
proc.send_signal(signal.CTRL_BREAK_EVENT)
else:
proc.send_signal(signal.SIGINT)
proc.send_signal(signal.SIGTERM)
try:
if sys.version_info[0] == 3:
proc.wait(10)
Expand Down
24 changes: 14 additions & 10 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bisect
from collections import defaultdict, deque, namedtuple
from collections.abc import MutableMapping
import concurrent.futures
from datetime import timedelta
import heapq
from inspect import isawaitable
Expand Down Expand Up @@ -619,6 +620,7 @@ def __init__(
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"close_gracefully": self.close_gracefully,
}

stream_handlers = {
Expand Down Expand Up @@ -1103,27 +1105,29 @@ async def close(
self.rpc.close()

self.status = "closed"
await ServerNode.close(self)
with ignoring(concurrent.futures.CancelledError):
await ServerNode.close(self)

setproctitle("dask-worker [closed]")
return "OK"

async def close_gracefully(self):
async def close_gracefully(self, comm=None):
""" Gracefully shut down a worker

This first informs the scheduler that we're shutting down, and asks it
to move our data elsewhere. Afterwards, we close as normal
"""
if self.status.startswith("closing"):
await self.finished()
with log_errors():
if self.status.startswith("closing"):
await self.finished()

if self.status == "closed":
return
if self.status == "closed":
return

logger.info("Closing worker gracefully: %s", self.address)
self.status = "closing-gracefully"
await self.scheduler.retire_workers(workers=[self.address], remove=False)
await self.close(safe=True, nanny=not self.lifetime_restart)
logger.info("Closing worker gracefully: %s", self.address)
self.status = "closing-gracefully"
await self.scheduler.retire_workers(workers=[self.address], remove=False)
await self.close(safe=True, nanny=not self.lifetime_restart)

async def terminate(self, comm, report=True, **kwargs):
await self.close(report=report, **kwargs)
Expand Down