Skip to content

Commit

Permalink
Enhance futures interop related code
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro committed Dec 4, 2024
1 parent edbbe39 commit cf53191
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 107 deletions.
6 changes: 4 additions & 2 deletions src/asgi/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl ASGIHTTPProtocol {
let body_ref = self.request_body.clone();
let flow_ref = self.flow_rx_exhausted.clone();
let flow_hld = self.flow_tx_waiter.clone();
future_into_py_iter(self.rt.clone(), py, async move {
future_into_py_futlike(self.rt.clone(), py, async move {
let mut bodym = body_ref.lock().await;
let body = &mut *bodym;
let mut more_body = false;
Expand Down Expand Up @@ -327,7 +327,7 @@ impl ASGIWebsocketProtocol {
let tx = self.ws_tx.clone();
let pynone = py.None();

future_into_py_iter(self.rt.clone(), py, async move {
future_into_py_futlike(self.rt.clone(), py, async move {
if let Some(mut upgrade) = upgrade {
let upgrade_headers = match subproto {
Some(v) => vec![(WS_SUBPROTO_HNAME.to_string(), v)],
Expand All @@ -347,6 +347,7 @@ impl ASGIWebsocketProtocol {
}
}
}
Python::with_gil(|_| drop(pynone));
error_flow!()
})
}
Expand All @@ -369,6 +370,7 @@ impl ASGIWebsocketProtocol {
}
};
};
Python::with_gil(|_| drop(pynone));
error_flow!()
})
}
Expand Down
175 changes: 103 additions & 72 deletions src/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ impl CallbackScheduler {
}
}

#[inline]
pub(crate) fn send(pyself: Py<Self>, py: Python, coro: PyObject) {
let rself = pyself.get();
let ptr = pyself.as_ptr();
Expand Down Expand Up @@ -106,6 +107,7 @@ impl CallbackScheduler {
let corom = pyo3::ffi::PyObject_GetAttr(coro.as_ptr(), rself.pyname_aiothrow.as_ptr());
pyo3::ffi::PyObject_CallOneArg(rself.aio_tenter.as_ptr(), ptr);
pyo3::ffi::PyObject_CallOneArg(corom, err.as_ptr());
pyo3::ffi::PyErr_Clear();
pyo3::ffi::PyObject_CallOneArg(rself.aio_texit.as_ptr(), ptr);
}
}
Expand Down Expand Up @@ -203,20 +205,19 @@ impl PyEmptyAwaitable {

#[pyclass(frozen, module = "granian._granian")]
pub(crate) struct PyIterAwaitable {
result: RwLock<Option<PyResult<PyObject>>>,
result: OnceLock<PyResult<PyObject>>,
}

#[cfg(not(target_os = "linux"))]
impl PyIterAwaitable {
pub(crate) fn new() -> Self {
Self {
result: RwLock::new(None),
result: OnceLock::new(),
}
}

#[inline]
pub(crate) fn set_result(&self, result: PyResult<PyObject>) {
let mut res = self.result.write().unwrap();
*res = Some(result);
let _ = self.result.set(result);
}
}

Expand All @@ -231,27 +232,28 @@ impl PyIterAwaitable {
}

fn __next__(&self, py: Python) -> PyResult<Option<PyObject>> {
if let Ok(res) = self.result.try_read() {
if let Some(ref res) = *res {
return res
.as_ref()
.map_err(|err| err.clone_ref(py))
.map(|v| Err(PyStopIteration::new_err(v.clone_ref(py))))?;
}
};
if let Some(res) = py.allow_threads(|| self.result.get()) {
return res
.as_ref()
.map_err(|err| err.clone_ref(py))
.map(|v| Err(PyStopIteration::new_err(v.clone_ref(py))))?;
}

Ok(Some(py.None()))
}
}

#[repr(u8)]
enum PyFutureAwaitableState {
Pending,
Completed(PyResult<PyObject>),
Cancelled,
Pending = 0,
Completed = 1,
Cancelled = 2,
}

#[pyclass(frozen, module = "granian._granian")]
pub(crate) struct PyFutureAwaitable {
state: RwLock<PyFutureAwaitableState>,
state: atomic::AtomicU8,
result: OnceLock<PyResult<PyObject>>,
event_loop: PyObject,
cancel_tx: Arc<Notify>,
py_block: atomic::AtomicBool,
Expand All @@ -261,7 +263,8 @@ pub(crate) struct PyFutureAwaitable {
impl PyFutureAwaitable {
pub(crate) fn new(event_loop: PyObject) -> Self {
Self {
state: RwLock::new(PyFutureAwaitableState::Pending),
state: atomic::AtomicU8::new(PyFutureAwaitableState::Pending as u8),
result: OnceLock::new(),
event_loop,
cancel_tx: Arc::new(Notify::new()),
py_block: true.into(),
Expand All @@ -274,23 +277,36 @@ impl PyFutureAwaitable {
Ok((Py::new(py, self)?, cancel_tx))
}

pub(crate) fn set_result(&self, result: PyResult<PyObject>, aw: Py<PyFutureAwaitable>) {
Python::with_gil(|py| {
let mut state = self.state.write().unwrap();
if !matches!(&mut *state, PyFutureAwaitableState::Pending) {
return;
}
*state = PyFutureAwaitableState::Completed(result);
pub(crate) fn set_result(pyself: Py<Self>, result: PyResult<PyObject>) {
let rself = pyself.get();
if rself
.state
.compare_exchange(
PyFutureAwaitableState::Pending as u8,
PyFutureAwaitableState::Completed as u8,
atomic::Ordering::Release,
atomic::Ordering::Relaxed,
)
.is_err()
{
Python::with_gil(|_| drop(result));
return;
}

let ack = self.ack.read().unwrap();
let _ = rself.result.set(result);

Python::with_gil(|py| {
let ack = pyself.get().ack.read().unwrap();
if let Some((cb, ctx)) = &*ack {
let _ = self.event_loop.clone_ref(py).call_method(
let _ = pyself.get().event_loop.clone_ref(py).call_method(
py,
pyo3::intern!(py, "call_soon_threadsafe"),
(cb, aw),
(cb, pyself.clone_ref(py)),
Some(ctx.bind(py)),
);
}
drop(ack);
drop(pyself);
});
}
}
Expand All @@ -305,15 +321,17 @@ impl PyFutureAwaitable {
}

fn __next__(pyself: PyRef<'_, Self>) -> PyResult<Option<PyRef<'_, Self>>> {
let state = pyself.state.read().unwrap();
if let PyFutureAwaitableState::Completed(res) = &*state {
if pyself.state.load(atomic::Ordering::Acquire) == PyFutureAwaitableState::Completed as u8 {
let py = pyself.py();
return res
return pyself
.result
.get()
.unwrap()
.as_ref()
.map_err(|err| err.clone_ref(py))
.map(|v| Err(PyStopIteration::new_err(v.clone_ref(py))))?;
};
drop(state);
}

Ok(Some(pyself))
}

Expand All @@ -337,20 +355,16 @@ impl PyFutureAwaitable {
let kwctx = pyo3::types::PyDict::new(py);
kwctx.set_item(pyo3::intern!(py, "context"), context)?;

let state = pyself.state.read().unwrap();
match &*state {
PyFutureAwaitableState::Pending => {
let mut ack = pyself.ack.write().unwrap();
*ack = Some((cb, kwctx.unbind()));
Ok(())
}
_ => {
drop(state);
let event_loop = pyself.event_loop.clone_ref(py);
event_loop.call_method(py, pyo3::intern!(py, "call_soon"), (cb, pyself), Some(&kwctx))?;
Ok(())
}
let state = pyself.state.load(atomic::Ordering::Acquire);
if state == PyFutureAwaitableState::Pending as u8 {
let mut ack = pyself.ack.write().unwrap();
*ack = Some((cb, kwctx.unbind()));
} else {
let event_loop = pyself.event_loop.clone_ref(py);
event_loop.call_method(py, pyo3::intern!(py, "call_soon"), (cb, pyself), Some(&kwctx))?;
}

Ok(())
}

#[allow(unused)]
Expand All @@ -363,13 +377,20 @@ impl PyFutureAwaitable {
#[allow(unused)]
#[pyo3(signature = (msg=None))]
fn cancel(pyself: PyRef<'_, Self>, msg: Option<PyObject>) -> bool {
let mut state = pyself.state.write().unwrap();
if !matches!(&mut *state, PyFutureAwaitableState::Pending) {
if pyself
.state
.compare_exchange(
PyFutureAwaitableState::Pending as u8,
PyFutureAwaitableState::Cancelled as u8,
atomic::Ordering::Release,
atomic::Ordering::Relaxed,
)
.is_err()
{
return false;
}

pyself.cancel_tx.notify_one();
*state = PyFutureAwaitableState::Cancelled;

let ack = pyself.ack.read().unwrap();
if let Some((cb, ctx)) = &*ack {
Expand All @@ -378,7 +399,6 @@ impl PyFutureAwaitable {
let cb = cb.clone_ref(py);
let ctx = ctx.clone_ref(py);
drop(ack);
drop(state);

let _ = event_loop.call_method(py, pyo3::intern!(py, "call_soon"), (cb, pyself), Some(ctx.bind(py)));
}
Expand All @@ -387,36 +407,47 @@ impl PyFutureAwaitable {
}

fn done(&self) -> bool {
let state = self.state.read().unwrap();
!matches!(&*state, PyFutureAwaitableState::Pending)
self.state.load(atomic::Ordering::Acquire) != PyFutureAwaitableState::Pending as u8
}

fn result(&self, py: Python) -> PyResult<PyObject> {
let state = self.state.read().unwrap();
match &*state {
PyFutureAwaitableState::Completed(res) => {
res.as_ref().map(|v| v.clone_ref(py)).map_err(|err| err.clone_ref(py))
}
PyFutureAwaitableState::Cancelled => {
Err(pyo3::exceptions::asyncio::CancelledError::new_err("Future cancelled."))
}
PyFutureAwaitableState::Pending => Err(pyo3::exceptions::asyncio::InvalidStateError::new_err(
"Result is not ready.",
)),
let state = self.state.load(atomic::Ordering::Acquire);

if state == PyFutureAwaitableState::Completed as u8 {
return self
.result
.get()
.unwrap()
.as_ref()
.map(|v| v.clone_ref(py))
.map_err(|err| err.clone_ref(py));
}
if state == PyFutureAwaitableState::Cancelled as u8 {
return Err(pyo3::exceptions::asyncio::CancelledError::new_err("Future cancelled."));
}
Err(pyo3::exceptions::asyncio::InvalidStateError::new_err(
"Result is not ready.",
))
}

fn exception(&self, py: Python) -> PyResult<PyObject> {
let state = self.state.read().unwrap();
match &*state {
PyFutureAwaitableState::Completed(res) => res.as_ref().map(|_| py.None()).map_err(|err| err.clone_ref(py)),
PyFutureAwaitableState::Cancelled => {
Err(pyo3::exceptions::asyncio::CancelledError::new_err("Future cancelled."))
}
PyFutureAwaitableState::Pending => Err(pyo3::exceptions::asyncio::InvalidStateError::new_err(
"Exception is not set.",
)),
let state = self.state.load(atomic::Ordering::Acquire);

if state == PyFutureAwaitableState::Completed as u8 {
return self
.result
.get()
.unwrap()
.as_ref()
.map(|_| py.None())
.map_err(|err| err.clone_ref(py));
}
if state == PyFutureAwaitableState::Cancelled as u8 {
return Err(pyo3::exceptions::asyncio::CancelledError::new_err("Future cancelled."));
}
Err(pyo3::exceptions::asyncio::InvalidStateError::new_err(
"Exception is not set.",
))
}
}

Expand Down
Loading

0 comments on commit cf53191

Please sign in to comment.