From 02401e51483e01c7c5252efcad827af2fc40ff4f Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Tue, 17 Dec 2024 09:57:20 +0100 Subject: [PATCH] Add `--task-impl` option (#468) --- README.md | 3 +++ granian/_futures.py | 15 +++++++++++---- granian/_imports.py | 5 +++++ granian/cli.py | 10 +++++++++- granian/constants.py | 6 ++++++ granian/server.py | 30 ++++++++++++++++++++++++------ pyproject.toml | 2 +- 7 files changed, 59 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index fb7b3c66..988c5670 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,9 @@ Options: --loop [auto|asyncio|rloop|uvloop] Event loop implementation [env var: GRANIAN_LOOP; default: (auto)] + --task-impl [auto|rust|asyncio] + Async task implementation to use [env var: + GRANIAN_TASK_IMPL; default: (auto)] --backlog INTEGER RANGE Maximum number of connections to hold in backlog (globally) [env var: GRANIAN_BACKLOG; default: 1024; x>=128] diff --git a/granian/_futures.py b/granian/_futures.py index 8630a200..a819d068 100644 --- a/granian/_futures.py +++ b/granian/_futures.py @@ -25,10 +25,17 @@ def __init__(self, loop, ctx, cb, aio_tenter, aio_texit): self._schedule_fn = _cbsched_schedule(loop, ctx, self._run, cb) -def _new_cbscheduler(loop, cb): - return _CBScheduler( - loop, contextvars.copy_context(), cb, partial(_aio_taskenter, loop), partial(_aio_taskleave, loop) - ) +class _CBSchedulerAIO(_BaseCBScheduler): + __slots__ = [] + + def __init__(self, loop, ctx, cb, aio_tenter, aio_texit): + super().__init__() + self._schedule_fn = _cbsched_schedule(loop, ctx, loop.create_task, cb) + + +def _new_cbscheduler(loop, cb, impl_asyncio=False): + _cls = _CBSchedulerAIO if impl_asyncio else _CBScheduler + return _cls(loop, contextvars.copy_context(), cb, partial(_aio_taskenter, loop), partial(_aio_taskleave, loop)) def _cbsched_schedule(loop, ctx, run, cb): diff --git a/granian/_imports.py b/granian/_imports.py index b07ce1b3..62711fc2 100644 --- a/granian/_imports.py +++ b/granian/_imports.py @@ -1,3 +1,8 @@ +try: + import anyio +except ImportError: + anyio = None + try: import setproctitle except ImportError: diff --git a/granian/cli.py b/granian/cli.py index 87b6ff76..882dda00 100644 --- a/granian/cli.py +++ b/granian/cli.py @@ -5,7 +5,7 @@ import click -from .constants import HTTPModes, Interfaces, Loops, ThreadModes +from .constants import HTTPModes, Interfaces, Loops, TaskImpl, ThreadModes from .errors import FatalError from .http import HTTP1Settings, HTTP2Settings from .log import LogLevels @@ -77,6 +77,12 @@ def option(*param_decls: str, cls: Optional[Type[click.Option]] = None, **attrs: help='Threading mode to use', ) @option('--loop', type=EnumType(Loops), default=Loops.auto, help='Event loop implementation') +@option( + '--task-impl', + type=EnumType(TaskImpl), + default=TaskImpl.auto, + help='Async task implementation to use', +) @option( '--backlog', type=click.IntRange(128), @@ -261,6 +267,7 @@ def cli( blocking_threads: Optional[int], threading_mode: ThreadModes, loop: Loops, + task_impl: TaskImpl, backlog: int, backpressure: Optional[int], http1_buffer_size: int, @@ -316,6 +323,7 @@ def cli( blocking_threads=blocking_threads, threading_mode=threading_mode, loop=loop, + task_impl=task_impl, http=http, websockets=websockets, backlog=backlog, diff --git a/granian/constants.py b/granian/constants.py index d7e98c44..221c8a0c 100644 --- a/granian/constants.py +++ b/granian/constants.py @@ -29,3 +29,9 @@ class Loops(StrEnum): asyncio = 'asyncio' rloop = 'rloop' uvloop = 'uvloop' + + +class TaskImpl(StrEnum): + auto = 'auto' + rust = 'rust' + asyncio = 'asyncio' diff --git a/granian/server.py b/granian/server.py index 641b6a4b..77294a75 100644 --- a/granian/server.py +++ b/granian/server.py @@ -14,11 +14,11 @@ from ._futures import _future_watcher_wrapper, _new_cbscheduler from ._granian import ASGIWorker, RSGIWorker, WSGIWorker -from ._imports import setproctitle, watchfiles +from ._imports import anyio, setproctitle, watchfiles from ._internal import load_target from ._signals import set_main_signals from .asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap -from .constants import HTTPModes, Interfaces, Loops, ThreadModes +from .constants import HTTPModes, Interfaces, Loops, TaskImpl, ThreadModes from .errors import ConfigurationError, PidFileError from .http import HTTP1Settings, HTTP2Settings from .log import DEFAULT_ACCESSLOG_FMT, LogLevels, configure_logging, logger @@ -81,6 +81,7 @@ def __init__( blocking_threads: Optional[int] = None, threading_mode: ThreadModes = ThreadModes.workers, loop: Loops = Loops.auto, + task_impl: TaskImpl = TaskImpl.auto, http: HTTPModes = HTTPModes.auto, websockets: bool = True, backlog: int = 1024, @@ -118,6 +119,7 @@ def __init__( self.threads = max(1, threads) self.threading_mode = threading_mode self.loop = loop + self.task_impl = task_impl self.http = http self.websockets = websockets self.backlog = max(128, backlog) @@ -188,6 +190,7 @@ def _spawn_asgi_worker( blocking_threads: int, backpressure: int, threading_mode: ThreadModes, + task_impl: TaskImpl, http_mode: HTTPModes, http1_settings: Optional[HTTP1Settings], http2_settings: Optional[HTTP2Settings], @@ -225,7 +228,9 @@ def _spawn_asgi_worker( *ssl_ctx, ) serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) - scheduler = _new_cbscheduler(loop, _future_watcher_wrapper(wcallback)) + scheduler = _new_cbscheduler( + loop, _future_watcher_wrapper(wcallback), impl_asyncio=task_impl == TaskImpl.asyncio + ) serve(scheduler, loop, shutdown_event) @staticmethod @@ -239,6 +244,7 @@ def _spawn_asgi_lifespan_worker( blocking_threads: int, backpressure: int, threading_mode: ThreadModes, + task_impl: TaskImpl, http_mode: HTTPModes, http1_settings: Optional[HTTP1Settings], http2_settings: Optional[HTTP2Settings], @@ -283,7 +289,9 @@ def _spawn_asgi_lifespan_worker( *ssl_ctx, ) serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) - scheduler = _new_cbscheduler(loop, _future_watcher_wrapper(wcallback)) + scheduler = _new_cbscheduler( + loop, _future_watcher_wrapper(wcallback), impl_asyncio=task_impl == TaskImpl.asyncio + ) serve(scheduler, loop, shutdown_event) loop.run_until_complete(lifespan_handler.shutdown()) @@ -298,6 +306,7 @@ def _spawn_rsgi_worker( blocking_threads: int, backpressure: int, threading_mode: ThreadModes, + task_impl: TaskImpl, http_mode: HTTPModes, http1_settings: Optional[HTTP1Settings], http2_settings: Optional[HTTP2Settings], @@ -343,7 +352,9 @@ def _spawn_rsgi_worker( *ssl_ctx, ) serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) - scheduler = _new_cbscheduler(loop, _future_watcher_wrapper(callback)) + scheduler = _new_cbscheduler( + loop, _future_watcher_wrapper(callback), impl_asyncio=task_impl == TaskImpl.asyncio + ) serve(scheduler, loop, shutdown_event) callback_del(loop) @@ -358,6 +369,7 @@ def _spawn_wsgi_worker( blocking_threads: int, backpressure: int, threading_mode: ThreadModes, + task_impl: TaskImpl, http_mode: HTTPModes, http1_settings: Optional[HTTP1Settings], http2_settings: Optional[HTTP2Settings], @@ -385,7 +397,9 @@ def _spawn_wsgi_worker( worker_id, sfd, threads, blocking_threads, backpressure, http_mode, http1_settings, http2_settings, *ssl_ctx ) serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) - scheduler = _new_cbscheduler(loop, _wsgi_call_wrap(callback, scope_opts, log_access_fmt)) + scheduler = _new_cbscheduler( + loop, _wsgi_call_wrap(callback, scope_opts, log_access_fmt), impl_asyncio=task_impl == TaskImpl.asyncio + ) serve(scheduler, loop, shutdown_event) shutdown_event.qs.wait() @@ -416,6 +430,7 @@ def _spawn_proc(self, idx, target, callback_loader, socket_loader) -> Worker: self.blocking_threads, self.backpressure, self.threading_mode, + self.task_impl, self.http, self.http1_settings, self.http2_settings, @@ -713,5 +728,8 @@ def serve( logger.error('Workers lifetime cannot be less than 60 seconds') raise ConfigurationError('workers_lifetime') + if self.task_impl == TaskImpl.auto: + self.task_impl = TaskImpl.asyncio if anyio is not None else TaskImpl.rust + serve_method = self._serve_with_reloader if self.reload_on_changes else self._serve serve_method(spawn_target, target_loader) diff --git a/pyproject.toml b/pyproject.toml index a3c9f6a7..ccd00c76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ extend-ignore = [ 'S110', # except pass is fine ] flake8-quotes = { inline-quotes = 'single', multiline-quotes = 'double' } -mccabe = { max-complexity = 13 } +mccabe = { max-complexity = 14 } [tool.ruff.format] quote-style = 'single'