From cf53191e370d6671e6e3c87c4917aec49223c5c9 Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Wed, 4 Dec 2024 18:33:57 +0100 Subject: [PATCH] Enhance futures interop related code --- src/asgi/io.rs | 6 +- src/callbacks.rs | 175 ++++++++++++++++++++++++++++------------------- src/rsgi/io.rs | 28 ++++++-- src/runtime.rs | 28 +------- 4 files changed, 130 insertions(+), 107 deletions(-) diff --git a/src/asgi/io.rs b/src/asgi/io.rs index a4a112f..e7ec806 100644 --- a/src/asgi/io.rs +++ b/src/asgi/io.rs @@ -128,7 +128,7 @@ impl ASGIHTTPProtocol { let body_ref = self.request_body.clone(); let flow_ref = self.flow_rx_exhausted.clone(); let flow_hld = self.flow_tx_waiter.clone(); - future_into_py_iter(self.rt.clone(), py, async move { + future_into_py_futlike(self.rt.clone(), py, async move { let mut bodym = body_ref.lock().await; let body = &mut *bodym; let mut more_body = false; @@ -327,7 +327,7 @@ impl ASGIWebsocketProtocol { let tx = self.ws_tx.clone(); let pynone = py.None(); - future_into_py_iter(self.rt.clone(), py, async move { + future_into_py_futlike(self.rt.clone(), py, async move { if let Some(mut upgrade) = upgrade { let upgrade_headers = match subproto { Some(v) => vec![(WS_SUBPROTO_HNAME.to_string(), v)], @@ -347,6 +347,7 @@ impl ASGIWebsocketProtocol { } } } + Python::with_gil(|_| drop(pynone)); error_flow!() }) } @@ -369,6 +370,7 @@ impl ASGIWebsocketProtocol { } }; }; + Python::with_gil(|_| drop(pynone)); error_flow!() }) } diff --git a/src/callbacks.rs b/src/callbacks.rs index 1466475..83e2554 100644 --- a/src/callbacks.rs +++ b/src/callbacks.rs @@ -35,6 +35,7 @@ impl CallbackScheduler { } } + #[inline] pub(crate) fn send(pyself: Py, py: Python, coro: PyObject) { let rself = pyself.get(); let ptr = pyself.as_ptr(); @@ -106,6 +107,7 @@ impl CallbackScheduler { let corom = pyo3::ffi::PyObject_GetAttr(coro.as_ptr(), rself.pyname_aiothrow.as_ptr()); pyo3::ffi::PyObject_CallOneArg(rself.aio_tenter.as_ptr(), ptr); pyo3::ffi::PyObject_CallOneArg(corom, err.as_ptr()); + pyo3::ffi::PyErr_Clear(); pyo3::ffi::PyObject_CallOneArg(rself.aio_texit.as_ptr(), ptr); } } @@ -203,20 +205,19 @@ impl PyEmptyAwaitable { #[pyclass(frozen, module = "granian._granian")] pub(crate) struct PyIterAwaitable { - result: RwLock>>, + result: OnceLock>, } -#[cfg(not(target_os = "linux"))] impl PyIterAwaitable { pub(crate) fn new() -> Self { Self { - result: RwLock::new(None), + result: OnceLock::new(), } } + #[inline] pub(crate) fn set_result(&self, result: PyResult) { - let mut res = self.result.write().unwrap(); - *res = Some(result); + let _ = self.result.set(result); } } @@ -231,27 +232,28 @@ impl PyIterAwaitable { } fn __next__(&self, py: Python) -> PyResult> { - if let Ok(res) = self.result.try_read() { - if let Some(ref res) = *res { - return res - .as_ref() - .map_err(|err| err.clone_ref(py)) - .map(|v| Err(PyStopIteration::new_err(v.clone_ref(py))))?; - } - }; + if let Some(res) = py.allow_threads(|| self.result.get()) { + return res + .as_ref() + .map_err(|err| err.clone_ref(py)) + .map(|v| Err(PyStopIteration::new_err(v.clone_ref(py))))?; + } + Ok(Some(py.None())) } } +#[repr(u8)] enum PyFutureAwaitableState { - Pending, - Completed(PyResult), - Cancelled, + Pending = 0, + Completed = 1, + Cancelled = 2, } #[pyclass(frozen, module = "granian._granian")] pub(crate) struct PyFutureAwaitable { - state: RwLock, + state: atomic::AtomicU8, + result: OnceLock>, event_loop: PyObject, cancel_tx: Arc, py_block: atomic::AtomicBool, @@ -261,7 +263,8 @@ pub(crate) struct PyFutureAwaitable { impl PyFutureAwaitable { pub(crate) fn new(event_loop: PyObject) -> Self { Self { - state: RwLock::new(PyFutureAwaitableState::Pending), + state: atomic::AtomicU8::new(PyFutureAwaitableState::Pending as u8), + result: OnceLock::new(), event_loop, cancel_tx: Arc::new(Notify::new()), py_block: true.into(), @@ -274,23 +277,36 @@ impl PyFutureAwaitable { Ok((Py::new(py, self)?, cancel_tx)) } - pub(crate) fn set_result(&self, result: PyResult, aw: Py) { - Python::with_gil(|py| { - let mut state = self.state.write().unwrap(); - if !matches!(&mut *state, PyFutureAwaitableState::Pending) { - return; - } - *state = PyFutureAwaitableState::Completed(result); + pub(crate) fn set_result(pyself: Py, result: PyResult) { + let rself = pyself.get(); + if rself + .state + .compare_exchange( + PyFutureAwaitableState::Pending as u8, + PyFutureAwaitableState::Completed as u8, + atomic::Ordering::Release, + atomic::Ordering::Relaxed, + ) + .is_err() + { + Python::with_gil(|_| drop(result)); + return; + } - let ack = self.ack.read().unwrap(); + let _ = rself.result.set(result); + + Python::with_gil(|py| { + let ack = pyself.get().ack.read().unwrap(); if let Some((cb, ctx)) = &*ack { - let _ = self.event_loop.clone_ref(py).call_method( + let _ = pyself.get().event_loop.clone_ref(py).call_method( py, pyo3::intern!(py, "call_soon_threadsafe"), - (cb, aw), + (cb, pyself.clone_ref(py)), Some(ctx.bind(py)), ); } + drop(ack); + drop(pyself); }); } } @@ -305,15 +321,17 @@ impl PyFutureAwaitable { } fn __next__(pyself: PyRef<'_, Self>) -> PyResult>> { - let state = pyself.state.read().unwrap(); - if let PyFutureAwaitableState::Completed(res) = &*state { + if pyself.state.load(atomic::Ordering::Acquire) == PyFutureAwaitableState::Completed as u8 { let py = pyself.py(); - return res + return pyself + .result + .get() + .unwrap() .as_ref() .map_err(|err| err.clone_ref(py)) .map(|v| Err(PyStopIteration::new_err(v.clone_ref(py))))?; - }; - drop(state); + } + Ok(Some(pyself)) } @@ -337,20 +355,16 @@ impl PyFutureAwaitable { let kwctx = pyo3::types::PyDict::new(py); kwctx.set_item(pyo3::intern!(py, "context"), context)?; - let state = pyself.state.read().unwrap(); - match &*state { - PyFutureAwaitableState::Pending => { - let mut ack = pyself.ack.write().unwrap(); - *ack = Some((cb, kwctx.unbind())); - Ok(()) - } - _ => { - drop(state); - let event_loop = pyself.event_loop.clone_ref(py); - event_loop.call_method(py, pyo3::intern!(py, "call_soon"), (cb, pyself), Some(&kwctx))?; - Ok(()) - } + let state = pyself.state.load(atomic::Ordering::Acquire); + if state == PyFutureAwaitableState::Pending as u8 { + let mut ack = pyself.ack.write().unwrap(); + *ack = Some((cb, kwctx.unbind())); + } else { + let event_loop = pyself.event_loop.clone_ref(py); + event_loop.call_method(py, pyo3::intern!(py, "call_soon"), (cb, pyself), Some(&kwctx))?; } + + Ok(()) } #[allow(unused)] @@ -363,13 +377,20 @@ impl PyFutureAwaitable { #[allow(unused)] #[pyo3(signature = (msg=None))] fn cancel(pyself: PyRef<'_, Self>, msg: Option) -> bool { - let mut state = pyself.state.write().unwrap(); - if !matches!(&mut *state, PyFutureAwaitableState::Pending) { + if pyself + .state + .compare_exchange( + PyFutureAwaitableState::Pending as u8, + PyFutureAwaitableState::Cancelled as u8, + atomic::Ordering::Release, + atomic::Ordering::Relaxed, + ) + .is_err() + { return false; } pyself.cancel_tx.notify_one(); - *state = PyFutureAwaitableState::Cancelled; let ack = pyself.ack.read().unwrap(); if let Some((cb, ctx)) = &*ack { @@ -378,7 +399,6 @@ impl PyFutureAwaitable { let cb = cb.clone_ref(py); let ctx = ctx.clone_ref(py); drop(ack); - drop(state); let _ = event_loop.call_method(py, pyo3::intern!(py, "call_soon"), (cb, pyself), Some(ctx.bind(py))); } @@ -387,36 +407,47 @@ impl PyFutureAwaitable { } fn done(&self) -> bool { - let state = self.state.read().unwrap(); - !matches!(&*state, PyFutureAwaitableState::Pending) + self.state.load(atomic::Ordering::Acquire) != PyFutureAwaitableState::Pending as u8 } fn result(&self, py: Python) -> PyResult { - let state = self.state.read().unwrap(); - match &*state { - PyFutureAwaitableState::Completed(res) => { - res.as_ref().map(|v| v.clone_ref(py)).map_err(|err| err.clone_ref(py)) - } - PyFutureAwaitableState::Cancelled => { - Err(pyo3::exceptions::asyncio::CancelledError::new_err("Future cancelled.")) - } - PyFutureAwaitableState::Pending => Err(pyo3::exceptions::asyncio::InvalidStateError::new_err( - "Result is not ready.", - )), + let state = self.state.load(atomic::Ordering::Acquire); + + if state == PyFutureAwaitableState::Completed as u8 { + return self + .result + .get() + .unwrap() + .as_ref() + .map(|v| v.clone_ref(py)) + .map_err(|err| err.clone_ref(py)); + } + if state == PyFutureAwaitableState::Cancelled as u8 { + return Err(pyo3::exceptions::asyncio::CancelledError::new_err("Future cancelled.")); } + Err(pyo3::exceptions::asyncio::InvalidStateError::new_err( + "Result is not ready.", + )) } fn exception(&self, py: Python) -> PyResult { - let state = self.state.read().unwrap(); - match &*state { - PyFutureAwaitableState::Completed(res) => res.as_ref().map(|_| py.None()).map_err(|err| err.clone_ref(py)), - PyFutureAwaitableState::Cancelled => { - Err(pyo3::exceptions::asyncio::CancelledError::new_err("Future cancelled.")) - } - PyFutureAwaitableState::Pending => Err(pyo3::exceptions::asyncio::InvalidStateError::new_err( - "Exception is not set.", - )), + let state = self.state.load(atomic::Ordering::Acquire); + + if state == PyFutureAwaitableState::Completed as u8 { + return self + .result + .get() + .unwrap() + .as_ref() + .map(|_| py.None()) + .map_err(|err| err.clone_ref(py)); + } + if state == PyFutureAwaitableState::Cancelled as u8 { + return Err(pyo3::exceptions::asyncio::CancelledError::new_err("Future cancelled.")); } + Err(pyo3::exceptions::asyncio::InvalidStateError::new_err( + "Exception is not set.", + )) } } diff --git a/src/rsgi/io.rs b/src/rsgi/io.rs index cbf0f3a..ffae5c0 100644 --- a/src/rsgi/io.rs +++ b/src/rsgi/io.rs @@ -48,7 +48,10 @@ impl RSGIHTTPStreamTransport { future_into_py_futlike(self.rt.clone(), py, async move { match transport.send(Ok(body::Bytes::from(bdata))).await { Ok(()) => Ok(pynone), - _ => error_stream!(), + _ => { + Python::with_gil(|_| drop(pynone)); + error_stream!() + } } }) } @@ -60,7 +63,10 @@ impl RSGIHTTPStreamTransport { future_into_py_futlike(self.rt.clone(), py, async move { match transport.send(Ok(body::Bytes::from(data))).await { Ok(()) => Ok(pynone), - _ => error_stream!(), + _ => { + Python::with_gil(|_| drop(pynone)); + error_stream!() + } } }) } @@ -93,7 +99,7 @@ impl RSGIHTTPProtocol { impl RSGIHTTPProtocol { fn __call__<'p>(&self, py: Python<'p>) -> PyResult> { if let Some(body) = self.body.lock().unwrap().take() { - return future_into_py_iter(self.rt.clone(), py, async move { + return future_into_py_futlike(self.rt.clone(), py, async move { match body.collect().await { Ok(data) => { let bytes = BytesToPy(data.to_bytes()); @@ -121,7 +127,7 @@ impl RSGIHTTPProtocol { return Err(pyo3::exceptions::PyStopAsyncIteration::new_err("stream exhausted")); } let body_stream = self.body_stream.clone(); - future_into_py_iter(self.rt.clone(), py, async move { + future_into_py_futlike(self.rt.clone(), py, async move { let guard = &mut *body_stream.lock().await; let bytes = match guard.as_mut().unwrap().next().await { Some(chunk) => { @@ -256,9 +262,13 @@ impl RSGIWebsocketTransport { if let Ok(mut stream) = transport.try_lock() { return match stream.send(bdata[..].into()).await { Ok(()) => Ok(pynone), - _ => error_stream!(), + _ => { + Python::with_gil(|_| drop(pynone)); + error_stream!() + } }; } + Python::with_gil(|_| drop(pynone)); error_proto!() }) } @@ -271,9 +281,13 @@ impl RSGIWebsocketTransport { if let Ok(mut stream) = transport.try_lock() { return match stream.send(Message::Text(data)).await { Ok(()) => Ok(pynone), - _ => error_stream!(), + _ => { + Python::with_gil(|_| drop(pynone)); + error_stream!() + } }; } + Python::with_gil(|_| drop(pynone)); error_proto!() }) } @@ -384,7 +398,7 @@ impl RSGIWebsocketProtocol { let mut upgrade = self.upgrade.write().unwrap().take().unwrap(); let transport = self.websocket.clone(); let itransport = self.transport.clone(); - future_into_py_iter(self.rt.clone(), py, async move { + future_into_py_futlike(self.rt.clone(), py, async move { let mut ws = transport.lock().await; match upgrade.send(None).await { Ok(()) => match (&mut *ws).await { diff --git a/src/runtime.rs b/src/runtime.rs index aff85c1..81aa583 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -142,7 +142,6 @@ pub(crate) fn init_runtime_st(blocking_threads: usize, py_loop: Arc) - // It consumes more cpu-cycles than `future_into_py_futlike`, // but for "quick" operations it's something like 12% faster. #[allow(unused_must_use)] -#[cfg(not(target_os = "linux"))] pub(crate) fn future_into_py_iter(rt: R, py: Python, fut: F) -> PyResult> where R: Runtime + ContextExt + Clone, @@ -163,20 +162,6 @@ where Ok(py_fut.into_any().into_bound(py)) } -// NOTE: -// for some unknown reasons, it seems on Linux the real implementation -// has performance issues. We just fallback to `futlike` impl on such targets. -// MacOS works best with original impl, Windows still needs further analysis. -#[cfg(target_os = "linux")] -#[inline(always)] -pub(crate) fn future_into_py_iter(rt: R, py: Python, fut: F) -> PyResult> -where - R: Runtime + ContextExt + Clone, - F: Future> + Send + 'static, -{ - future_into_py_futlike(rt, py, fut) -} - // NOTE: // `future_into_py_futlike` relies on an `asyncio.Future` like implementation. // This is generally ~38% faster than `pyo3_asyncio.future_into_py` implementation. @@ -191,25 +176,16 @@ where { let event_loop = rt.py_event_loop(py); let (aw, cancel_tx) = PyFutureAwaitable::new(event_loop).to_spawn(py)?; - let aw_ref = aw.clone_ref(py); let py_fut = aw.clone_ref(py); let rb = rt.blocking(); rt.spawn(async move { tokio::select! { result = fut => { - let _ = rb.run(move || { - aw.get().set_result(result, aw_ref); - Python::with_gil(|_| drop(aw)); - }); + let _ = rb.run(move || PyFutureAwaitable::set_result(aw, result)); }, () = cancel_tx.notified() => { - let _ = rb.run(move || { - Python::with_gil(|_| { - drop(aw_ref); - drop(aw); - }); - }); + let _ = rb.run(move || Python::with_gil(|_| drop(aw))); } } });