diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index c7a18517876..700765ba6fe 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -1,8 +1,10 @@ +import asyncio import atexit import gc import logging import os import re +import signal import sys import warnings @@ -10,7 +12,7 @@ from tornado.ioloop import IOLoop from distributed import Scheduler -from distributed.cli.utils import install_signal_handlers +from distributed.cli.utils import wait_for_signals from distributed.preloading import validate_preload_argv from distributed.proctitle import ( enable_proctitle_on_children, @@ -183,33 +185,51 @@ def del_pid_file(): limit = max(soft, hard // 2) resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard)) - loop = IOLoop.current() - logger.info("-" * 47) + async def run(): + loop = IOLoop.current() + logger.info("-" * 47) + + scheduler = Scheduler( + loop=loop, + security=sec, + host=host, + port=port, + dashboard=dashboard, + dashboard_address=dashboard_address, + http_prefix=dashboard_prefix, + **kwargs, + ) + logger.info("-" * 47) - scheduler = Scheduler( - loop=loop, - security=sec, - host=host, - port=port, - dashboard=dashboard, - dashboard_address=dashboard_address, - http_prefix=dashboard_prefix, - **kwargs, - ) - logger.info("-" * 47) + async def wait_for_scheduler_to_finish(): + """Wait for the scheduler to initialize and finish""" + await scheduler + await scheduler.finished() - install_signal_handlers(loop) + async def wait_for_signals_and_close(): + """Wait for SIGINT or SIGTERM and close the scheduler upon receiving one of those signals""" + await wait_for_signals([signal.SIGINT, signal.SIGTERM]) + await scheduler.close() - async def run(): - await scheduler - await scheduler.finished() + wait_for_signals_and_close_task = asyncio.create_task( + wait_for_signals_and_close() + ) + wait_for_scheduler_to_finish_task = asyncio.create_task( + wait_for_scheduler_to_finish() + ) + + done, _ = await asyncio.wait( + [wait_for_signals_and_close_task, wait_for_scheduler_to_finish_task], + return_when=asyncio.FIRST_COMPLETED, + ) + # Re-raise exceptions from done tasks + [task.result() for task in done] + logger.info("Stopped scheduler at %r", scheduler.address) try: - loop.run_sync(run) + asyncio.run(run()) finally: - scheduler.stop() - - logger.info("End scheduler at %r", scheduler.address) + logger.info("End scheduler") if __name__ == "__main__": diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 376f2a1c62b..0b22db6ae3c 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -20,7 +20,7 @@ from dask.system import CPU_COUNT from distributed import Nanny -from distributed.cli.utils import install_signal_handlers +from distributed.cli.utils import wait_for_signals from distributed.comm import get_address_host_port from distributed.deploy.utils import nprocesses_nthreads from distributed.preloading import validate_preload_argv @@ -404,8 +404,6 @@ def del_pid_file(): else: resources = None - loop = IOLoop.current() - worker_class = import_term(worker_class) port_kwargs = _apportion_ports(worker_port, nanny_port, n_workers, nanny) @@ -432,56 +430,68 @@ def del_pid_file(): with suppress(TypeError, ValueError): name = int(name) - nannies = [ - t( - scheduler, - scheduler_file=scheduler_file, - nthreads=nthreads, - loop=loop, - resources=resources, - security=sec, - contact_address=contact_address, - host=host, - dashboard=dashboard, - dashboard_address=dashboard_address, - name=name - if n_workers == 1 or name is None or name == "" - else str(name) + "-" + str(i), - **kwargs, - **port_kwargs_i, - ) - for i, port_kwargs_i in enumerate(port_kwargs) - ] + signal_fired = False - async def close_all(): - # Unregister all workers from scheduler - if nanny: - await asyncio.gather(*(n.close(timeout=2) for n in nannies)) + async def run(): + loop = IOLoop.current() + + nannies = [ + t( + scheduler, + scheduler_file=scheduler_file, + nthreads=nthreads, + loop=loop, + resources=resources, + security=sec, + contact_address=contact_address, + host=host, + dashboard=dashboard, + dashboard_address=dashboard_address, + name=name + if n_workers == 1 or name is None or name == "" + else str(name) + "-" + str(i), + **kwargs, + **port_kwargs_i, + ) + for i, port_kwargs_i in enumerate(port_kwargs) + ] - signal_fired = False + async def wait_for_nannies_to_finish(): + """Wait for all nannies to initialize and finish""" + await asyncio.gather(*nannies) + await asyncio.gather(*(n.finished() for n in nannies)) - def on_signal(signum): - nonlocal signal_fired - signal_fired = True - if signum != signal.SIGINT: - logger.info("Exiting on signal %d", signum) - return asyncio.ensure_future(close_all()) + async def wait_for_signals_and_close(): + """Wait for SIGINT or SIGTERM and close all nannies upon receiving one of those signals""" + nonlocal signal_fired + await wait_for_signals([signal.SIGINT, signal.SIGTERM]) - async def run(): - await asyncio.gather(*nannies) - await asyncio.gather(*(n.finished() for n in nannies)) + signal_fired = True + if nanny: + # Unregister all workers from scheduler + await asyncio.gather(*(n.close(timeout=10) for n in nannies)) - install_signal_handlers(loop, cleanup=on_signal) + wait_for_signals_and_close_task = asyncio.create_task( + wait_for_signals_and_close() + ) + wait_for_nannies_to_finish_task = asyncio.create_task( + wait_for_nannies_to_finish() + ) + + done, _ = await asyncio.wait( + [wait_for_signals_and_close_task, wait_for_nannies_to_finish_task], + return_when=asyncio.FIRST_COMPLETED, + ) + # Re-raise exceptions from done tasks + [task.result() for task in done] try: - loop.run_sync(run) - except TimeoutError: + asyncio.run(run()) + except (TimeoutError, asyncio.TimeoutError): # We already log the exception in nanny / worker. Don't do it again. if not signal_fired: logger.info("Timed out starting worker") sys.exit(1) - except KeyboardInterrupt: # pragma: no cover - pass finally: logger.info("End worker") diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index d73905401ff..1b01d1ad355 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -5,6 +5,8 @@ import os import shutil +import signal +import subprocess import sys import tempfile from time import sleep @@ -17,7 +19,7 @@ import distributed import distributed.cli.dask_scheduler from distributed import Client, Scheduler -from distributed.compatibility import LINUX +from distributed.compatibility import LINUX, WINDOWS from distributed.metrics import time from distributed.utils import get_ip, get_ip_interface from distributed.utils_test import ( @@ -414,9 +416,12 @@ def test_version_option(): def test_idle_timeout(loop): start = time() runner = CliRunner() - runner.invoke(distributed.cli.dask_scheduler.main, ["--idle-timeout", "1s"]) + result = runner.invoke( + distributed.cli.dask_scheduler.main, ["--idle-timeout", "1s"] + ) stop = time() assert 1 < stop - start < 10 + assert result.exit_code == 0 def test_multiple_workers_2(loop): @@ -453,3 +458,25 @@ def test_multiple_workers(loop): while len(c.nthreads()) < 2: sleep(0.1) assert time() < start + 10 + + +@pytest.mark.slow +@pytest.mark.skipif(WINDOWS, reason="POSIX only") +@pytest.mark.parametrize("sig", [signal.SIGINT, signal.SIGTERM]) +def test_signal_handling(loop, sig): + with subprocess.Popen( + ["python", "-m", "distributed.cli.dask_scheduler"], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) as scheduler: + # Wait for scheduler to start + with Client(f"127.0.0.1:{Scheduler.default_port}", loop=loop) as c: + pass + scheduler.send_signal(sig) + stdout, stderr = scheduler.communicate() + logs = stdout.decode().lower() + assert stderr is None + assert scheduler.returncode == 0 + assert sig.name.lower() in logs + assert "scheduler closing" in logs + assert "end scheduler" in logs diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 23080909e1c..a6ba6d45aa6 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -2,6 +2,8 @@ import asyncio import os +import signal +import subprocess import sys from multiprocessing import cpu_count from time import sleep @@ -13,7 +15,7 @@ from distributed import Client from distributed.cli.dask_worker import _apportion_ports, main -from distributed.compatibility import LINUX, to_thread +from distributed.compatibility import LINUX, WINDOWS, to_thread from distributed.deploy.utils import nprocesses_nthreads from distributed.metrics import time from distributed.utils_test import gen_cluster, popen, requires_ipv6 @@ -593,7 +595,7 @@ def test_worker_timeout(no_nanny): if no_nanny: args.append("--no-nanny") result = runner.invoke(main, args) - assert result.exit_code != 0 + assert result.exit_code == 1 def test_bokeh_deprecation(): @@ -682,3 +684,54 @@ def dask_setup(worker): await c.wait_for_workers(1) [foo] = (await c.run(lambda dask_worker: dask_worker.foo)).values() assert foo == "setup" + + +@pytest.mark.slow +@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) +def test_timeout(nanny): + worker = subprocess.run( + [ + "python", + "-m", + "distributed.cli.dask_worker", + "192.168.1.100:7777", + nanny, + "--death-timeout=1", + ], + text=True, + encoding="utf8", + capture_output=True, + ) + + assert "timed out starting worker" in worker.stderr.lower() + assert "end worker" in worker.stderr.lower() + assert worker.returncode == 1 + + +@pytest.mark.skipif(WINDOWS, reason="POSIX only") +@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) +@pytest.mark.parametrize("sig", [signal.SIGINT, signal.SIGTERM]) +@gen_cluster(client=True, nthreads=[]) +async def test_signal_handling(c, s, nanny, sig): + with subprocess.Popen( + ["python", "-m", "distributed.cli.dask_worker", s.address, nanny], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) as worker: + await c.wait_for_workers(1) + + worker.send_signal(sig) + stdout, stderr = worker.communicate() + logs = stdout.decode().lower() + assert stderr is None + assert worker.returncode == 0 + assert sig.name.lower() in logs + if nanny == "--nanny": + assert "closing nanny" in logs + assert "stopping worker" in logs + else: + assert "nanny" not in logs + assert "end worker" in logs + assert "timed out" not in logs + assert "error" not in logs + assert "exception" not in logs diff --git a/distributed/cli/utils.py b/distributed/cli/utils.py index a0191edaa90..f25b7245c41 100644 --- a/distributed/cli/utils.py +++ b/distributed/cli/utils.py @@ -1,30 +1,28 @@ -from tornado.ioloop import IOLoop +from __future__ import annotations +import asyncio +import logging +import signal +from typing import Any -def install_signal_handlers(loop=None, cleanup=None): - """ - Install global signal handlers to halt the Tornado IOLoop in case of - a SIGINT or SIGTERM. *cleanup* is an optional callback called, - before the loop stops, with a single signal number argument. - """ - import signal +logger = logging.getLogger(__name__) - loop = loop or IOLoop.current() - old_handlers = {} +async def wait_for_signals(signals: list[signal.Signals]) -> None: + """Wait for the passed signals by setting global signal handlers""" + loop = asyncio.get_running_loop() + event = asyncio.Event() - def handle_signal(sig, frame): - async def cleanup_and_stop(): - try: - if cleanup is not None: - await cleanup(sig) - finally: - loop.stop() + old_handlers: dict[int, Any] = {} - loop.add_callback_from_signal(cleanup_and_stop) - # Restore old signal handler to allow for a quicker exit + def handle_signal(signum, frame): + # Restore old signal handler to allow for quicker exit # if the user sends the signal again. - signal.signal(sig, old_handlers[sig]) + signal.signal(signum, old_handlers[signum]) + logger.info("Received signal %s (%d)", signal.Signals(signum).name, signum) + loop.call_soon_threadsafe(event.set) - for sig in [signal.SIGINT, signal.SIGTERM]: + for sig in signals: old_handlers[sig] = signal.signal(sig, handle_signal) + + await event.wait()