diff --git a/CHANGES/7718.feature b/CHANGES/7718.feature new file mode 100644 index 00000000000..abeac12b19a --- /dev/null +++ b/CHANGES/7718.feature @@ -0,0 +1,3 @@ +Fixed keep-alive connections stopping a graceful shutdown. +Added ``shutdown_timeout`` parameter to ``BaseRunner``, while +removing ``shutdown_timeout`` parameter from ``BaseSite``. -- by :user:`Dreamsorcerer` diff --git a/aiohttp/web.py b/aiohttp/web.py index f87d57988be..581fd17a26f 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -5,6 +5,8 @@ import sys from argparse import ArgumentParser from collections.abc import Iterable +from contextlib import suppress +from functools import partial from importlib import import_module from typing import ( Any, @@ -18,6 +20,7 @@ Union, cast, ) +from weakref import WeakSet from .abc import AbstractAccessLogger from .helpers import AppKey @@ -291,6 +294,23 @@ async def _run_app( reuse_port: Optional[bool] = None, handler_cancellation: bool = False, ) -> None: + async def wait( + starting_tasks: "WeakSet[asyncio.Task[object]]", shutdown_timeout: float + ) -> None: + # Wait for pending tasks for a given time limit. + t = asyncio.current_task() + assert t is not None + starting_tasks.add(t) + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(_wait(starting_tasks), timeout=shutdown_timeout) + + async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None: + t = asyncio.current_task() + assert t is not None + exclude.add(t) + while tasks := asyncio.all_tasks().difference(exclude): + await asyncio.wait(tasks) + # An internal function to actually do all dirty job for application running if asyncio.iscoroutine(app): app = await app @@ -304,10 +324,17 @@ async def _run_app( access_log_format=access_log_format, access_log=access_log, keepalive_timeout=keepalive_timeout, + shutdown_timeout=shutdown_timeout, handler_cancellation=handler_cancellation, ) await runner.setup() + # On shutdown we want to avoid waiting on tasks which run forever. + # It's very likely that all tasks which run forever will have been created by + # the time we have completed the application startup (in runner.setup()), + # so we just record all running tasks here and exclude them later. + starting_tasks: "WeakSet[asyncio.Task[object]]" = WeakSet(asyncio.all_tasks()) + runner.shutdown_callback = partial(wait, starting_tasks, shutdown_timeout) sites: List[BaseSite] = [] @@ -319,7 +346,6 @@ async def _run_app( runner, host, port, - shutdown_timeout=shutdown_timeout, ssl_context=ssl_context, backlog=backlog, reuse_address=reuse_address, @@ -333,7 +359,6 @@ async def _run_app( runner, h, port, - shutdown_timeout=shutdown_timeout, ssl_context=ssl_context, backlog=backlog, reuse_address=reuse_address, @@ -345,7 +370,6 @@ async def _run_app( TCPSite( runner, port=port, - shutdown_timeout=shutdown_timeout, ssl_context=ssl_context, backlog=backlog, reuse_address=reuse_address, @@ -359,7 +383,6 @@ async def _run_app( UnixSite( runner, path, - shutdown_timeout=shutdown_timeout, ssl_context=ssl_context, backlog=backlog, ) @@ -370,7 +393,6 @@ async def _run_app( UnixSite( runner, p, - shutdown_timeout=shutdown_timeout, ssl_context=ssl_context, backlog=backlog, ) @@ -382,7 +404,6 @@ async def _run_app( SockSite( runner, sock, - shutdown_timeout=shutdown_timeout, ssl_context=ssl_context, backlog=backlog, ) @@ -393,7 +414,6 @@ async def _run_app( SockSite( runner, s, - shutdown_timeout=shutdown_timeout, ssl_context=ssl_context, backlog=backlog, ) diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py index 3063dce36ec..96db9541a4b 100644 --- a/aiohttp/web_runner.py +++ b/aiohttp/web_runner.py @@ -2,8 +2,7 @@ import signal import socket from abc import ABC, abstractmethod -from contextlib import suppress -from typing import Any, List, Optional, Set, Type +from typing import Any, Awaitable, Callable, List, Optional, Set, Type from yarl import URL @@ -45,20 +44,18 @@ def _raise_graceful_exit() -> None: class BaseSite(ABC): - __slots__ = ("_runner", "_shutdown_timeout", "_ssl_context", "_backlog", "_server") + __slots__ = ("_runner", "_ssl_context", "_backlog", "_server") def __init__( self, runner: "BaseRunner", *, - shutdown_timeout: float = 60.0, ssl_context: Optional[SSLContext] = None, backlog: int = 128, ) -> None: if runner.server is None: raise RuntimeError("Call runner.setup() before making a site") self._runner = runner - self._shutdown_timeout = shutdown_timeout self._ssl_context = ssl_context self._backlog = backlog self._server: Optional[asyncio.AbstractServer] = None @@ -74,30 +71,11 @@ async def start(self) -> None: async def stop(self) -> None: self._runner._check_site(self) - if self._server is None: - self._runner._unreg_site(self) - return # not started yet - self._server.close() - # named pipes do not have wait_closed property - if hasattr(self._server, "wait_closed"): - await self._server.wait_closed() - - # Wait for pending tasks for a given time limit. - with suppress(asyncio.TimeoutError): - await asyncio.wait_for( - self._wait(asyncio.current_task()), timeout=self._shutdown_timeout - ) + if self._server is not None: # Maybe not started yet + self._server.close() - await self._runner.shutdown() - assert self._runner.server - await self._runner.server.shutdown(self._shutdown_timeout) self._runner._unreg_site(self) - async def _wait(self, parent_task: Optional["asyncio.Task[object]"]) -> None: - exclude = self._runner.starting_tasks | {asyncio.current_task(), parent_task} - while tasks := asyncio.all_tasks() - exclude: - await asyncio.wait(tasks) - class TCPSite(BaseSite): __slots__ = ("_host", "_port", "_reuse_address", "_reuse_port") @@ -108,7 +86,6 @@ def __init__( host: Optional[str] = None, port: Optional[int] = None, *, - shutdown_timeout: float = 60.0, ssl_context: Optional[SSLContext] = None, backlog: int = 128, reuse_address: Optional[bool] = None, @@ -116,7 +93,6 @@ def __init__( ) -> None: super().__init__( runner, - shutdown_timeout=shutdown_timeout, ssl_context=ssl_context, backlog=backlog, ) @@ -157,13 +133,11 @@ def __init__( runner: "BaseRunner", path: PathLike, *, - shutdown_timeout: float = 60.0, ssl_context: Optional[SSLContext] = None, backlog: int = 128, ) -> None: super().__init__( runner, - shutdown_timeout=shutdown_timeout, ssl_context=ssl_context, backlog=backlog, ) @@ -190,9 +164,7 @@ async def start(self) -> None: class NamedPipeSite(BaseSite): __slots__ = ("_path",) - def __init__( - self, runner: "BaseRunner", path: str, *, shutdown_timeout: float = 60.0 - ) -> None: + def __init__(self, runner: "BaseRunner", path: str) -> None: loop = asyncio.get_event_loop() if not isinstance( loop, asyncio.ProactorEventLoop # type: ignore[attr-defined] @@ -200,7 +172,7 @@ def __init__( raise RuntimeError( "Named Pipes only available in proactor" "loop under windows" ) - super().__init__(runner, shutdown_timeout=shutdown_timeout) + super().__init__(runner) self._path = path @property @@ -226,13 +198,11 @@ def __init__( runner: "BaseRunner", sock: socket.socket, *, - shutdown_timeout: float = 60.0, ssl_context: Optional[SSLContext] = None, backlog: int = 128, ) -> None: super().__init__( runner, - shutdown_timeout=shutdown_timeout, ssl_context=ssl_context, backlog=backlog, ) @@ -260,13 +230,28 @@ async def start(self) -> None: class BaseRunner(ABC): - __slots__ = ("starting_tasks", "_handle_signals", "_kwargs", "_server", "_sites") + __slots__ = ( + "shutdown_callback", + "_handle_signals", + "_kwargs", + "_server", + "_sites", + "_shutdown_timeout", + ) - def __init__(self, *, handle_signals: bool = False, **kwargs: Any) -> None: + def __init__( + self, + *, + handle_signals: bool = False, + shutdown_timeout: float = 60.0, + **kwargs: Any, + ) -> None: + self.shutdown_callback: Optional[Callable[[], Awaitable[None]]] = None self._handle_signals = handle_signals self._kwargs = kwargs self._server: Optional[Server] = None self._sites: List[BaseSite] = [] + self._shutdown_timeout = shutdown_timeout @property def server(self) -> Optional[Server]: @@ -300,28 +285,32 @@ async def setup(self) -> None: pass self._server = await self._make_server() - # On shutdown we want to avoid waiting on tasks which run forever. - # It's very likely that all tasks which run forever will have been created by - # the time we have completed the application startup (in self._make_server()), - # so we just record all running tasks here and exclude them later. - self.starting_tasks = asyncio.all_tasks() @abstractmethod async def shutdown(self) -> None: - pass # pragma: no cover + """Call any shutdown hooks to help server close gracefully.""" async def cleanup(self) -> None: - loop = asyncio.get_event_loop() - # The loop over sites is intentional, an exception on gather() # leaves self._sites in unpredictable state. # The loop guarantees that a site is either deleted on success or # still present on failure for site in list(self._sites): await site.stop() + + if self._server: # If setup succeeded + self._server.pre_shutdown() + await self.shutdown() + + if self.shutdown_callback: + await self.shutdown_callback() + + await self._server.shutdown(self._shutdown_timeout) await self._cleanup_server() + self._server = None if self._handle_signals: + loop = asyncio.get_running_loop() try: loop.remove_signal_handler(signal.SIGINT) loop.remove_signal_handler(signal.SIGTERM) diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index a3d658afbff..fcde46a3482 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -61,8 +61,12 @@ def _make_request( ) -> BaseRequest: return BaseRequest(message, payload, protocol, writer, task, self._loop) + def pre_shutdown(self) -> None: + for conn in self._connections: + conn.close() + async def shutdown(self, timeout: Optional[float] = None) -> None: - coros = [conn.shutdown(timeout) for conn in self._connections] + coros = (conn.shutdown(timeout) for conn in self._connections) await asyncio.gather(*coros) self._connections.clear() diff --git a/aiohttp/worker.py b/aiohttp/worker.py index c1c45f192a5..1c734ea56c4 100644 --- a/aiohttp/worker.py +++ b/aiohttp/worker.py @@ -86,6 +86,7 @@ async def _run(self) -> None: access_log_format=self._get_valid_log_format( self.cfg.access_log_format ), + shutdown_timeout=self.cfg.graceful_timeout / 100 * 95, ) await runner.setup() @@ -99,7 +100,6 @@ async def _run(self) -> None: runner, sock, ssl_context=ctx, - shutdown_timeout=self.cfg.graceful_timeout / 100 * 95, ) await site.start() diff --git a/docs/web_advanced.rst b/docs/web_advanced.rst index 6055ddaf319..7ee0cd372a6 100644 --- a/docs/web_advanced.rst +++ b/docs/web_advanced.rst @@ -927,25 +927,38 @@ Graceful shutdown Stopping *aiohttp web server* by just closing all connections is not always satisfactory. -The first thing aiohttp will do is to stop listening on the sockets, -so new connections will be rejected. It will then wait a few -seconds to allow any pending tasks to complete before continuing -with application shutdown. The timeout can be adjusted with -``shutdown_timeout`` in :func:`run_app`. +When aiohttp is run with :func:`run_app`, it will attempt a graceful shutdown +by following these steps (if using a :ref:`runner `, +then calling :meth:`AppRunner.cleanup` will perform these steps, excluding +steps 4 and 7). + +1. Stop each site listening on sockets, so new connections will be rejected. +2. Close idle keep-alive connections (and set active ones to close upon completion). +3. Call the :attr:`Application.on_shutdown` signal. This should be used to shutdown + long-lived connections, such as websockets (see below). +4. Wait a short time for running tasks to complete. This allows any pending handlers + or background tasks to complete successfully. The timeout can be adjusted with + ``shutdown_timeout`` in :func:`run_app`. +5. Close any remaining connections and cancel their handlers. It will wait on the + canceling handlers for a short time, again adjustable with ``shutdown_timeout``. +6. Call the :attr:`Application.on_cleanup` signal. This should be used to cleanup any + resources (such as DB connections). This includes completing the + :ref:`cleanup contexts`. +7. Cancel any remaining tasks and wait on them to complete. + +Websocket shutdown +^^^^^^^^^^^^^^^^^^ -Another problem is if the application supports :term:`websockets ` or -*data streaming* it most likely has open connections at server -shutdown time. +One problem is if the application supports :term:`websockets ` or +*data streaming* it most likely has open connections at server shutdown time. -The *library* has no knowledge how to close them gracefully but -developer can help by registering :attr:`Application.on_shutdown` -signal handler and call the signal on *web server* closing. +The *library* has no knowledge how to close them gracefully but a developer can +help by registering an :attr:`Application.on_shutdown` signal handler. -Developer should keep a list of opened connections +A developer should keep a list of opened connections (:class:`Application` is a good candidate). -The following :term:`websocket` snippet shows an example for websocket -handler:: +The following :term:`websocket` snippet shows an example of a websocket handler:: from aiohttp import web import weakref @@ -967,20 +980,16 @@ handler:: return ws -Signal handler may look like:: +Then the signal handler may look like:: from aiohttp import WSCloseCode async def on_shutdown(app): for ws in set(app[websockets]): - await ws.close(code=WSCloseCode.GOING_AWAY, - message='Server shutdown') + await ws.close(code=WSCloseCode.GOING_AWAY, message="Server shutdown") app.on_shutdown.append(on_shutdown) -Both :func:`run_app` and :meth:`AppRunner.cleanup` call shutdown -signal handlers. - .. _aiohttp-web-ceil-absolute-timeout: Ceil of absolute timeout value diff --git a/tests/test_run_app.py b/tests/test_run_app.py index c3e52ef361b..c2945614c63 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -10,13 +10,13 @@ import subprocess import sys import time -from typing import Any, Callable, NoReturn +from typing import Any, Callable, NoReturn, Set from unittest import mock from uuid import uuid4 import pytest -from aiohttp import ClientConnectorError, ClientSession, web +from aiohttp import ClientConnectorError, ClientSession, WSCloseCode, web from aiohttp.test_utils import make_mocked_coro from aiohttp.web_runner import BaseRunner @@ -1113,3 +1113,90 @@ async def handler(request: web.Request) -> web.Response: web.run_app(app, port=port, shutdown_timeout=5) assert t.exception() is None assert finished is True + + def test_shutdown_close_idle_keepalive( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + + async def test() -> None: + await asyncio.sleep(1) + async with ClientSession() as sess: + async with sess.get(f"http://localhost:{port}/stop"): + pass + + # Hold on to keep-alive connection. + await asyncio.sleep(5) + + async def run_test(app: web.Application) -> None: + nonlocal t + t = asyncio.create_task(test()) + yield + t.cancel() + with contextlib.suppress(asyncio.CancelledError): + await t + + t = None + app = web.Application() + app.cleanup_ctx.append(run_test) + app.router.add_get("/stop", self.stop) + + web.run_app(app, port=port, shutdown_timeout=10) + # If connection closed, then test() will be cancelled in cleanup_ctx. + # If not, then shutdown_timeout will allow it to sleep until complete. + assert t.cancelled() + + def test_shutdown_close_websockets( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + WS = web.AppKey("ws", Set[web.WebSocketResponse]) + client_finished = server_finished = False + + async def ws_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + request.app[WS].add(ws) + async for msg in ws: + pass + nonlocal server_finished + server_finished = True + return ws + + async def close_websockets(app: web.Application) -> None: + for ws in app[WS]: + await ws.close(code=WSCloseCode.GOING_AWAY) + + async def test() -> None: + await asyncio.sleep(1) + async with ClientSession() as sess: + async with sess.ws_connect(f"http://localhost:{port}/ws") as ws: + async with sess.get(f"http://localhost:{port}/stop"): + pass + + async for msg in ws: + pass + nonlocal client_finished + client_finished = True + + async def run_test(app: web.Application) -> None: + nonlocal t + t = asyncio.create_task(test()) + yield + t.cancel() + with contextlib.suppress(asyncio.CancelledError): + await t + + t = None + app = web.Application() + app[WS] = set() + app.on_shutdown.append(close_websockets) + app.cleanup_ctx.append(run_test) + app.router.add_get("/ws", ws_handler) + app.router.add_get("/stop", self.stop) + + start = time.time() + web.run_app(app, port=port, shutdown_timeout=10) + assert time.time() - start < 5 + assert client_finished + assert server_finished diff --git a/tests/test_worker.py b/tests/test_worker.py index e59efe17ea9..2797ef4c08c 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -79,6 +79,7 @@ def test_run( worker.cfg = mock.Mock() worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT worker.cfg.is_ssl = False + worker.cfg.graceful_timeout = 100 worker.sockets = [] worker.loop = loop @@ -95,6 +96,7 @@ def test_run_async_factory( worker.cfg = mock.Mock() worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT worker.cfg.is_ssl = False + worker.cfg.graceful_timeout = 100 worker.sockets = [] app = worker.wsgi