Skip to content

Commit

Permalink
always normalize exceptions before raising
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Sep 23, 2023
1 parent aeb7a95 commit 6e88e59
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 52 deletions.
1 change: 1 addition & 0 deletions newsfragments/3471.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix `IterNextOutput::Return` not returning a value on PyPy.
1 change: 1 addition & 0 deletions pytests/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
hypothesis>=3.55
pytest>=6.0
pytest-asyncio>=0.21
pytest-benchmark>=3.4
psutil>=5.6
typing_extensions>=4.0.0
87 changes: 87 additions & 0 deletions pytests/src/awaitable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//! The following classes are examples of objects which implement Python's
//! awaitable protocol.
//!
//! Both IterAwaitable and FutureAwaitable will return a value immediately
//! when awaited, see guide examples related to pyo3-asyncio for ways
//! to suspend tasks and await results.

use pyo3::{prelude::*, pyclass::IterNextOutput};

#[pyclass]
#[derive(Debug)]
pub(crate) struct IterAwaitable {
result: Option<PyResult<PyObject>>,
}

#[pymethods]
impl IterAwaitable {
#[new]
fn new(result: PyObject) -> Self {
IterAwaitable {
result: Some(Ok(result)),
}
}

fn __await__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> {
pyself
}

fn __iter__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> {
pyself
}

fn __next__(&mut self, py: Python<'_>) -> PyResult<IterNextOutput<PyObject, PyObject>> {
match self.result.take() {
Some(res) => match res {
Ok(v) => Ok(IterNextOutput::Return(v)),
Err(err) => Err(err),
},
_ => Ok(IterNextOutput::Yield(py.None())),
}
}
}

#[pyclass]
pub(crate) struct FutureAwaitable {
#[pyo3(get, set, name = "_asyncio_future_blocking")]
py_block: bool,
result: Option<PyResult<PyObject>>,
}

#[pymethods]
impl FutureAwaitable {
#[new]
fn new(result: PyObject) -> Self {
FutureAwaitable {
py_block: false,
result: Some(Ok(result)),
}
}

fn __await__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> {
pyself
}

fn __iter__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> {
pyself
}

fn __next__(
mut pyself: PyRefMut<'_, Self>,
) -> PyResult<IterNextOutput<PyRefMut<'_, Self>, PyObject>> {
match pyself.result {
Some(_) => match pyself.result.take().unwrap() {
Ok(v) => Ok(IterNextOutput::Return(v)),
Err(err) => Err(err),
},
_ => Ok(IterNextOutput::Yield(pyself)),
}
}
}

#[pymodule]
pub fn awaitable(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<IterAwaitable>()?;
m.add_class::<FutureAwaitable>()?;
Ok(())
}
3 changes: 3 additions & 0 deletions pytests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::wrap_pymodule;

pub mod awaitable;
pub mod buf_and_str;
pub mod comparisons;
pub mod datetime;
Expand All @@ -17,6 +18,7 @@ pub mod subclassing;

#[pymodule]
fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pymodule!(awaitable::awaitable))?;
#[cfg(not(Py_LIMITED_API))]
m.add_wrapped(wrap_pymodule!(buf_and_str::buf_and_str))?;
m.add_wrapped(wrap_pymodule!(comparisons::comparisons))?;
Expand All @@ -37,6 +39,7 @@ fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {

let sys = PyModule::import(py, "sys")?;
let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?;
sys_modules.set_item("pyo3_pytests.awaitable", m.getattr("awaitable")?)?;
sys_modules.set_item("pyo3_pytests.buf_and_str", m.getattr("buf_and_str")?)?;
sys_modules.set_item("pyo3_pytests.comparisons", m.getattr("comparisons")?)?;
sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?;
Expand Down
13 changes: 13 additions & 0 deletions pytests/tests/test_awaitable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from pyo3_pytests.awaitable import IterAwaitable, FutureAwaitable


@pytest.mark.asyncio
async def test_iter_awaitable():
assert await IterAwaitable(5) == 5


@pytest.mark.asyncio
async def test_future_awaitable():
assert await FutureAwaitable(5) == 5
126 changes: 79 additions & 47 deletions src/err/err_state.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use crate::{
exceptions::{PyBaseException, PyTypeError},
ffi,
types::{PyTraceback, PyType},
types::{PyString, PyTraceback, PyType},
IntoPy, Py, PyAny, PyObject, PyTypeInfo, Python,
};

#[derive(Clone)]
pub(crate) struct PyErrStateNormalized {
#[cfg(not(Py_3_12))]
pub ptype: Py<PyType>,
ptype: Py<PyType>,
pub pvalue: Py<PyBaseException>,
#[cfg(not(Py_3_12))]
pub ptraceback: Option<Py<PyTraceback>>,
ptraceback: Option<Py<PyTraceback>>,
}

impl PyErrStateNormalized {
Expand All @@ -36,6 +36,26 @@ impl PyErrStateNormalized {
pub(crate) fn ptraceback<'py>(&'py self, py: Python<'py>) -> Option<&'py PyTraceback> {
unsafe { py.from_owned_ptr_or_opt(ffi::PyException_GetTraceback(self.pvalue.as_ptr())) }
}

#[cfg(Py_3_12)]
pub(crate) fn take(py: Python<'_>) -> Option<PyErrStateNormalized> {
unsafe { Py::from_owned_ptr_or_opt(py, ffi::PyErr_GetRaisedException()) }
.map(|pvalue| PyErrStateNormalized { pvalue })
}

#[cfg(not(Py_3_12))]
unsafe fn from_normalized_ffi_tuple(
py: Python<'_>,
ptype: *mut ffi::PyObject,
pvalue: *mut ffi::PyObject,
ptraceback: *mut ffi::PyObject,
) -> Self {
PyErrStateNormalized {
ptype: Py::from_owned_ptr_or_opt(py, ptype).expect("Exception type missing"),
pvalue: Py::from_owned_ptr_or_opt(py, pvalue).expect("Exception value missing"),
ptraceback: Py::from_owned_ptr_or_opt(py, ptraceback),
}
}
}

pub(crate) struct PyErrStateLazyFnOutput {
Expand Down Expand Up @@ -96,24 +116,45 @@ impl PyErrState {
})
}

#[cfg(not(Py_3_12))]
pub(crate) fn into_ffi_tuple(
self,
py: Python<'_>,
) -> (*mut ffi::PyObject, *mut ffi::PyObject, *mut ffi::PyObject) {
pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
match self {
#[cfg(not(Py_3_12))]
PyErrState::Lazy(lazy) => {
let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
if unsafe { ffi::PyExceptionClass_Check(ptype.as_ptr()) } == 0 {
PyErrState::lazy(
PyTypeError::type_object(py),
"exceptions must derive from BaseException",
)
.into_ffi_tuple(py)
} else {
(ptype.into_ptr(), pvalue.into_ptr(), std::ptr::null_mut())
let (ptype, pvalue, ptraceback) = lazy_into_normalized_ffi_tuple(py, lazy);
unsafe {
PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback)
}
}
#[cfg(Py_3_12)]
PyErrState::Lazy(lazy) => {
// To keep the implementation simple, just write the exception into the interpreter,
// which will cause it to be normalized
self.restore(py);
PyErrStateNormalized::take(py)
.expect("exception missing after writing to the interpreter")
}
#[cfg(not(Py_3_12))]
PyErrState::FfiTuple {
ptype,
pvalue,
ptraceback,
} => {
let mut ptype = ptype.into_ptr();
let mut pvalue = pvalue.map_or(std::ptr::null_mut(), Py::into_ptr);
let mut ptraceback = ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr);
unsafe {
ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback)
}
}
PyErrState::Normalized(normalized) => normalized,
}
}

