diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index cc875242..e2f9c11c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -118,7 +118,6 @@ jobs: pip install .[test] pip install granian --no-index --no-deps --find-links pgo_wheel --force-reinstall PGO_RUN=y pytest tests - PGO_RUN=y LOOP_OPT=y pytest tests/test_asgi.py tests/test_rsgi.py - name: merge PGO data run: ${{ env.LLVM_PROFDATA }} merge -o ${{ github.workspace }}/merged.profdata ${{ github.workspace }}/profdata - name: Build PGO wheel diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b5b7928e..f2d30443 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -139,7 +139,6 @@ jobs: pip install .[test] pip install granian --no-index --no-deps --find-links pgo_wheel --force-reinstall PGO_RUN=y pytest tests - PGO_RUN=y LOOP_OPT=y pytest tests/test_asgi.py tests/test_rsgi.py - name: merge PGO data run: ${{ env.LLVM_PROFDATA }} merge -o ${{ github.workspace }}/merged.profdata ${{ github.workspace }}/profdata - name: Build PGO wheel diff --git a/README.md b/README.md index 0bb22b1f..087f68d1 100644 --- a/README.md +++ b/README.md @@ -100,8 +100,6 @@ Options: GRANIAN_THREADING_MODE; default: (workers)] --loop [auto|asyncio|uvloop] Event loop implementation [env var: GRANIAN_LOOP; default: (auto)] - --opt / --no-opt Enable loop optimizations [env var: - GRANIAN_LOOP_OPT; default: (disabled)] --backlog INTEGER RANGE Maximum number of connections to hold in backlog (globally) [env var: GRANIAN_BACKLOG; default: 1024; x>=128] diff --git a/granian/_futures.py b/granian/_futures.py index 2abeee9f..3084630f 100644 --- a/granian/_futures.py +++ b/granian/_futures.py @@ -1,4 +1,9 @@ -def future_watcher_wrapper(inner): +from asyncio.tasks import _enter_task, _leave_task + +from ._granian import CallbackScheduler as _BaseCBScheduler + + +def _future_watcher_wrapper(inner): async def future_watcher(watcher): try: await inner(watcher.scope, watcher.proto) @@ -8,3 +13,59 @@ async def future_watcher(watcher): watcher.done() return future_watcher + + +class _CBScheduler(_BaseCBScheduler): + __slots__ = [] + + def __init__(self, loop, ctx, cb): + super().__init__() + self._schedule_fn = _cbsched_schedule(loop, ctx, self._run, cb) + + def _waker(self, coro): + def _wake(fut): + self._resume(coro, fut) + + return _wake + + def _resume(self, coro, fut): + try: + fut.result() + except BaseException as exc: + self._throw(coro, exc) + else: + self._run(coro) + + def _run(self, coro): + _enter_task(self._loop, self) + try: + try: + result = coro.send(None) + except (KeyboardInterrupt, SystemExit): + raise + except BaseException: + pass + else: + if getattr(result, '_asyncio_future_blocking', None): + result._asyncio_future_blocking = False + result.add_done_callback(self._waker(coro), context=self._ctx) + elif result is None: + self._loop.call_soon(self._run, coro, context=self._ctx) + finally: + _leave_task(self._loop, self) + + def _throw(self, coro, exc): + _enter_task(self._loop, self) + try: + coro.throw(exc) + except BaseException: + pass + finally: + _leave_task(self._loop, self) + + +def _cbsched_schedule(loop, ctx, run, cb): + def _schedule(watcher): + loop.call_soon_threadsafe(run, cb(watcher), context=ctx) + + return _schedule diff --git a/granian/_granian.pyi b/granian/_granian.pyi index 55c2134a..0072f7c7 100644 --- a/granian/_granian.pyi +++ b/granian/_granian.pyi @@ -107,3 +107,7 @@ class ListenerHolder: @classmethod def from_address(cls, address: str, port: int, backlog: int) -> ListenerHolder: ... def get_fd(self) -> Any: ... + +class CallbackScheduler: + _loop: Any + _ctx: Any diff --git a/granian/cli.py b/granian/cli.py index 7060e5f3..bb5cbdb9 100644 --- a/granian/cli.py +++ b/granian/cli.py @@ -77,7 +77,6 @@ def option(*param_decls: str, cls: Optional[Type[click.Option]] = None, **attrs: help='Threading mode to use', ) @option('--loop', type=EnumType(Loops), default=Loops.auto, help='Event loop implementation') -@option('--opt/--no-opt', 'loop_opt', default=False, help='Enable loop optimizations') @option( '--backlog', type=click.IntRange(128), @@ -256,7 +255,6 @@ def cli( blocking_threads: Optional[int], threading_mode: ThreadModes, loop: Loops, - loop_opt: bool, backlog: int, backpressure: Optional[int], http1_buffer_size: int, @@ -311,7 +309,6 @@ def cli( blocking_threads=blocking_threads, threading_mode=threading_mode, loop=loop, - loop_opt=loop_opt, http=http, websockets=websockets, backlog=backlog, diff --git a/granian/server.py b/granian/server.py index bf658476..1e78f38b 100644 --- a/granian/server.py +++ b/granian/server.py @@ -13,7 +13,7 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type -from ._futures import future_watcher_wrapper +from ._futures import _CBScheduler, _future_watcher_wrapper from ._granian import ASGIWorker, RSGIWorker, WSGIWorker from ._imports import setproctitle, watchfiles from ._internal import load_target @@ -78,7 +78,6 @@ def __init__( blocking_threads: Optional[int] = None, threading_mode: ThreadModes = ThreadModes.workers, loop: Loops = Loops.auto, - loop_opt: bool = False, http: HTTPModes = HTTPModes.auto, websockets: bool = True, backlog: int = 1024, @@ -115,7 +114,6 @@ def __init__( self.threads = max(1, threads) self.threading_mode = threading_mode self.loop = loop - self.loop_opt = loop_opt self.http = http self.websockets = websockets self.backlog = max(128, backlog) @@ -189,7 +187,6 @@ def _spawn_asgi_worker( http1_settings: Optional[HTTP1Settings], http2_settings: Optional[HTTP2Settings], websockets: bool, - loop_opt: bool, log_enabled: bool, log_level: LogLevels, log_config: Dict[str, Any], @@ -207,12 +204,8 @@ def _spawn_asgi_worker( loop = loops.get(loop_impl) sfd = socket.fileno() callback = callback_loader() - shutdown_event = set_loop_signals(loop) - wcallback = _asgi_call_wrap(callback, scope_opts, {}, log_access_fmt) - if not loop_opt: - wcallback = future_watcher_wrapper(wcallback) worker = ASGIWorker( worker_id, @@ -224,11 +217,11 @@ def _spawn_asgi_worker( http1_settings, http2_settings, websockets, - loop_opt, *ssl_ctx, ) serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) - serve(wcallback, loop, contextvars.copy_context(), shutdown_event) + scheduler = _CBScheduler(loop, contextvars.copy_context(), _future_watcher_wrapper(wcallback)) + serve(scheduler, loop, shutdown_event) @staticmethod def _spawn_asgi_lifespan_worker( @@ -245,7 +238,6 @@ def _spawn_asgi_lifespan_worker( http1_settings: Optional[HTTP1Settings], http2_settings: Optional[HTTP2Settings], websockets: bool, - loop_opt: bool, log_enabled: bool, log_level: LogLevels, log_config: Dict[str, Any], @@ -271,10 +263,7 @@ def _spawn_asgi_lifespan_worker( sys.exit(1) shutdown_event = set_loop_signals(loop) - wcallback = _asgi_call_wrap(callback, scope_opts, lifespan_handler.state, log_access_fmt) - if not loop_opt: - wcallback = future_watcher_wrapper(wcallback) worker = ASGIWorker( worker_id, @@ -286,11 +275,11 @@ def _spawn_asgi_lifespan_worker( http1_settings, http2_settings, websockets, - loop_opt, *ssl_ctx, ) serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) - serve(wcallback, loop, contextvars.copy_context(), shutdown_event) + scheduler = _CBScheduler(loop, contextvars.copy_context(), _future_watcher_wrapper(wcallback)) + serve(scheduler, loop, shutdown_event) loop.run_until_complete(lifespan_handler.shutdown()) @staticmethod @@ -308,7 +297,6 @@ def _spawn_rsgi_worker( http1_settings: Optional[HTTP1Settings], http2_settings: Optional[HTTP2Settings], websockets: bool, - loop_opt: bool, log_enabled: bool, log_level: LogLevels, log_config: Dict[str, Any], @@ -334,7 +322,6 @@ def _spawn_rsgi_worker( getattr(target, '__rsgi_del__') if hasattr(target, '__rsgi_del__') else lambda *args, **kwargs: None ) callback = _rsgi_call_wrap(callback, log_access_fmt) - shutdown_event = set_loop_signals(loop) callback_init(loop) @@ -348,16 +335,11 @@ def _spawn_rsgi_worker( http1_settings, http2_settings, websockets, - loop_opt, *ssl_ctx, ) serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) - serve( - future_watcher_wrapper(callback) if not loop_opt else callback, - loop, - contextvars.copy_context(), - shutdown_event, - ) + scheduler = _CBScheduler(loop, contextvars.copy_context(), _future_watcher_wrapper(callback)) + serve(scheduler, loop, shutdown_event) callback_del(loop) @staticmethod @@ -375,7 +357,6 @@ def _spawn_wsgi_worker( http1_settings: Optional[HTTP1Settings], http2_settings: Optional[HTTP2Settings], websockets: bool, - loop_opt: bool, log_enabled: bool, log_level: LogLevels, log_config: Dict[str, Any], @@ -393,14 +374,16 @@ def _spawn_wsgi_worker( loop = loops.get(loop_impl) sfd = socket.fileno() callback = callback_loader() - shutdown_event = set_sync_signals() worker = WSGIWorker( worker_id, sfd, threads, blocking_threads, backpressure, http_mode, http1_settings, http2_settings, *ssl_ctx ) serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) - serve(_wsgi_call_wrap(callback, scope_opts, log_access_fmt), loop, contextvars.copy_context(), shutdown_event) + scheduler = _CBScheduler( + loop, contextvars.copy_context(), _wsgi_call_wrap(callback, scope_opts, log_access_fmt) + ) + serve(scheduler, loop, shutdown_event) shutdown_event.qs.wait() def _init_shared_socket(self): @@ -434,7 +417,6 @@ def _spawn_proc(self, idx, target, callback_loader, socket_loader) -> Worker: self.http1_settings, self.http2_settings, self.websockets, - self.loop_opt, self.log_enabled, self.log_level, self.log_config, diff --git a/pyproject.toml b/pyproject.toml index c80fb897..fb080770 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ test = [ 'httpx~=0.25.0', 'pytest~=7.4.2', 'pytest-asyncio~=0.21.1', + 'sniffio~=1.3', 'websockets~=11.0', ] all = ['granian[pname,reload]'] @@ -95,6 +96,7 @@ extend-ignore = [ 'E501', # leave line length to black 'N818', # leave to us exceptions naming 'S101', # assert is fine + 'S110', # except pass is fine ] flake8-quotes = { inline-quotes = 'single', multiline-quotes = 'double' } mccabe = { max-complexity = 13 } diff --git a/src/asgi/callbacks.rs b/src/asgi/callbacks.rs index 300a60f3..d2cc696e 100644 --- a/src/asgi/callbacks.rs +++ b/src/asgi/callbacks.rs @@ -1,9 +1,6 @@ use pyo3::prelude::*; use pyo3::types::PyDict; -use std::{ - net::SocketAddr, - sync::{Arc, Mutex}, -}; +use std::{net::SocketAddr, sync::Arc}; use tokio::sync::oneshot; use super::{ @@ -11,50 +8,13 @@ use super::{ utils::{build_scope_http, build_scope_ws, scope_native_parts}, }; use crate::{ - asyncio::PyContext, - callbacks::{ - callback_impl_loop_err, callback_impl_loop_pytask, callback_impl_loop_run, callback_impl_loop_step, - callback_impl_loop_wake, callback_impl_run, callback_impl_run_pytask, CallbackWrapper, - }, + callbacks::ArcCBScheduler, http::{response_500, HTTPResponse}, runtime::RuntimeRef, utils::log_application_callable_exception, ws::{HyperWebsocket, UpgradeData}, }; -#[pyclass(frozen)] -pub(crate) struct CallbackRunnerHTTP { - proto: Py, - context: PyContext, - cb: PyObject, -} - -impl CallbackRunnerHTTP { - pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Bound) -> Self { - let pyproto = Py::new(py, proto).unwrap(); - Self { - proto: pyproto.clone_ref(py), - context: cb.context, - cb: cb.callback.call1(py, (scope, pyproto)).unwrap(), - } - } - - callback_impl_run!(); -} - -#[pymethods] -impl CallbackRunnerHTTP { - fn _loop_task<'p>(&self, py: Python<'p>) -> PyResult> { - CallbackTaskHTTP::new( - py, - self.cb.clone_ref(py), - self.proto.clone_ref(py), - self.context.clone(), - )? - .run(py) - } -} - macro_rules! callback_impl_done_http { ($self:expr) => { if let Some(tx) = $self.proto.get().tx() { @@ -63,6 +23,14 @@ macro_rules! callback_impl_done_http { }; } +macro_rules! callback_impl_done_ws { + ($self:expr) => { + if let (Some(tx), res) = $self.proto.get().tx() { + let _ = tx.send(res); + } + }; +} + macro_rules! callback_impl_done_err { ($self:expr, $err:expr) => { $self.done(); @@ -71,211 +39,58 @@ macro_rules! callback_impl_done_err { } #[pyclass(frozen)] -pub(crate) struct CallbackTaskHTTP { - proto: Py, - context: PyContext, - pycontext: PyObject, - cb: PyObject, -} - -impl CallbackTaskHTTP { - pub fn new(py: Python, cb: PyObject, proto: Py, context: PyContext) -> PyResult { - let pyctx = context.context(py); - Ok(Self { - proto, - context, - pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), - cb, - }) - } - - fn done(&self) { - callback_impl_done_http!(self); - } - - fn err(&self, err: &PyErr) { - callback_impl_done_err!(self, err); - } - - callback_impl_loop_run!(); - callback_impl_loop_err!(); -} - -#[pymethods] -impl CallbackTaskHTTP { - fn _loop_step(pyself: PyRef<'_, Self>, py: Python) -> PyResult<()> { - callback_impl_loop_step!(pyself, py) - } - - fn _loop_wake(pyself: PyRef<'_, Self>, py: Python, fut: PyObject) -> PyResult { - callback_impl_loop_wake!(pyself, py, fut) - } -} - -#[pyclass(frozen)] -pub(crate) struct CallbackWrappedRunnerHTTP { +pub(crate) struct CallbackWatcherHTTP { #[pyo3(get)] proto: Py, - context: PyContext, - cb: PyObject, #[pyo3(get)] scope: PyObject, - pytaskref: Arc>>, } -impl CallbackWrappedRunnerHTTP { - pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Bound) -> Self { +impl CallbackWatcherHTTP { + pub fn new(py: Python, proto: HTTPProtocol, scope: Bound) -> Self { Self { proto: Py::new(py, proto).unwrap(), - context: cb.context, - cb: cb.callback.clone_ref(py), scope: scope.into_py(py), - pytaskref: Arc::new(Mutex::new(None)), } } - - callback_impl_run_pytask!(); } #[pymethods] -impl CallbackWrappedRunnerHTTP { - fn _loop_task<'p>(pyself: PyRef<'_, Self>, py: Python<'p>) -> PyResult> { - callback_impl_loop_pytask!(pyself, py) - } - +impl CallbackWatcherHTTP { fn done(&self) { callback_impl_done_http!(self); - self.pytaskref.lock().unwrap().take(); } fn err(&self, err: Bound) { callback_impl_done_err!(self, &PyErr::from_value_bound(err)); - self.pytaskref.lock().unwrap().take(); - } -} - -#[pyclass(frozen)] -pub(crate) struct CallbackRunnerWebsocket { - proto: Py, - context: PyContext, - cb: PyObject, -} - -impl CallbackRunnerWebsocket { - pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Bound) -> Self { - let pyproto = Py::new(py, proto).unwrap(); - Self { - proto: pyproto.clone_ref(py), - context: cb.context, - cb: cb.callback.call1(py, (scope, pyproto)).unwrap(), - } } - - callback_impl_run!(); -} - -#[pymethods] -impl CallbackRunnerWebsocket { - fn _loop_task<'p>(&self, py: Python<'p>) -> PyResult> { - CallbackTaskWebsocket::new( - py, - self.cb.clone_ref(py), - self.proto.clone_ref(py), - self.context.clone(), - )? - .run(py) - } -} - -macro_rules! callback_impl_done_ws { - ($self:expr) => { - if let (Some(tx), res) = $self.proto.get().tx() { - let _ = tx.send(res); - } - }; } #[pyclass(frozen)] -pub(crate) struct CallbackTaskWebsocket { - proto: Py, - context: PyContext, - pycontext: PyObject, - cb: PyObject, -} - -impl CallbackTaskWebsocket { - pub fn new(py: Python, cb: PyObject, proto: Py, context: PyContext) -> PyResult { - let pyctx = context.context(py); - Ok(Self { - proto, - context, - pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), - cb, - }) - } - - fn done(&self) { - callback_impl_done_ws!(self); - } - - fn err(&self, err: &PyErr) { - callback_impl_done_err!(self, err); - } - - callback_impl_loop_run!(); - callback_impl_loop_err!(); -} - -#[pymethods] -impl CallbackTaskWebsocket { - fn _loop_step(pyself: PyRef<'_, Self>, py: Python) -> PyResult<()> { - callback_impl_loop_step!(pyself, py) - } - - fn _loop_wake(pyself: PyRef<'_, Self>, py: Python, fut: PyObject) -> PyResult { - callback_impl_loop_wake!(pyself, py, fut) - } -} - -#[pyclass(frozen)] -pub(crate) struct CallbackWrappedRunnerWebsocket { +pub(crate) struct CallbackWatcherWebsocket { #[pyo3(get)] proto: Py, - context: PyContext, - cb: PyObject, #[pyo3(get)] scope: PyObject, - pytaskref: Arc>>, } -impl CallbackWrappedRunnerWebsocket { - pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Bound) -> Self { +impl CallbackWatcherWebsocket { + pub fn new(py: Python, proto: WebsocketProtocol, scope: Bound) -> Self { Self { proto: Py::new(py, proto).unwrap(), - context: cb.context, - cb: cb.callback.clone_ref(py), scope: scope.into_py(py), - pytaskref: Arc::new(Mutex::new(None)), } } - - callback_impl_run_pytask!(); } #[pymethods] -impl CallbackWrappedRunnerWebsocket { - fn _loop_task<'p>(pyself: PyRef<'_, Self>, py: Python<'p>) -> PyResult> { - callback_impl_loop_pytask!(pyself, py) - } - +impl CallbackWatcherWebsocket { fn done(&self) { callback_impl_done_ws!(self); - self.pytaskref.lock().unwrap().take(); } fn err(&self, err: Bound) { callback_impl_done_err!(self, &PyErr::from_value_bound(err)); - self.pytaskref.lock().unwrap().take(); } } @@ -302,88 +117,75 @@ impl CallbackWrappedRunnerWebsocket { // } // } -macro_rules! call_impl_http { - ($func_name:ident, $runner:ident) => { - #[inline] - pub(crate) fn $func_name( - cb: CallbackWrapper, - rt: RuntimeRef, - server_addr: SocketAddr, - client_addr: SocketAddr, - scheme: &str, - req: hyper::http::request::Parts, - body: hyper::body::Incoming, - ) -> oneshot::Receiver { - let brt = rt.innerb.clone(); - let (tx, rx) = oneshot::channel(); - let protocol = HTTPProtocol::new(rt, body, tx); - let scheme: Arc = scheme.into(); - - let _ = brt.run(move || { - scope_native_parts!( - req, - server_addr, - client_addr, - path, - query_string, - version, - server, - client - ); - Python::with_gil(|py| { - let scope = - build_scope_http(py, &req, version, server, client, &scheme, &path, query_string).unwrap(); - let _ = $runner::new(py, cb, protocol, scope).run(py); - }); - }); - - rx - } - }; -} - -macro_rules! call_impl_ws { - ($func_name:ident, $runner:ident) => { - #[inline] - pub(crate) fn $func_name( - cb: CallbackWrapper, - rt: RuntimeRef, - server_addr: SocketAddr, - client_addr: SocketAddr, - scheme: &str, - ws: HyperWebsocket, - req: hyper::http::request::Parts, - upgrade: UpgradeData, - ) -> oneshot::Receiver { - let brt = rt.innerb.clone(); - let (tx, rx) = oneshot::channel(); - let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); - let scheme: Arc = scheme.into(); - - let _ = brt.run(move || { - scope_native_parts!( - req, - server_addr, - client_addr, - path, - query_string, - version, - server, - client - ); - Python::with_gil(|py| { - let scope = - build_scope_ws(py, &req, version, server, client, &scheme, &path, query_string).unwrap(); - let _ = $runner::new(py, cb, protocol, scope).run(py); - }); - }); - - rx - } - }; +#[inline] +pub(crate) fn call_http( + cb: ArcCBScheduler, + rt: RuntimeRef, + server_addr: SocketAddr, + client_addr: SocketAddr, + scheme: &str, + req: hyper::http::request::Parts, + body: hyper::body::Incoming, +) -> oneshot::Receiver { + let brt = rt.innerb.clone(); + let (tx, rx) = oneshot::channel(); + let protocol = HTTPProtocol::new(rt, body, tx); + let scheme: Arc = scheme.into(); + + let _ = brt.run(move || { + scope_native_parts!( + req, + server_addr, + client_addr, + path, + query_string, + version, + server, + client + ); + Python::with_gil(|py| { + let scope = build_scope_http(py, &req, version, server, client, &scheme, &path, query_string).unwrap(); + let watcher = Py::new(py, CallbackWatcherHTTP::new(py, protocol, scope)).unwrap(); + cb.get().schedule(py, watcher.as_any()); + }); + }); + + rx +} + +#[inline] +pub(crate) fn call_ws( + cb: ArcCBScheduler, + rt: RuntimeRef, + server_addr: SocketAddr, + client_addr: SocketAddr, + scheme: &str, + ws: HyperWebsocket, + req: hyper::http::request::Parts, + upgrade: UpgradeData, +) -> oneshot::Receiver { + let brt = rt.innerb.clone(); + let (tx, rx) = oneshot::channel(); + let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); + let scheme: Arc = scheme.into(); + + let _ = brt.run(move || { + scope_native_parts!( + req, + server_addr, + client_addr, + path, + query_string, + version, + server, + client + ); + Python::with_gil(|py| { + let scope = build_scope_ws(py, &req, version, server, client, &scheme, &path, query_string).unwrap(); + let watcher = Py::new(py, CallbackWatcherWebsocket::new(py, protocol, scope)).unwrap(); + cb.get().schedule(py, watcher.as_any()); + }); + }); + + rx } - -call_impl_http!(call_http, CallbackRunnerHTTP); -call_impl_http!(call_http_pyw, CallbackWrappedRunnerHTTP); -call_impl_ws!(call_ws, CallbackRunnerWebsocket); -call_impl_ws!(call_ws_pyw, CallbackWrappedRunnerWebsocket); diff --git a/src/asgi/http.rs b/src/asgi/http.rs index 2aed8cbe..40c168c4 100644 --- a/src/asgi/http.rs +++ b/src/asgi/http.rs @@ -3,9 +3,9 @@ use hyper::{header::SERVER as HK_SERVER, http::response::Builder as ResponseBuil use std::net::SocketAddr; use tokio::sync::mpsc; -use super::callbacks::{call_http, call_http_pyw, call_ws, call_ws_pyw}; +use super::callbacks::{call_http, call_ws}; use crate::{ - callbacks::CallbackWrapper, + callbacks::ArcCBScheduler, http::{empty_body, response_500, HTTPRequest, HTTPResponse, HV_SERVER}, runtime::RuntimeRef, ws::{is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade, UpgradeData}, @@ -32,7 +32,7 @@ macro_rules! handle_request { #[inline] pub(crate) async fn $func_name( rt: RuntimeRef, - callback: CallbackWrapper, + callback: ArcCBScheduler, server_addr: SocketAddr, client_addr: SocketAddr, req: HTTPRequest, @@ -58,7 +58,7 @@ macro_rules! handle_request_with_ws { #[inline] pub(crate) async fn $func_name( rt: RuntimeRef, - callback: CallbackWrapper, + callback: ArcCBScheduler, server_addr: SocketAddr, client_addr: SocketAddr, mut req: HTTPRequest, @@ -154,10 +154,4 @@ macro_rules! handle_request_with_ws { } handle_request!(handle, call_http); -// handle_request!(handle_rtb, call_rtb_http); -handle_request!(handle_pyw, call_http_pyw); -// handle_request!(handle_rtb_pyw, call_rtb_http_pyw); handle_request_with_ws!(handle_ws, call_http, call_ws); -// handle_request_with_ws!(handle_rtb_ws, call_rtb_http, call_rtb_ws); -handle_request_with_ws!(handle_ws_pyw, call_http_pyw, call_ws_pyw); -// handle_request_with_ws!(handle_rtb_ws_pyw, call_rtb_http_pyw, call_rtb_ws_pyw); diff --git a/src/asgi/serve.rs b/src/asgi/serve.rs index 4c8b9c14..dd7d1380 100644 --- a/src/asgi/serve.rs +++ b/src/asgi/serve.rs @@ -1,7 +1,8 @@ use pyo3::prelude::*; -use super::http::{handle, handle_pyw, handle_ws, handle_ws_pyw}; +use super::http::{handle, handle_ws}; +use crate::callbacks::CallbackScheduler; use crate::conversion::{worker_http1_config_from_py, worker_http2_config_from_py}; use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig, WorkerSignal, WorkerSignals}; @@ -12,21 +13,13 @@ pub struct ASGIWorker { impl ASGIWorker { serve_rth!(_serve_rth, handle); - serve_rth!(_serve_rth_pyw, handle_pyw); serve_rth!(_serve_rth_ws, handle_ws); - serve_rth!(_serve_rth_ws_pyw, handle_ws_pyw); serve_wth!(_serve_wth, handle); - serve_wth!(_serve_wth_pyw, handle_pyw); serve_wth!(_serve_wth_ws, handle_ws); - serve_wth!(_serve_wth_ws_pyw, handle_ws_pyw); serve_rth_ssl!(_serve_rth_ssl, handle); - serve_rth_ssl!(_serve_rth_ssl_pyw, handle_pyw); serve_rth_ssl!(_serve_rth_ssl_ws, handle_ws); - serve_rth_ssl!(_serve_rth_ssl_ws_pyw, handle_ws_pyw); serve_wth_ssl!(_serve_wth_ssl, handle); - serve_wth_ssl!(_serve_wth_ssl_pyw, handle_pyw); serve_wth_ssl!(_serve_wth_ssl_ws, handle_ws); - serve_wth_ssl!(_serve_wth_ssl_ws_pyw, handle_ws_pyw); } #[pymethods] @@ -43,7 +36,6 @@ impl ASGIWorker { http1_opts=None, http2_opts=None, websockets_enabled=false, - opt_enabled=true, ssl_enabled=false, ssl_cert=None, ssl_key=None, @@ -61,7 +53,6 @@ impl ASGIWorker { http1_opts: Option, http2_opts: Option, websockets_enabled: bool, - opt_enabled: bool, ssl_enabled: bool, ssl_cert: Option<&str>, ssl_key: Option<&str>, @@ -78,7 +69,6 @@ impl ASGIWorker { worker_http1_config_from_py(py, http1_opts)?, worker_http2_config_from_py(py, http2_opts)?, websockets_enabled, - opt_enabled, ssl_enabled, ssl_cert, ssl_key, @@ -87,57 +77,21 @@ impl ASGIWorker { }) } - fn serve_rth( - &self, - callback: PyObject, - event_loop: &Bound, - context: Bound, - signal: Py, - ) { - match ( - self.config.websockets_enabled, - self.config.ssl_enabled, - self.config.opt_enabled, - ) { - (false, false, true) => self._serve_rth(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (false, false, false) => self._serve_rth_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (true, false, true) => self._serve_rth_ws(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (true, false, false) => self._serve_rth_ws_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (false, true, true) => self._serve_rth_ssl(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (false, true, false) => { - self._serve_rth_ssl_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)); - } - (true, true, true) => self._serve_rth_ssl_ws(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (true, true, false) => { - self._serve_rth_ssl_ws_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)); - } + fn serve_rth(&self, callback: Py, event_loop: &Bound, signal: Py) { + match (self.config.websockets_enabled, self.config.ssl_enabled) { + (false, false) => self._serve_rth(callback, event_loop, WorkerSignals::Tokio(signal)), + (true, false) => self._serve_rth_ws(callback, event_loop, WorkerSignals::Tokio(signal)), + (false, true) => self._serve_rth_ssl(callback, event_loop, WorkerSignals::Tokio(signal)), + (true, true) => self._serve_rth_ssl_ws(callback, event_loop, WorkerSignals::Tokio(signal)), } } - fn serve_wth( - &self, - callback: PyObject, - event_loop: &Bound, - context: Bound, - signal: Py, - ) { - match ( - self.config.websockets_enabled, - self.config.ssl_enabled, - self.config.opt_enabled, - ) { - (false, false, true) => self._serve_wth(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (false, false, false) => self._serve_wth_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (true, false, true) => self._serve_wth_ws(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (true, false, false) => self._serve_wth_ws_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (false, true, true) => self._serve_wth_ssl(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (false, true, false) => { - self._serve_wth_ssl_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)); - } - (true, true, true) => self._serve_wth_ssl_ws(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (true, true, false) => { - self._serve_wth_ssl_ws_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)); - } + fn serve_wth(&self, callback: Py, event_loop: &Bound, signal: Py) { + match (self.config.websockets_enabled, self.config.ssl_enabled) { + (false, false) => self._serve_wth(callback, event_loop, WorkerSignals::Tokio(signal)), + (true, false) => self._serve_wth_ws(callback, event_loop, WorkerSignals::Tokio(signal)), + (false, true) => self._serve_wth_ssl(callback, event_loop, WorkerSignals::Tokio(signal)), + (true, true) => self._serve_wth_ssl_ws(callback, event_loop, WorkerSignals::Tokio(signal)), } } } diff --git a/src/asyncio.rs b/src/asyncio.rs index 7767ed9d..f6b1b38a 100644 --- a/src/asyncio.rs +++ b/src/asyncio.rs @@ -1,46 +1,16 @@ use pyo3::{prelude::*, sync::GILOnceCell}; -use std::{convert::Into, sync::Arc}; +use std::convert::Into; static CONTEXTVARS: GILOnceCell = GILOnceCell::new(); static CONTEXT: GILOnceCell = GILOnceCell::new(); -#[derive(Clone, Debug)] -pub struct PyContext { - event_loop: Arc, - context: Arc, -} - -impl PyContext { - pub fn new(event_loop: Bound) -> Self { - let pynone = event_loop.py().None(); - Self { - event_loop: Arc::new(event_loop.unbind()), - context: Arc::new(pynone), - } - } - - pub fn with_context(self, context: Bound) -> Self { - Self { - context: Arc::new(context.unbind()), - ..self - } - } - - pub fn event_loop<'p>(&self, py: Python<'p>) -> Bound<'p, PyAny> { - self.event_loop.clone_ref(py).into_bound(py) - } - - pub fn context<'p>(&self, py: Python<'p>) -> Bound<'p, PyAny> { - self.context.clone_ref(py).into_bound(py) - } -} - fn contextvars(py: Python) -> PyResult<&Bound> { Ok(CONTEXTVARS .get_or_try_init(py, || py.import_bound("contextvars").map(Into::into))? .bind(py)) } +#[allow(dead_code)] pub(crate) fn empty_context(py: Python) -> PyResult<&Bound> { Ok(CONTEXT .get_or_try_init(py, || { diff --git a/src/callbacks.rs b/src/callbacks.rs index 07a6bfb4..2944fbff 100644 --- a/src/callbacks.rs +++ b/src/callbacks.rs @@ -3,21 +3,48 @@ use pyo3::{exceptions::PyStopIteration, prelude::*}; use std::sync::{atomic, Arc, RwLock}; use tokio::sync::Notify; -use super::asyncio::PyContext; +pub(crate) type ArcCBScheduler = Arc>; + +#[pyclass(frozen, subclass)] +pub(crate) struct CallbackScheduler { + #[pyo3(get)] + _loop: PyObject, + #[pyo3(get)] + _ctx: PyObject, + schedule_fn: Arc>, + pub cb: PyObject, +} -#[derive(Clone)] -pub(crate) struct CallbackWrapper { - pub callback: Arc, - pub context: PyContext, +impl CallbackScheduler { + #[inline] + pub(crate) fn schedule(&self, _py: Python, watcher: &PyObject) { + // // let cb = self.cb.as_ptr(); + let cbarg = watcher.as_ptr(); + let sched = self.schedule_fn.read().unwrap().as_ptr(); + unsafe { + // let coro = pyo3::ffi::PyObject_CallOneArg(cb, cbarg); + pyo3::ffi::PyObject_CallOneArg(sched, cbarg); + } + } } -impl CallbackWrapper { - pub(crate) fn new(callback: PyObject, event_loop: Bound, context: Bound) -> Self { +#[pymethods] +impl CallbackScheduler { + #[new] + fn new(py: Python, event_loop: PyObject, ctx: PyObject, cb: PyObject) -> Self { Self { - callback: Arc::new(callback), - context: PyContext::new(event_loop).with_context(context), + _loop: event_loop, + _ctx: ctx, + schedule_fn: Arc::new(RwLock::new(py.None())), + cb, } } + + #[setter(_schedule_fn)] + fn _set_schedule_fn(&self, val: PyObject) { + let mut guard = self.schedule_fn.write().unwrap(); + *guard = val; + } } #[pyclass(frozen)] @@ -284,135 +311,3 @@ impl PyFutureResultSetter { let _ = target.call1((value,)); } } - -macro_rules! callback_impl_run { - () => { - pub fn run(self, py: Python<'_>) -> PyResult> { - let event_loop = self.context.event_loop(py); - let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?; - let kwctx = pyo3::types::PyDict::new_bound(py); - kwctx.set_item(pyo3::intern!(py, "context"), crate::asyncio::empty_context(py)?)?; - event_loop.call_method(pyo3::intern!(py, "call_soon_threadsafe"), (target,), Some(&kwctx)) - } - }; -} - -macro_rules! callback_impl_run_pytask { - () => { - pub fn run(self, py: Python<'_>) -> PyResult> { - let taskref = self.pytaskref.clone(); - let event_loop = self.context.event_loop(py); - let context = self.context.context(py); - let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?; - let kwctx = pyo3::types::PyDict::new_bound(py); - { - let mut taskref_guard = taskref.lock().unwrap(); - *taskref_guard = Some(target.clone_ref(py)); - } - kwctx.set_item(pyo3::intern!(py, "context"), context)?; - event_loop.call_method(pyo3::intern!(py, "call_soon_threadsafe"), (target,), Some(&kwctx)) - } - }; -} - -macro_rules! callback_impl_loop_run { - () => { - pub fn run(self, py: Python<'_>) -> PyResult> { - let context = self.pycontext.clone_ref(py).into_bound(py); - context.call_method1( - pyo3::intern!(py, "run"), - (self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_step"))?,), - ) - } - }; -} - -macro_rules! callback_impl_loop_pytask { - ($pyself:expr, $py:expr) => { - $pyself.context.event_loop($py).call_method1( - pyo3::intern!($py, "create_task"), - ($pyself - .cb - .clone_ref($py) - .into_bound($py) - .call1(($pyself.into_py($py),))?,), - ) - }; -} - -macro_rules! callback_impl_loop_step { - ($pyself:expr, $py:expr) => { - match $pyself.cb.call_method1($py, pyo3::intern!($py, "send"), ($py.None(),)) { - Ok(res) => { - let blocking: bool = match res.getattr($py, pyo3::intern!($py, "_asyncio_future_blocking")) { - Ok(v) => v.extract($py)?, - _ => false, - }; - - let ctx = $pyself.pycontext.clone_ref($py); - let kwctx = pyo3::types::PyDict::new_bound($py); - kwctx.set_item(pyo3::intern!($py, "context"), ctx)?; - - match blocking { - true => { - res.setattr($py, pyo3::intern!($py, "_asyncio_future_blocking"), false)?; - res.call_method_bound( - $py, - pyo3::intern!($py, "add_done_callback"), - ($pyself.into_py($py).getattr($py, pyo3::intern!($py, "_loop_wake"))?,), - Some(&kwctx), - )?; - } - false => { - let event_loop = $pyself.context.event_loop($py); - event_loop.call_method( - pyo3::intern!($py, "call_soon"), - ($pyself.into_py($py).getattr($py, pyo3::intern!($py, "_loop_step"))?,), - Some(&kwctx), - )?; - } - } - Ok(()) - } - Err(err) - if (err.is_instance_of::($py) - || err.is_instance_of::($py) - || err.is_instance_of::($py)) => - { - $pyself.done(); - Ok(()) - } - Err(err) => { - $pyself.err(&err); - Ok(()) - } - } - }; -} - -macro_rules! callback_impl_loop_wake { - ($pyself:expr, $py:expr, $fut:expr) => { - match $fut.call_method0($py, pyo3::intern!($py, "result")) { - Ok(_) => $pyself.into_py($py).call_method0($py, pyo3::intern!($py, "_loop_step")), - Err(err) => $pyself._loop_err($py, err), - } - }; -} - -macro_rules! callback_impl_loop_err { - () => { - pub fn _loop_err(&self, py: Python, err: PyErr) -> PyResult { - self.err(&err); - let cberr = self.cb.call_method1(py, pyo3::intern!(py, "throw"), (err,)); - cberr - } - }; -} - -pub(crate) use callback_impl_loop_err; -pub(crate) use callback_impl_loop_pytask; -pub(crate) use callback_impl_loop_run; -pub(crate) use callback_impl_loop_step; -pub(crate) use callback_impl_loop_wake; -pub(crate) use callback_impl_run; -pub(crate) use callback_impl_run_pytask; diff --git a/src/lib.rs b/src/lib.rs index d639101b..b26a7349 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,6 +37,7 @@ pub fn get_granian_version() -> &'static str { #[pymodule] fn _granian(py: Python, module: &Bound) -> PyResult<()> { module.add("__version__", get_granian_version())?; + module.add_class::()?; asgi::init_pymodule(module)?; rsgi::init_pymodule(py, module)?; tcp::init_pymodule(module)?; diff --git a/src/rsgi/callbacks.rs b/src/rsgi/callbacks.rs index bdc61f2e..b67081b7 100644 --- a/src/rsgi/callbacks.rs +++ b/src/rsgi/callbacks.rs @@ -1,5 +1,4 @@ use pyo3::prelude::*; -use std::sync::{Arc, Mutex}; use tokio::sync::oneshot; use super::{ @@ -7,49 +6,12 @@ use super::{ types::{PyResponse, PyResponseBody, RSGIHTTPScope as HTTPScope, RSGIWebsocketScope as WebsocketScope}, }; use crate::{ - asyncio::PyContext, - callbacks::{ - callback_impl_loop_err, callback_impl_loop_pytask, callback_impl_loop_run, callback_impl_loop_step, - callback_impl_loop_wake, callback_impl_run, callback_impl_run_pytask, CallbackWrapper, - }, + callbacks::ArcCBScheduler, runtime::RuntimeRef, utils::log_application_callable_exception, ws::{HyperWebsocket, UpgradeData}, }; -#[pyclass(frozen)] -pub(crate) struct CallbackRunnerHTTP { - proto: Py, - context: PyContext, - cb: PyObject, -} - -impl CallbackRunnerHTTP { - pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: HTTPScope) -> Self { - let pyproto = Py::new(py, proto).unwrap(); - Self { - proto: pyproto.clone_ref(py), - context: cb.context, - cb: cb.callback.call1(py, (scope, pyproto)).unwrap(), - } - } - - callback_impl_run!(); -} - -#[pymethods] -impl CallbackRunnerHTTP { - fn _loop_task<'p>(&self, py: Python<'p>) -> PyResult> { - CallbackTaskHTTP::new( - py, - self.cb.clone_ref(py), - self.proto.clone_ref(py), - self.context.clone(), - )? - .run(py) - } -} - macro_rules! callback_impl_done_http { ($self:expr) => { if let Some(tx) = $self.proto.get().tx() { @@ -58,6 +20,12 @@ macro_rules! callback_impl_done_http { }; } +macro_rules! callback_impl_done_ws { + ($self:expr) => { + let _ = $self.proto.get().close(None); + }; +} + macro_rules! callback_impl_done_err { ($self:expr, $err:expr) => { $self.done(); @@ -66,262 +34,100 @@ macro_rules! callback_impl_done_err { } #[pyclass(frozen)] -pub(crate) struct CallbackTaskHTTP { - proto: Py, - context: PyContext, - pycontext: PyObject, - cb: PyObject, -} - -impl CallbackTaskHTTP { - pub fn new(py: Python, cb: PyObject, proto: Py, context: PyContext) -> PyResult { - let pyctx = context.context(py); - Ok(Self { - proto, - context, - pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), - cb, - }) - } - - fn done(&self) { - callback_impl_done_http!(self); - } - - fn err(&self, err: &PyErr) { - callback_impl_done_err!(self, err); - } - - callback_impl_loop_run!(); - callback_impl_loop_err!(); -} - -#[pymethods] -impl CallbackTaskHTTP { - fn _loop_step(pyself: PyRef<'_, Self>, py: Python) -> PyResult<()> { - callback_impl_loop_step!(pyself, py) - } - - fn _loop_wake(pyself: PyRef<'_, Self>, py: Python, fut: PyObject) -> PyResult { - callback_impl_loop_wake!(pyself, py, fut) - } -} - -#[pyclass(frozen)] -pub(crate) struct CallbackWrappedRunnerHTTP { +pub(crate) struct CallbackWatcherHTTP { #[pyo3(get)] proto: Py, - context: PyContext, - cb: PyObject, #[pyo3(get)] scope: PyObject, - pytaskref: Arc>>, } -impl CallbackWrappedRunnerHTTP { - pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: HTTPScope) -> Self { +impl CallbackWatcherHTTP { + pub fn new(py: Python, proto: HTTPProtocol, scope: HTTPScope) -> Self { Self { proto: Py::new(py, proto).unwrap(), - context: cb.context, - cb: cb.callback.clone_ref(py), scope: scope.into_py(py), - pytaskref: Arc::new(Mutex::new(None)), } } - - callback_impl_run_pytask!(); } #[pymethods] -impl CallbackWrappedRunnerHTTP { - fn _loop_task<'p>(pyself: PyRef<'_, Self>, py: Python<'p>) -> PyResult> { - callback_impl_loop_pytask!(pyself, py) - } - +impl CallbackWatcherHTTP { fn done(&self) { callback_impl_done_http!(self); - self.pytaskref.lock().unwrap().take(); } fn err(&self, err: Bound) { callback_impl_done_err!(self, &PyErr::from_value_bound(err)); - self.pytaskref.lock().unwrap().take(); } } #[pyclass(frozen)] -pub(crate) struct CallbackRunnerWebsocket { - proto: Py, - context: PyContext, - cb: PyObject, -} - -impl CallbackRunnerWebsocket { - pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: WebsocketScope) -> Self { - let pyproto = Py::new(py, proto).unwrap(); - Self { - proto: pyproto.clone_ref(py), - context: cb.context, - cb: cb.callback.call1(py, (scope, pyproto)).unwrap(), - } - } - - callback_impl_run!(); -} - -#[pymethods] -impl CallbackRunnerWebsocket { - fn _loop_task<'p>(&self, py: Python<'p>) -> PyResult> { - CallbackTaskWebsocket::new( - py, - self.cb.clone_ref(py), - self.proto.clone_ref(py), - self.context.clone(), - )? - .run(py) - } -} - -macro_rules! callback_impl_done_ws { - ($self:expr) => { - let _ = $self.proto.get().close(None); - }; -} - -#[pyclass(frozen)] -pub(crate) struct CallbackTaskWebsocket { - proto: Py, - context: PyContext, - pycontext: PyObject, - cb: PyObject, -} - -impl CallbackTaskWebsocket { - pub fn new(py: Python, cb: PyObject, proto: Py, context: PyContext) -> PyResult { - let pyctx = context.context(py); - Ok(Self { - proto, - context, - pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), - cb, - }) - } - - fn done(&self) { - callback_impl_done_ws!(self); - } - - fn err(&self, err: &PyErr) { - callback_impl_done_err!(self, err); - } - - callback_impl_loop_run!(); - callback_impl_loop_err!(); -} - -#[pymethods] -impl CallbackTaskWebsocket { - fn _loop_step(pyself: PyRef<'_, Self>, py: Python) -> PyResult<()> { - callback_impl_loop_step!(pyself, py) - } - - fn _loop_wake(pyself: PyRef<'_, Self>, py: Python, fut: PyObject) -> PyResult { - callback_impl_loop_wake!(pyself, py, fut) - } -} - -#[pyclass(frozen)] -pub(crate) struct CallbackWrappedRunnerWebsocket { +pub(crate) struct CallbackWatcherWebsocket { #[pyo3(get)] proto: Py, - context: PyContext, - cb: PyObject, #[pyo3(get)] scope: PyObject, - pytaskref: Arc>>, } -impl CallbackWrappedRunnerWebsocket { - pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: WebsocketScope) -> Self { +impl CallbackWatcherWebsocket { + pub fn new(py: Python, proto: WebsocketProtocol, scope: WebsocketScope) -> Self { Self { proto: Py::new(py, proto).unwrap(), - context: cb.context, - cb: cb.callback.clone_ref(py), scope: scope.into_py(py), - pytaskref: Arc::new(Mutex::new(None)), } } - - callback_impl_run_pytask!(); } #[pymethods] -impl CallbackWrappedRunnerWebsocket { - fn _loop_task<'p>(pyself: PyRef<'_, Self>, py: Python<'p>) -> PyResult> { - callback_impl_loop_pytask!(pyself, py) - } - +impl CallbackWatcherWebsocket { fn done(&self) { callback_impl_done_ws!(self); - self.pytaskref.lock().unwrap().take(); } fn err(&self, err: Bound) { callback_impl_done_err!(self, &PyErr::from_value_bound(err)); - self.pytaskref.lock().unwrap().take(); } } -macro_rules! call_impl_http { - ($func_name:ident, $runner:ident) => { - #[inline] - pub(crate) fn $func_name( - cb: CallbackWrapper, - rt: RuntimeRef, - body: hyper::body::Incoming, - scope: HTTPScope, - ) -> oneshot::Receiver { - let brt = rt.innerb.clone(); - let (tx, rx) = oneshot::channel(); - let protocol = HTTPProtocol::new(rt, tx, body); - - let _ = brt.run(|| { - Python::with_gil(|py| { - let _ = $runner::new(py, cb, protocol, scope).run(py); - }); - }); - - rx - } - }; -} - -macro_rules! call_impl_ws { - ($func_name:ident, $runner:ident) => { - #[inline] - pub(crate) fn $func_name( - cb: CallbackWrapper, - rt: RuntimeRef, - ws: HyperWebsocket, - upgrade: UpgradeData, - scope: WebsocketScope, - ) -> oneshot::Receiver { - let brt = rt.innerb.clone(); - let (tx, rx) = oneshot::channel(); - let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); - - let _ = brt.run(|| { - Python::with_gil(|py| { - let _ = $runner::new(py, cb, protocol, scope).run(py); - }); - }); - - rx - } - }; +#[inline] +pub(crate) fn call_http( + cb: ArcCBScheduler, + rt: RuntimeRef, + body: hyper::body::Incoming, + scope: HTTPScope, +) -> oneshot::Receiver { + let brt = rt.innerb.clone(); + let (tx, rx) = oneshot::channel(); + let protocol = HTTPProtocol::new(rt, tx, body); + + let _ = brt.run(move || { + Python::with_gil(|py| { + let watcher = Py::new(py, CallbackWatcherHTTP::new(py, protocol, scope)).unwrap(); + cb.get().schedule(py, watcher.as_any()); + }); + }); + + rx +} + +#[inline] +pub(crate) fn call_ws( + cb: ArcCBScheduler, + rt: RuntimeRef, + ws: HyperWebsocket, + upgrade: UpgradeData, + scope: WebsocketScope, +) -> oneshot::Receiver { + let brt = rt.innerb.clone(); + let (tx, rx) = oneshot::channel(); + let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); + + let _ = brt.run(move || { + Python::with_gil(|py| { + let watcher = Py::new(py, CallbackWatcherWebsocket::new(py, protocol, scope)).unwrap(); + cb.get().schedule(py, watcher.as_any()); + }); + }); + + rx } - -call_impl_http!(call_http, CallbackRunnerHTTP); -call_impl_http!(call_http_pyw, CallbackWrappedRunnerHTTP); -call_impl_ws!(call_ws, CallbackRunnerWebsocket); -call_impl_ws!(call_ws_pyw, CallbackWrappedRunnerWebsocket); diff --git a/src/rsgi/http.rs b/src/rsgi/http.rs index 50f52361..ee3278ba 100644 --- a/src/rsgi/http.rs +++ b/src/rsgi/http.rs @@ -4,11 +4,11 @@ use std::net::SocketAddr; use tokio::sync::mpsc; use super::{ - callbacks::{call_http, call_http_pyw, call_ws, call_ws_pyw}, + callbacks::{call_http, call_ws}, types::{PyResponse, RSGIHTTPScope as HTTPScope, RSGIWebsocketScope as WebsocketScope}, }; use crate::{ - callbacks::CallbackWrapper, + callbacks::ArcCBScheduler, http::{empty_body, response_500, HTTPRequest, HTTPResponse, HV_SERVER}, runtime::RuntimeRef, ws::{is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade, UpgradeData}, @@ -46,7 +46,7 @@ macro_rules! handle_request { #[inline] pub(crate) async fn $func_name( rt: RuntimeRef, - callback: CallbackWrapper, + callback: ArcCBScheduler, server_addr: SocketAddr, client_addr: SocketAddr, req: HTTPRequest, @@ -59,12 +59,30 @@ macro_rules! handle_request { }; } +// macro_rules! handle_request2 { +// ($func_name:ident, $handler:expr) => { +// #[inline] +// pub(crate) async fn $func_name( +// rt: RuntimeRef, +// callback: ArcCBScheduler, +// server_addr: SocketAddr, +// client_addr: SocketAddr, +// req: HTTPRequest, +// scheme: &str, +// ) -> HTTPResponse { +// let (parts, body) = req.into_parts(); +// let scope = build_scope!(HTTPScope, server_addr, client_addr, parts, scheme); +// handle_http_response!($handler, rt, callback, body, scope) +// } +// }; +// } + macro_rules! handle_request_with_ws { ($func_name:ident, $handler_req:expr, $handler_ws:expr) => { #[inline] pub(crate) async fn $func_name( rt: RuntimeRef, - callback: CallbackWrapper, + callback: ArcCBScheduler, server_addr: SocketAddr, client_addr: SocketAddr, mut req: HTTPRequest, @@ -136,6 +154,7 @@ macro_rules! handle_request_with_ws { } handle_request!(handle, call_http); -handle_request!(handle_pyw, call_http_pyw); +// handle_request!(handle_pyw, call_http_pyw); +// handle_request2!(handle2, call_http2); handle_request_with_ws!(handle_ws, call_http, call_ws); -handle_request_with_ws!(handle_ws_pyw, call_http_pyw, call_ws_pyw); +// handle_request_with_ws!(handle_ws_pyw, call_http_pyw, call_ws_pyw); diff --git a/src/rsgi/io.rs b/src/rsgi/io.rs index e83c2246..cd813ea1 100644 --- a/src/rsgi/io.rs +++ b/src/rsgi/io.rs @@ -96,29 +96,37 @@ impl RSGIHTTPProtocol { error_proto!() } - fn __aiter__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> { - if let Some(body) = pyself.body.lock().unwrap().take() { - let mut stream = pyself.body_stream.blocking_lock(); + fn __aiter__(pyself: Py, py: Python) -> PyResult> { + let inner = pyself.get(); + if let Some(body) = inner.body.lock().unwrap().take() { + let mut stream = inner.body_stream.blocking_lock(); *stream = Some(http_body_util::BodyStream::new(body)); + return Ok(pyself.clone_ref(py)); } - pyself + error_proto!() } - fn __anext__<'p>(&self, py: Python<'p>) -> PyResult>> { + fn __anext__<'p>(&self, py: Python<'p>) -> PyResult> { + if self.body_stream.blocking_lock().is_none() { + return Err(pyo3::exceptions::PyStopAsyncIteration::new_err("stream exhausted")); + } let body_stream = self.body_stream.clone(); - let pyfut = future_into_py_iter(self.rt.clone(), py, async move { - if let Some(stream) = &mut *body_stream.lock().await { - if let Some(chunk) = stream.next().await { + future_into_py_iter(self.rt.clone(), py, async move { + let guard = &mut *body_stream.lock().await; + let bytes = match guard.as_mut().unwrap().next().await { + Some(chunk) => { let chunk = chunk .map(|buf| buf.into_data().unwrap_or_default()) .unwrap_or(body::Bytes::new()); - return Ok(BytesToPy(chunk)); - }; - return Err(pyo3::exceptions::PyStopAsyncIteration::new_err("stream exhausted")); - } - error_proto!() - })?; - Ok(Some(pyfut)) + BytesToPy(chunk) + } + _ => { + let _ = guard.take(); + BytesToPy(body::Bytes::new()) + } + }; + Ok(bytes) + }) } #[pyo3(signature = (status=200, headers=vec![]))] diff --git a/src/rsgi/serve.rs b/src/rsgi/serve.rs index 1f13dca4..bc904b91 100644 --- a/src/rsgi/serve.rs +++ b/src/rsgi/serve.rs @@ -1,7 +1,8 @@ use pyo3::prelude::*; -use super::http::{handle, handle_pyw, handle_ws, handle_ws_pyw}; +use super::http::{handle, handle_ws}; +use crate::callbacks::CallbackScheduler; use crate::conversion::{worker_http1_config_from_py, worker_http2_config_from_py}; use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig, WorkerSignal, WorkerSignals}; @@ -12,21 +13,13 @@ pub struct RSGIWorker { impl RSGIWorker { serve_rth!(_serve_rth, handle); - serve_rth!(_serve_rth_pyw, handle_pyw); serve_rth!(_serve_rth_ws, handle_ws); - serve_rth!(_serve_rth_ws_pyw, handle_ws_pyw); serve_wth!(_serve_wth, handle); - serve_wth!(_serve_wth_pyw, handle_pyw); serve_wth!(_serve_wth_ws, handle_ws); - serve_wth!(_serve_wth_ws_pyw, handle_ws_pyw); serve_rth_ssl!(_serve_rth_ssl, handle); - serve_rth_ssl!(_serve_rth_ssl_pyw, handle_pyw); serve_rth_ssl!(_serve_rth_ssl_ws, handle_ws); - serve_rth_ssl!(_serve_rth_ssl_ws_pyw, handle_ws_pyw); serve_wth_ssl!(_serve_wth_ssl, handle); - serve_wth_ssl!(_serve_wth_ssl_pyw, handle_pyw); serve_wth_ssl!(_serve_wth_ssl_ws, handle_ws); - serve_wth_ssl!(_serve_wth_ssl_ws_pyw, handle_ws_pyw); } #[pymethods] @@ -43,7 +36,6 @@ impl RSGIWorker { http1_opts=None, http2_opts=None, websockets_enabled=false, - opt_enabled=true, ssl_enabled=false, ssl_cert=None, ssl_key=None, @@ -61,7 +53,6 @@ impl RSGIWorker { http1_opts: Option, http2_opts: Option, websockets_enabled: bool, - opt_enabled: bool, ssl_enabled: bool, ssl_cert: Option<&str>, ssl_key: Option<&str>, @@ -78,7 +69,6 @@ impl RSGIWorker { worker_http1_config_from_py(py, http1_opts)?, worker_http2_config_from_py(py, http2_opts)?, websockets_enabled, - opt_enabled, ssl_enabled, ssl_cert, ssl_key, @@ -87,57 +77,21 @@ impl RSGIWorker { }) } - fn serve_rth( - &self, - callback: PyObject, - event_loop: &Bound, - context: Bound, - signal: Py, - ) { - match ( - self.config.websockets_enabled, - self.config.ssl_enabled, - self.config.opt_enabled, - ) { - (false, false, true) => self._serve_rth(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (false, false, false) => self._serve_rth_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (true, false, true) => self._serve_rth_ws(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (true, false, false) => self._serve_rth_ws_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (false, true, true) => self._serve_rth_ssl(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (false, true, false) => { - self._serve_rth_ssl_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)); - } - (true, true, true) => self._serve_rth_ssl_ws(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (true, true, false) => { - self._serve_rth_ssl_ws_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)); - } + fn serve_rth(&self, callback: Py, event_loop: &Bound, signal: Py) { + match (self.config.websockets_enabled, self.config.ssl_enabled) { + (false, false) => self._serve_rth(callback, event_loop, WorkerSignals::Tokio(signal)), + (true, false) => self._serve_rth_ws(callback, event_loop, WorkerSignals::Tokio(signal)), + (false, true) => self._serve_rth_ssl(callback, event_loop, WorkerSignals::Tokio(signal)), + (true, true) => self._serve_rth_ssl_ws(callback, event_loop, WorkerSignals::Tokio(signal)), } } - fn serve_wth( - &self, - callback: PyObject, - event_loop: &Bound, - context: Bound, - signal: Py, - ) { - match ( - self.config.websockets_enabled, - self.config.ssl_enabled, - self.config.opt_enabled, - ) { - (false, false, true) => self._serve_wth(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (false, false, false) => self._serve_wth_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (true, false, true) => self._serve_wth_ws(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (true, false, false) => self._serve_wth_ws_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (false, true, true) => self._serve_wth_ssl(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (false, true, false) => { - self._serve_wth_ssl_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)); - } - (true, true, true) => self._serve_wth_ssl_ws(callback, event_loop, context, WorkerSignals::Tokio(signal)), - (true, true, false) => { - self._serve_wth_ssl_ws_pyw(callback, event_loop, context, WorkerSignals::Tokio(signal)); - } + fn serve_wth(&self, callback: Py, event_loop: &Bound, signal: Py) { + match (self.config.websockets_enabled, self.config.ssl_enabled) { + (false, false) => self._serve_wth(callback, event_loop, WorkerSignals::Tokio(signal)), + (true, false) => self._serve_wth_ws(callback, event_loop, WorkerSignals::Tokio(signal)), + (false, true) => self._serve_wth_ssl(callback, event_loop, WorkerSignals::Tokio(signal)), + (true, true) => self._serve_wth_ssl_ws(callback, event_loop, WorkerSignals::Tokio(signal)), } } } diff --git a/src/workers.rs b/src/workers.rs index 5eb4f316..ebb63a5a 100644 --- a/src/workers.rs +++ b/src/workers.rs @@ -100,7 +100,6 @@ pub(crate) struct WorkerConfig { pub http1_opts: HTTP1Config, pub http2_opts: HTTP2Config, pub websockets_enabled: bool, - pub opt_enabled: bool, pub ssl_enabled: bool, ssl_cert: Option, ssl_key: Option, @@ -118,7 +117,6 @@ impl WorkerConfig { http1_opts: HTTP1Config, http2_opts: HTTP2Config, websockets_enabled: bool, - opt_enabled: bool, ssl_enabled: bool, ssl_cert: Option<&str>, ssl_key: Option<&str>, @@ -134,7 +132,6 @@ impl WorkerConfig { http1_opts, http2_opts, websockets_enabled, - opt_enabled, ssl_enabled, ssl_cert: ssl_cert.map(std::convert::Into::into), ssl_key: ssl_key.map(std::convert::Into::into), @@ -593,9 +590,8 @@ macro_rules! serve_rth { ($func_name:ident, $target:expr) => { fn $func_name( &self, - callback: PyObject, + callback: Py, event_loop: &Bound, - context: Bound, signal: crate::workers::WorkerSignals, ) { pyo3_log::init(); @@ -609,7 +605,7 @@ macro_rules! serve_rth { let http1_opts = self.config.http1_opts.clone(); let http2_opts = self.config.http2_opts.clone(); let backpressure = self.config.backpressure.clone(); - let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop.clone(), context); + let callback_wrapper = std::sync::Arc::new(callback); let rt = crate::runtime::init_runtime_mt( self.config.threads, @@ -673,9 +669,9 @@ macro_rules! serve_rth_ssl { ($func_name:ident, $target:expr) => { fn $func_name( &self, - callback: PyObject, + callback: Py, event_loop: &Bound, - context: Bound, + // context: Bound, signal: crate::workers::WorkerSignals, ) { pyo3_log::init(); @@ -690,7 +686,8 @@ macro_rules! serve_rth_ssl { let http2_opts = self.config.http2_opts.clone(); let backpressure = self.config.backpressure.clone(); let tls_cfg = self.config.tls_cfg(); - let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop.clone(), context); + // let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop.clone(), context); + let callback_wrapper = std::sync::Arc::new(callback); let rt = crate::runtime::init_runtime_mt( self.config.threads, @@ -752,8 +749,9 @@ macro_rules! serve_rth_ssl { } macro_rules! serve_wth_inner { - ($self:expr, $target:expr, $callback:expr, $event_loop:expr, $context:expr, $wid:expr, $workers:expr, $srx:expr) => { - let callback_wrapper = crate::callbacks::CallbackWrapper::new($callback, $event_loop.clone(), $context); + ($self:expr, $target:expr, $callback:expr, $event_loop:expr, $wid:expr, $workers:expr, $srx:expr) => { + // let callback_wrapper = crate::callbacks::CallbackWrapper::new($callback, $event_loop.clone(), $context); + let callback_wrapper = std::sync::Arc::new($callback); let py_loop = std::sync::Arc::new($event_loop.clone().unbind()); for thread_id in 0..$self.config.threads { @@ -807,9 +805,9 @@ macro_rules! serve_wth { ($func_name: ident, $target:expr) => { fn $func_name( &self, - callback: PyObject, + callback: Py, event_loop: &Bound, - context: Bound, + // context: Bound, signal: crate::workers::WorkerSignals, ) { pyo3_log::init(); @@ -819,7 +817,7 @@ macro_rules! serve_wth { let (stx, srx) = tokio::sync::watch::channel(false); let mut workers = vec![]; - crate::workers::serve_wth_inner!(self, $target, callback, event_loop, context, worker_id, workers, srx); + crate::workers::serve_wth_inner!(self, $target, callback, event_loop, worker_id, workers, srx); match signal { crate::workers::WorkerSignals::Tokio(sig) => { @@ -865,8 +863,9 @@ macro_rules! serve_wth { } macro_rules! serve_wth_ssl_inner { - ($self:expr, $target:expr, $callback:expr, $event_loop:expr, $context:expr, $wid:expr, $workers:expr, $srx:expr) => { - let callback_wrapper = crate::callbacks::CallbackWrapper::new($callback, $event_loop.clone(), $context); + ($self:expr, $target:expr, $callback:expr, $event_loop:expr, $wid:expr, $workers:expr, $srx:expr) => { + // let callback_wrapper = crate::callbacks::CallbackWrapper::new($callback, $event_loop.clone(), $context); + let callback_wrapper = std::sync::Arc::new($callback); let py_loop = std::sync::Arc::new($event_loop.clone().unbind()); for thread_id in 0..$self.config.threads { @@ -918,9 +917,9 @@ macro_rules! serve_wth_ssl { ($func_name: ident, $target:expr) => { fn $func_name( &self, - callback: PyObject, + callback: Py, event_loop: &Bound, - context: Bound, + // context: Bound, signal: crate::workers::WorkerSignals, ) { pyo3_log::init(); @@ -930,7 +929,7 @@ macro_rules! serve_wth_ssl { let (stx, srx) = tokio::sync::watch::channel(false); let mut workers = vec![]; - crate::workers::serve_wth_ssl_inner!(self, $target, callback, event_loop, context, worker_id, workers, srx); + crate::workers::serve_wth_ssl_inner!(self, $target, callback, event_loop, worker_id, workers, srx); match signal { crate::workers::WorkerSignals::Tokio(sig) => { diff --git a/src/wsgi/callbacks.rs b/src/wsgi/callbacks.rs index 4a674946..49ca894f 100644 --- a/src/wsgi/callbacks.rs +++ b/src/wsgi/callbacks.rs @@ -9,12 +9,12 @@ use pyo3::{ prelude::*, types::{IntoPyDict, PyBytes, PyDict}, }; -use std::{net::SocketAddr, sync::Arc}; +use std::net::SocketAddr; use tokio::sync::oneshot; use super::{io::WSGIProtocol, types::WSGIBody}; use crate::{ - callbacks::CallbackWrapper, + callbacks::ArcCBScheduler, http::{empty_body, HTTPResponseBody}, runtime::RuntimeRef, utils::log_application_callable_exception, @@ -24,7 +24,7 @@ use crate::{ fn run_callback( rt: RuntimeRef, tx: oneshot::Sender<(u16, HeaderMap, HTTPResponseBody)>, - callback: Arc, + cbs: ArcCBScheduler, mut parts: request::Parts, server_addr: SocketAddr, client_addr: SocketAddr, @@ -66,7 +66,7 @@ fn run_callback( let _ = Python::with_gil(|py| -> PyResult<()> { let proto = Py::new(py, WSGIProtocol::new(tx))?; - let callback = callback.clone_ref(py); + let callback = cbs.get().cb.clone_ref(py); let environ = PyDict::new_bound(py); environ.set_item(pyo3::intern!(py, "SERVER_PROTOCOL"), version)?; environ.set_item(pyo3::intern!(py, "SERVER_NAME"), server.0)?; @@ -108,7 +108,7 @@ fn run_callback( #[inline(always)] pub(crate) fn call_http( rt: RuntimeRef, - cb: CallbackWrapper, + cb: ArcCBScheduler, server_addr: SocketAddr, client_addr: SocketAddr, scheme: &str, @@ -118,7 +118,7 @@ pub(crate) fn call_http( let scheme: std::sync::Arc = scheme.into(); let (tx, rx) = oneshot::channel(); tokio::task::spawn_blocking(move || { - run_callback(rt, tx, cb.callback, req, server_addr, client_addr, &scheme, body); + run_callback(rt, tx, cb, req, server_addr, client_addr, &scheme, body); }); rx } diff --git a/src/wsgi/http.rs b/src/wsgi/http.rs index 030af3bb..287a0e07 100644 --- a/src/wsgi/http.rs +++ b/src/wsgi/http.rs @@ -3,7 +3,7 @@ use std::net::SocketAddr; use super::callbacks::call_http; use crate::{ - callbacks::CallbackWrapper, + callbacks::ArcCBScheduler, http::{response_500, HTTPRequest, HTTPResponse, HTTPResponseBody}, runtime::RuntimeRef, }; @@ -19,7 +19,7 @@ fn build_response(status: u16, pyheaders: hyper::HeaderMap, body: HTTPResponseBo #[inline] pub(crate) async fn handle( rt: RuntimeRef, - callback: CallbackWrapper, + callback: ArcCBScheduler, server_addr: SocketAddr, client_addr: SocketAddr, req: HTTPRequest, diff --git a/src/wsgi/serve.rs b/src/wsgi/serve.rs index 5a295637..f10922ee 100644 --- a/src/wsgi/serve.rs +++ b/src/wsgi/serve.rs @@ -2,6 +2,7 @@ use pyo3::prelude::*; use super::http::handle; +use crate::callbacks::CallbackScheduler; use crate::conversion::{worker_http1_config_from_py, worker_http2_config_from_py}; use crate::workers::{ serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig, WorkerSignalSync, WorkerSignals, @@ -64,7 +65,6 @@ impl WSGIWorker { worker_http1_config_from_py(py, http1_opts)?, worker_http2_config_from_py(py, http2_opts)?, false, - true, ssl_enabled, ssl_cert, ssl_key, @@ -73,29 +73,17 @@ impl WSGIWorker { }) } - fn serve_rth( - &self, - callback: PyObject, - event_loop: &Bound, - context: Bound, - signal: Py, - ) { + fn serve_rth(&self, callback: Py, event_loop: &Bound, signal: Py) { match self.config.ssl_enabled { - false => self._serve_rth(callback, event_loop, context, WorkerSignals::Crossbeam(signal)), - true => self._serve_rth_ssl(callback, event_loop, context, WorkerSignals::Crossbeam(signal)), + false => self._serve_rth(callback, event_loop, WorkerSignals::Crossbeam(signal)), + true => self._serve_rth_ssl(callback, event_loop, WorkerSignals::Crossbeam(signal)), } } - fn serve_wth( - &self, - callback: PyObject, - event_loop: &Bound, - context: Bound, - signal: Py, - ) { + fn serve_wth(&self, callback: Py, event_loop: &Bound, signal: Py) { match self.config.ssl_enabled { - false => self._serve_wth(callback, event_loop, context, WorkerSignals::Crossbeam(signal)), - true => self._serve_wth_ssl(callback, event_loop, context, WorkerSignals::Crossbeam(signal)), + false => self._serve_wth(callback, event_loop, WorkerSignals::Crossbeam(signal)), + true => self._serve_wth_ssl(callback, event_loop, WorkerSignals::Crossbeam(signal)), } } } diff --git a/tests/apps/asgi.py b/tests/apps/asgi.py index 02efb66f..6871977e 100644 --- a/tests/apps/asgi.py +++ b/tests/apps/asgi.py @@ -1,6 +1,8 @@ import json import pathlib +import sniffio + PLAINTEXT_RESPONSE = { 'type': 'http.response.start', @@ -39,6 +41,13 @@ async def info(scope, receive, send): ) +async def sniff_aio_impl(scope, receive, send): + await send(PLAINTEXT_RESPONSE) + await send( + {'type': 'http.response.body', 'body': sniffio.current_async_library().encode('utf8'), 'more_body': False} + ) + + async def echo(scope, receive, send): await send(PLAINTEXT_RESPONSE) more_body = True @@ -129,6 +138,7 @@ def app(scope, receive, send): return lifespan(scope, receive, send) return { '/info': info, + '/sniffio': sniff_aio_impl, '/echo': echo, '/file': pathsend, '/ws_reject': ws_reject, diff --git a/tests/conftest.py b/tests/conftest.py index b5ba9ba6..537599c0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ import asyncio import multiprocessing as mp -import os import socket from contextlib import asynccontextmanager, closing from functools import partial @@ -23,7 +22,6 @@ async def _server(interface, port, threading_mode, tls=False): 'interface': interface, 'port': port, 'threading_mode': threading_mode, - 'loop_opt': bool(os.getenv('LOOP_OPT')), } if tls: if tls == 'private': diff --git a/tests/test_asgi.py b/tests/test_asgi.py index a17a0996..d7fef745 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -77,3 +77,14 @@ async def test_file(asgi_server, threading_mode): assert res.status_code == 200 assert res.headers['content-type'] == 'image/png' assert res.headers['content-length'] == '95' + + +@pytest.mark.asyncio +@pytest.mark.skipif(bool(os.getenv('PGO_RUN')), reason='PGO build') +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) +async def test_sniffio(asgi_server, threading_mode): + async with asgi_server(threading_mode) as port: + res = httpx.get(f'http://localhost:{port}/sniffio') + + assert res.status_code == 200 + assert res.text == 'asyncio'