Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign shutdown process #7718

Merged
merged 21 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES/7718.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fixed keep-alive connections stopping a graceful shutdown.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Dreamsorcerer “fixed” implies a “bugfix” changelog note type, but the filename says “feature”. If it's both, it might be worth splitting into two fragments.

Added ``shutdown_timeout`` parameter to ``BaseRunner``, while
removing ``shutdown_timeout`` parameter from ``BaseSite``. -- by :user:`Dreamsorcerer`
34 changes: 27 additions & 7 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import sys
from argparse import ArgumentParser
from collections.abc import Iterable
from contextlib import suppress
from functools import partial
from importlib import import_module
from typing import (
Any,
Expand All @@ -18,6 +20,7 @@
Union,
cast,
)
from weakref import WeakSet

from .abc import AbstractAccessLogger
from .helpers import AppKey
Expand Down Expand Up @@ -291,6 +294,23 @@ async def _run_app(
reuse_port: Optional[bool] = None,
handler_cancellation: bool = False,
) -> None:
async def wait(
starting_tasks: "WeakSet[asyncio.Task[object]]", shutdown_timeout: float
) -> None:
# Wait for pending tasks for a given time limit.
t = asyncio.current_task()
assert t is not None
starting_tasks.add(t)
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(_wait(starting_tasks), timeout=shutdown_timeout)

async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None:
t = asyncio.current_task()
assert t is not None
exclude.add(t)
while tasks := asyncio.all_tasks().difference(exclude):
await asyncio.wait(tasks)

# An internal function to actually do all dirty job for application running
if asyncio.iscoroutine(app):
app = await app
Expand All @@ -304,10 +324,17 @@ async def _run_app(
access_log_format=access_log_format,
access_log=access_log,
keepalive_timeout=keepalive_timeout,
shutdown_timeout=shutdown_timeout,
handler_cancellation=handler_cancellation,
)

await runner.setup()
# On shutdown we want to avoid waiting on tasks which run forever.
# It's very likely that all tasks which run forever will have been created by
# the time we have completed the application startup (in runner.setup()),
# so we just record all running tasks here and exclude them later.
starting_tasks: "WeakSet[asyncio.Task[object]]" = WeakSet(asyncio.all_tasks())
runner.shutdown_callback = partial(wait, starting_tasks, shutdown_timeout)

sites: List[BaseSite] = []