#[cfg(not(Py_3_12))]
pub(crate) fn restore(self, py: Python<'_>) {
let (ptype, pvalue, ptraceback) = match self {
PyErrState::Lazy(lazy) => lazy_into_normalized_ffi_tuple(py, lazy),
PyErrState::FfiTuple {
ptype,
pvalue,
Expand All @@ -132,36 +173,7 @@ impl PyErrState {
pvalue.into_ptr(),
ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr),
),
}
}

#[cfg(not(Py_3_12))]
pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
let (mut ptype, mut pvalue, mut ptraceback) = self.into_ffi_tuple(py);

unsafe {
ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
PyErrStateNormalized {
ptype: Py::from_owned_ptr_or_opt(py, ptype).expect("Exception type missing"),
pvalue: Py::from_owned_ptr_or_opt(py, pvalue).expect("Exception value missing"),
ptraceback: Py::from_owned_ptr_or_opt(py, ptraceback),
}
}
}

#[cfg(Py_3_12)]
pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
// To keep the implementation simple, just write the exception into the interpreter,
// which will cause it to be normalized
self.restore(py);
// Safety: self.restore(py) will set the raised exception
let pvalue = unsafe { Py::from_owned_ptr(py, ffi::PyErr_GetRaisedException()) };
PyErrStateNormalized { pvalue }
}

#[cfg(not(Py_3_12))]
pub(crate) fn restore(self, py: Python<'_>) {
let (ptype, pvalue, ptraceback) = self.into_ffi_tuple(py);
};
unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) }
}

Expand Down Expand Up @@ -189,3 +201,23 @@ impl PyErrState {
}
}
}

fn lazy_into_normalized_ffi_tuple(
py: Python<'_>,
lazy: Box<PyErrStateLazyFn>,
) -> (*mut ffi::PyObject, *mut ffi::PyObject, *mut ffi::PyObject) {
let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
let (mut ptype, mut pvalue) = if unsafe { ffi::PyExceptionClass_Check(ptype.as_ptr()) } == 0 {
(
PyTypeError::type_object_raw(py).cast(),
PyString::new(py, "exceptions must derive from BaseException").into_ptr(),
)
} else {
(ptype.into_ptr(), pvalue.into_ptr())
};
let mut ptraceback = std::ptr::null_mut();
unsafe {
ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
}
(ptype, pvalue, ptraceback)
}
9 changes: 4 additions & 5 deletions src/err/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,18 +325,17 @@ impl PyErr {

#[cfg(Py_3_12)]
fn _take(py: Python<'_>) -> Option<PyErr> {
let pvalue = unsafe {
py.from_owned_ptr_or_opt::<PyBaseException>(ffi::PyErr_GetRaisedException())
}?;
let state = PyErrStateNormalized::take(py)?;
let pvalue = state.pvalue.as_ref(py);
if pvalue.get_type().as_ptr() == PanicException::type_object_raw(py).cast() {
let msg: String = pvalue
.str()
.map(|py_str| py_str.to_string_lossy().into())
.unwrap_or_else(|_| String::from("Unwrapped panic from Python code"));
Self::print_panic_and_unwind(py, PyErrState::normalized(pvalue), msg)
Self::print_panic_and_unwind(py, PyErrState::Normalized(state), msg)
}

Some(PyErr::from_state(PyErrState::normalized(pvalue)))
Some(PyErr::from_state(PyErrState::Normalized(state)))
}

fn print_panic_and_unwind(py: Python<'_>, state: PyErrState, msg: String) -> ! {
Expand Down

0 comments on commit 6e88e59

Please sign in to comment.