Skip to content

Commit

Permalink
Add --task-impl option (#468)
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro authored Dec 17, 2024
1 parent 693af47 commit 02401e5
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 12 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ Options:
--loop [auto|asyncio|rloop|uvloop]
Event loop implementation [env var:
GRANIAN_LOOP; default: (auto)]
--task-impl [auto|rust|asyncio]
Async task implementation to use [env var:
GRANIAN_TASK_IMPL; default: (auto)]
--backlog INTEGER RANGE Maximum number of connections to hold in
backlog (globally) [env var:
GRANIAN_BACKLOG; default: 1024; x>=128]
Expand Down
15 changes: 11 additions & 4 deletions granian/_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,17 @@ def __init__(self, loop, ctx, cb, aio_tenter, aio_texit):
self._schedule_fn = _cbsched_schedule(loop, ctx, self._run, cb)


def _new_cbscheduler(loop, cb):
return _CBScheduler(
loop, contextvars.copy_context(), cb, partial(_aio_taskenter, loop), partial(_aio_taskleave, loop)
)
class _CBSchedulerAIO(_BaseCBScheduler):
__slots__ = []

def __init__(self, loop, ctx, cb, aio_tenter, aio_texit):
super().__init__()
self._schedule_fn = _cbsched_schedule(loop, ctx, loop.create_task, cb)


def _new_cbscheduler(loop, cb, impl_asyncio=False):
_cls = _CBSchedulerAIO if impl_asyncio else _CBScheduler
return _cls(loop, contextvars.copy_context(), cb, partial(_aio_taskenter, loop), partial(_aio_taskleave, loop))


def _cbsched_schedule(loop, ctx, run, cb):
Expand Down
5 changes: 5 additions & 0 deletions granian/_imports.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
try:
import anyio
except ImportError:
anyio = None

try:
import setproctitle
except ImportError:
Expand Down
10 changes: 9 additions & 1 deletion granian/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import click

from .constants import HTTPModes, Interfaces, Loops, ThreadModes
from .constants import HTTPModes, Interfaces, Loops, TaskImpl, ThreadModes
from .errors import FatalError
from .http import HTTP1Settings, HTTP2Settings
from .log import LogLevels
Expand Down Expand Up @@ -77,6 +77,12 @@ def option(*param_decls: str, cls: Optional[Type[click.Option]] = None, **attrs:
help='Threading mode to use',
)
@option('--loop', type=EnumType(Loops), default=Loops.auto, help='Event loop implementation')
@option(
'--task-impl',
type=EnumType(TaskImpl),
default=TaskImpl.auto,
help='Async task implementation to use',
)
@option(
'--backlog',
type=click.IntRange(128),
Expand Down Expand Up @@ -261,6 +267,7 @@ def cli(
blocking_threads: Optional[int],
threading_mode: ThreadModes,
loop: Loops,
task_impl: TaskImpl,
backlog: int,
backpressure: Optional[int],
http1_buffer_size: int,
Expand Down Expand Up @@ -316,6 +323,7 @@ def cli(
blocking_threads=blocking_threads,
threading_mode=threading_mode,
loop=loop,
task_impl=task_impl,
http=http,
websockets=websockets,
backlog=backlog,
Expand Down
6 changes: 6 additions & 0 deletions granian/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,9 @@ class Loops(StrEnum):
asyncio = 'asyncio'
rloop = 'rloop'
uvloop = 'uvloop'


class TaskImpl(StrEnum):
auto = 'auto'
rust = 'rust'
asyncio = 'asyncio'
30 changes: 24 additions & 6 deletions granian/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from ._futures import _future_watcher_wrapper, _new_cbscheduler
from ._granian import ASGIWorker, RSGIWorker, WSGIWorker
from ._imports import setproctitle, watchfiles
from ._imports import anyio, setproctitle, watchfiles
from ._internal import load_target
from ._signals import set_main_signals
from .asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap
from .constants import HTTPModes, Interfaces, Loops, ThreadModes
from .constants import HTTPModes, Interfaces, Loops, TaskImpl, ThreadModes
from .errors import ConfigurationError, PidFileError
from .http import HTTP1Settings, HTTP2Settings
from .log import DEFAULT_ACCESSLOG_FMT, LogLevels, configure_logging, logger
Expand Down Expand Up @@ -81,6 +81,7 @@ def __init__(
blocking_threads: Optional[int] = None,
threading_mode: ThreadModes = ThreadModes.workers,
loop: Loops = Loops.auto,
task_impl: TaskImpl = TaskImpl.auto,
http: HTTPModes = HTTPModes.auto,
websockets: bool = True,
backlog: int = 1024,
Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(
self.threads = max(1, threads)
self.threading_mode = threading_mode
self.loop = loop
self.task_impl = task_impl
self.http = http
self.websockets = websockets
self.backlog = max(128, backlog)
Expand Down Expand Up @@ -188,6 +190,7 @@ def _spawn_asgi_worker(
blocking_threads: int,
backpressure: int,
threading_mode: ThreadModes,
task_impl: TaskImpl,
http_mode: HTTPModes,
http1_settings: Optional[HTTP1Settings],
http2_settings: Optional[HTTP2Settings],
Expand Down Expand Up @@ -225,7 +228,9 @@ def _spawn_asgi_worker(
*ssl_ctx,
)
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
scheduler = _new_cbscheduler(loop, _future_watcher_wrapper(wcallback))
scheduler = _new_cbscheduler(
loop, _future_watcher_wrapper(wcallback), impl_asyncio=task_impl == TaskImpl.asyncio
)
serve(scheduler, loop, shutdown_event)

@staticmethod
Expand All @@ -239,6 +244,7 @@ def _spawn_asgi_lifespan_worker(
blocking_threads: int,
backpressure: int,
threading_mode: ThreadModes,
task_impl: TaskImpl,
http_mode: HTTPModes,
http1_settings: Optional[HTTP1Settings],
http2_settings: Optional[HTTP2Settings],
Expand Down Expand Up @@ -283,7 +289,9 @@ def _spawn_asgi_lifespan_worker(
*ssl_ctx,
)
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
scheduler = _new_cbscheduler(loop, _future_watcher_wrapper(wcallback))
scheduler = _new_cbscheduler(
loop, _future_watcher_wrapper(wcallback), impl_asyncio=task_impl == TaskImpl.asyncio
)
serve(scheduler, loop, shutdown_event)
loop.run_until_complete(lifespan_handler.shutdown())

Expand All @@ -298,6 +306,7 @@ def _spawn_rsgi_worker(
blocking_threads: int,
backpressure: int,
threading_mode: ThreadModes,
task_impl: TaskImpl,
http_mode: HTTPModes,
http1_settings: Optional[HTTP1Settings],
http2_settings: Optional[HTTP2Settings],
Expand Down Expand Up @@ -343,7 +352,9 @@ def _spawn_rsgi_worker(
*ssl_ctx,
)
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
scheduler = _new_cbscheduler(loop, _future_watcher_wrapper(callback))
scheduler = _new_cbscheduler(
loop, _future_watcher_wrapper(callback), impl_asyncio=task_impl == TaskImpl.asyncio
)
serve(scheduler, loop, shutdown_event)
callback_del(loop)

Expand All @@ -358,6 +369,7 @@ def _spawn_wsgi_worker(
blocking_threads: int,
backpressure: int,
threading_mode: ThreadModes,
task_impl: TaskImpl,
http_mode: HTTPModes,
http1_settings: Optional[HTTP1Settings],
http2_settings: Optional[HTTP2Settings],
Expand Down Expand Up @@ -385,7 +397,9 @@ def _spawn_wsgi_worker(
worker_id, sfd, threads, blocking_threads, backpressure, http_mode, http1_settings, http2_settings, *ssl_ctx
)
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
scheduler = _new_cbscheduler(loop, _wsgi_call_wrap(callback, scope_opts, log_access_fmt))
scheduler = _new_cbscheduler(
loop, _wsgi_call_wrap(callback, scope_opts, log_access_fmt), impl_asyncio=task_impl == TaskImpl.asyncio
)
serve(scheduler, loop, shutdown_event)
shutdown_event.qs.wait()

Expand Down Expand Up @@ -416,6 +430,7 @@ def _spawn_proc(self, idx, target, callback_loader, socket_loader) -> Worker:
self.blocking_threads,
self.backpressure,
self.threading_mode,
self.task_impl,
self.http,
self.http1_settings,
self.http2_settings,
Expand Down Expand Up @@ -713,5 +728,8 @@ def serve(
logger.error('Workers lifetime cannot be less than 60 seconds')
raise ConfigurationError('workers_lifetime')

if self.task_impl == TaskImpl.auto:
self.task_impl = TaskImpl.asyncio if anyio is not None else TaskImpl.rust

serve_method = self._serve_with_reloader if self.reload_on_changes else self._serve
serve_method(spawn_target, target_loader)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ extend-ignore = [
'S110', # except pass is fine
]
flake8-quotes = { inline-quotes = 'single', multiline-quotes = 'double' }
mccabe = { max-complexity = 13 }
mccabe = { max-complexity = 14 }

[tool.ruff.format]
quote-style = 'single'
Expand Down

0 comments on commit 02401e5

Please sign in to comment.