Expand All @@ -319,7 +346,6 @@ async def _run_app(
runner,
host,
port,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
reuse_address=reuse_address,
Expand All @@ -333,7 +359,6 @@ async def _run_app(
runner,
h,
port,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
reuse_address=reuse_address,
Expand All @@ -345,7 +370,6 @@ async def _run_app(
TCPSite(
runner,
port=port,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
reuse_address=reuse_address,
Expand All @@ -359,7 +383,6 @@ async def _run_app(
UnixSite(
runner,
path,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
)
Expand All @@ -370,7 +393,6 @@ async def _run_app(
UnixSite(
runner,
p,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
)
Expand All @@ -382,7 +404,6 @@ async def _run_app(
SockSite(
runner,
sock,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
)
Expand All @@ -393,7 +414,6 @@ async def _run_app(
SockSite(
runner,
s,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
)
Expand Down
81 changes: 35 additions & 46 deletions aiohttp/web_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import signal
import socket
from abc import ABC, abstractmethod
from contextlib import suppress
from typing import Any, List, Optional, Set, Type
from typing import Any, Awaitable, Callable, List, Optional, Set, Type

from yarl import URL

Expand Down Expand Up @@ -45,20 +44,18 @@ def _raise_graceful_exit() -> None:


class BaseSite(ABC):
__slots__ = ("_runner", "_shutdown_timeout", "_ssl_context", "_backlog", "_server")
__slots__ = ("_runner", "_ssl_context", "_backlog", "_server")

def __init__(
self,
runner: "BaseRunner",
*,
shutdown_timeout: float = 60.0,
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
ssl_context: Optional[SSLContext] = None,
backlog: int = 128,
) -> None:
if runner.server is None:
raise RuntimeError("Call runner.setup() before making a site")
self._runner = runner
self._shutdown_timeout = shutdown_timeout
self._ssl_context = ssl_context
self._backlog = backlog
self._server: Optional[asyncio.AbstractServer] = None
Expand All @@ -74,30 +71,11 @@ async def start(self) -> None:

async def stop(self) -> None:
self._runner._check_site(self)
if self._server is None:
self._runner._unreg_site(self)
return # not started yet
self._server.close()
# named pipes do not have wait_closed property
if hasattr(self._server, "wait_closed"):
await self._server.wait_closed()
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved

# Wait for pending tasks for a given time limit.
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(
self._wait(asyncio.current_task()), timeout=self._shutdown_timeout
)
if self._server is not None: # Maybe not started yet
self._server.close()

await self._runner.shutdown()
assert self._runner.server
await self._runner.server.shutdown(self._shutdown_timeout)
self._runner._unreg_site(self)

async def _wait(self, parent_task: Optional["asyncio.Task[object]"]) -> None:
exclude = self._runner.starting_tasks | {asyncio.current_task(), parent_task}
while tasks := asyncio.all_tasks() - exclude:
await asyncio.wait(tasks)


class TCPSite(BaseSite):
__slots__ = ("_host", "_port", "_reuse_address", "_reuse_port")
Expand All @@ -108,15 +86,13 @@ def __init__(
host: Optional[str] = None,
port: Optional[int] = None,
*,
shutdown_timeout: float = 60.0,
ssl_context: Optional[SSLContext] = None,
backlog: int = 128,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
) -> None:
super().__init__(
runner,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
)
Expand Down Expand Up @@ -157,13 +133,11 @@ def __init__(
runner: "BaseRunner",
path: PathLike,
*,
shutdown_timeout: float = 60.0,
ssl_context: Optional[SSLContext] = None,
backlog: int = 128,
) -> None:
super().__init__(
runner,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
)
Expand All @@ -190,17 +164,15 @@ async def start(self) -> None:
class NamedPipeSite(BaseSite):
__slots__ = ("_path",)

def __init__(
self, runner: "BaseRunner", path: str, *, shutdown_timeout: float = 60.0
) -> None:
def __init__(self, runner: "BaseRunner", path: str) -> None:
loop = asyncio.get_event_loop()
if not isinstance(
loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
):
raise RuntimeError(
"Named Pipes only available in proactor" "loop under windows"
)
super().__init__(runner, shutdown_timeout=shutdown_timeout)
super().__init__(runner)
self._path = path

@property
Expand All @@ -226,13 +198,11 @@ def __init__(
runner: "BaseRunner",
sock: socket.socket,
*,
shutdown_timeout: float = 60.0,
ssl_context: Optional[SSLContext] = None,
backlog: int = 128,
) -> None:
super().__init__(
runner,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
)
Expand Down Expand Up @@ -260,13 +230,28 @@ async def start(self) -> None:


class BaseRunner(ABC):
__slots__ = ("starting_tasks", "_handle_signals", "_kwargs", "_server", "_sites")
__slots__ = (
"shutdown_callback",
"_handle_signals",
"_kwargs",
"_server",
"_sites",
"_shutdown_timeout",
)

def __init__(self, *, handle_signals: bool = False, **kwargs: Any) -> None:
def __init__(
self,
*,
handle_signals: bool = False,
shutdown_timeout: float = 60.0,
**kwargs: Any,
) -> None:
self.shutdown_callback: Optional[Callable[[], Awaitable[None]]] = None
self._handle_signals = handle_signals
self._kwargs = kwargs
self._server: Optional[Server] = None
self._sites: List[BaseSite] = []
self._shutdown_timeout = shutdown_timeout

@property
def server(self) -> Optional[Server]:
Expand Down Expand Up @@ -300,28 +285,32 @@ async def setup(self) -> None:
pass

self._server = await self._make_server()
# On shutdown we want to avoid waiting on tasks which run forever.
# It's very likely that all tasks which run forever will have been created by
# the time we have completed the application startup (in self._make_server()),
# so we just record all running tasks here and exclude them later.
self.starting_tasks = asyncio.all_tasks()

@abstractmethod
async def shutdown(self) -> None:
pass # pragma: no cover
"""Call any shutdown hooks to help server close gracefully."""

async def cleanup(self) -> None:
loop = asyncio.get_event_loop()

# The loop over sites is intentional, an exception on gather()
# leaves self._sites in unpredictable state.
# The loop guarantees that a site is either deleted on success or
# still present on failure
for site in list(self._sites):
await site.stop()

if self._server: # If setup succeeded
self._server.pre_shutdown()
await self.shutdown()

if self.shutdown_callback:
await self.shutdown_callback()

await self._server.shutdown(self._shutdown_timeout)
await self._cleanup_server()

self._server = None
if self._handle_signals:
loop = asyncio.get_running_loop()
try:
loop.remove_signal_handler(signal.SIGINT)
loop.remove_signal_handler(signal.SIGTERM)
Expand Down
6 changes: 5 additions & 1 deletion aiohttp/web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ def _make_request(
) -> BaseRequest:
return BaseRequest(message, payload, protocol, writer, task, self._loop)

def pre_shutdown(self) -> None:
for conn in self._connections:
conn.close()

async def shutdown(self, timeout: Optional[float] = None) -> None:
coros = [conn.shutdown(timeout) for conn in self._connections]
coros = (conn.shutdown(timeout) for conn in self._connections)
await asyncio.gather(*coros)
self._connections.clear()

Expand Down
2 changes: 1 addition & 1 deletion aiohttp/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ async def _run(self) -> None:
access_log_format=self._get_valid_log_format(
self.cfg.access_log_format
),
shutdown_timeout=self.cfg.graceful_timeout / 100 * 95,
)
await runner.setup()

Expand All @@ -99,7 +100,6 @@ async def _run(self) -> None:
runner,
sock,
ssl_context=ctx,
shutdown_timeout=self.cfg.graceful_timeout / 100 * 95,
)
await site.start()

Expand Down
Loading