Skip to content

Commit

Permalink
fix exception handling on Python 3.12
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Jul 14, 2023
1 parent 65312b4 commit 3619934
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 76 deletions.
139 changes: 104 additions & 35 deletions src/err/err_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,53 @@ use crate::{
exceptions::{PyBaseException, PyTypeError},
ffi,
types::{PyTraceback, PyType},
AsPyPointer, IntoPy, IntoPyPointer, Py, PyObject, Python,
AsPyPointer, IntoPy, IntoPyPointer, Py, PyAny, PyObject, PyTypeInfo, Python,
};

#[derive(Clone)]
pub(crate) struct PyErrStateNormalized {
#[cfg(not(Py_3_12))]
pub ptype: Py<PyType>,
pub pvalue: Py<PyBaseException>,
#[cfg(not(Py_3_12))]
pub ptraceback: Option<Py<PyTraceback>>,
}

#[cfg(not(Py_3_12))]
impl PyErrStateNormalized {
pub(crate) fn ptype<'py>(&'py self, py: Python<'py>) -> &'py PyType {
self.ptype.as_ref(py)
}

pub(crate) fn ptraceback<'py>(&'py self, py: Python<'py>) -> Option<&'py PyTraceback> {
self.ptraceback
.as_ref()
.map(|traceback| traceback.as_ref(py))
}
}

#[cfg(Py_3_12)]
impl PyErrStateNormalized {
pub(crate) fn ptype<'py>(&'py self, py: Python<'py>) -> &'py PyType {
self.pvalue.as_ref(py).get_type()
}

pub(crate) fn ptraceback<'py>(&'py self, py: Python<'py>) -> Option<&'py PyTraceback> {
unsafe { py.from_owned_ptr_or_opt(ffi::PyException_GetTraceback(self.pvalue.as_ptr())) }
}
}

pub(crate) struct PyErrStateLazyFnOutput {
pub(crate) ptype: PyObject,
pub(crate) pvalue: PyObject,
}

pub(crate) type PyErrStateLazyFn =
dyn for<'py> FnOnce(Python<'py>) -> PyErrStateLazyFnOutput + Send + Sync;

pub(crate) enum PyErrState {
LazyTypeAndValue {
ptype: for<'py> fn(Python<'py>) -> &PyType,
pvalue: Box<dyn for<'py> FnOnce(Python<'py>) -> PyObject + Send + Sync>,
},
LazyValue {
ptype: Py<PyType>,
pvalue: Box<dyn for<'py> FnOnce(Python<'py>) -> PyObject + Send + Sync>,
},
Lazy(Box<PyErrStateLazyFn>),
#[cfg(not(Py_3_12))]
FfiTuple {
ptype: PyObject,
pvalue: Option<PyObject>,
Expand All @@ -45,38 +73,33 @@ where
}

pub(crate) fn boxed_args(
ptype: &PyAny,
args: impl PyErrArguments + 'static,
) -> Box<dyn for<'py> FnOnce(Python<'py>) -> PyObject + Send + Sync> {
Box::new(|py| args.arguments(py))
) -> Box<PyErrStateLazyFn> {
let ptype = ptype.into();
Box::new(move |py| PyErrStateLazyFnOutput {
ptype,
pvalue: args.arguments(py),
})
}

impl PyErrState {
#[cfg(not(Py_3_12))]
pub(crate) fn into_ffi_tuple(
self,
py: Python<'_>,
) -> (*mut ffi::PyObject, *mut ffi::PyObject, *mut ffi::PyObject) {
match self {
PyErrState::LazyTypeAndValue { ptype, pvalue } => {
let ty = ptype(py);
if unsafe { ffi::PyExceptionClass_Check(ty.as_ptr()) } == 0 {
Self::exceptions_must_derive_from_base_exception(py).into_ffi_tuple(py)
} else {
(
ptype(py).into_ptr(),
pvalue(py).into_ptr(),
std::ptr::null_mut(),
)
}
}
PyErrState::LazyValue { ptype, pvalue } => {
PyErrState::Lazy(lazy) => {
let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
if unsafe { ffi::PyExceptionClass_Check(ptype.as_ptr()) } == 0 {
Self::exceptions_must_derive_from_base_exception(py).into_ffi_tuple(py)
PyErrState::Lazy(boxed_args(
PyTypeError::type_object(py),
"exceptions must derive from BaseException",
))
.into_ffi_tuple(py)
} else {
(
ptype.into_ptr(),
pvalue(py).into_ptr(),
std::ptr::null_mut(),
)
(ptype.into_ptr(), pvalue.into_ptr(), std::ptr::null_mut())
}
}
PyErrState::FfiTuple {
Expand All @@ -92,11 +115,57 @@ impl PyErrState {
}
}

#[inline]
pub(crate) fn exceptions_must_derive_from_base_exception(py: Python<'_>) -> Self {
PyErrState::LazyValue {
ptype: py.get_type::<PyTypeError>().into(),
pvalue: boxed_args("exceptions must derive from BaseException"),
#[cfg(not(Py_3_12))]
pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
let (mut ptype, mut pvalue, mut ptraceback) = self.into_ffi_tuple(py);

unsafe {
ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
PyErrStateNormalized {
ptype: Py::from_owned_ptr_or_opt(py, ptype).expect("Exception type missing"),
pvalue: Py::from_owned_ptr_or_opt(py, pvalue).expect("Exception value missing"),
ptraceback: Py::from_owned_ptr_or_opt(py, ptraceback),
}
}
}

#[cfg(Py_3_12)]
pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
// To keep the implementation simple, just write the exception into the interpreter,
// which will cause it to be normalized
self.restore(py);
// Safety: self.restore(py) will set the raised exception
let pvalue = unsafe { Py::from_owned_ptr(py, ffi::PyErr_GetRaisedException()) };
PyErrStateNormalized { pvalue }
}

#[cfg(not(Py_3_12))]
pub(crate) fn restore(self, py: Python<'_>) {
let (ptype, pvalue, ptraceback) = self.into_ffi_tuple(py);
unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) }
}

#[cfg(Py_3_12)]
pub(crate) fn restore(self, py: Python<'_>) {
match self {
PyErrState::Lazy(lazy) => {
let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
unsafe {
if ffi::PyExceptionClass_Check(ptype.as_ptr()) == 0 {
ffi::PyErr_SetString(
PyTypeError::type_object_raw(py).cast(),
"exceptions must derive from BaseException\0"
.as_ptr()
.cast(),
)
} else {
ffi::PyErr_SetObject(ptype.as_ptr(), pvalue.as_ptr())
}
}
}
PyErrState::Normalized(PyErrStateNormalized { pvalue }) => unsafe {
ffi::PyErr_SetRaisedException(pvalue.into_ptr())
},
}
}
}
92 changes: 51 additions & 41 deletions src/err/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ mod impls;
pub use err_state::PyErrArguments;
use err_state::{boxed_args, PyErrState, PyErrStateNormalized};

use self::err_state::PyErrStateLazyFnOutput;

/// Represents a Python exception.
///
/// To avoid needing access to [`Python`] in `Into` conversions to create `PyErr` (thus improving
Expand Down Expand Up @@ -119,10 +121,12 @@ impl PyErr {
T: PyTypeInfo,
A: PyErrArguments + Send + Sync + 'static,
{
PyErr::from_state(PyErrState::LazyTypeAndValue {
ptype: T::type_object,
pvalue: boxed_args(args),
})
PyErr::from_state(PyErrState::Lazy(Box::new(move |py| {
PyErrStateLazyFnOutput {
ptype: T::type_object(py).into(),
pvalue: args.arguments(py),
}
})))
}

/// Constructs a new PyErr from the given Python type and arguments.
Expand All @@ -139,10 +143,7 @@ impl PyErr {
where
A: PyErrArguments + Send + Sync + 'static,
{
PyErr::from_state(PyErrState::LazyValue {
ptype: ty.into(),
pvalue: boxed_args(args),
})
PyErr::from_state(PyErrState::Lazy(boxed_args(ty, args)))
}

/// Creates a new PyErr.
Expand Down Expand Up @@ -179,18 +180,16 @@ impl PyErr {
pub fn from_value(obj: &PyAny) -> PyErr {
let state = if let Ok(obj) = obj.downcast::<PyBaseException>() {
PyErrState::Normalized(PyErrStateNormalized {
#[cfg(not(Py_3_12))]
ptype: obj.get_type().into(),
pvalue: obj.into(),
#[cfg(not(Py_3_12))]
ptraceback: None,
})
} else if unsafe { ffi::PyExceptionClass_Check(obj.as_ptr()) } != 0 {
PyErrState::FfiTuple {
ptype: obj.into(),
pvalue: None,
ptraceback: None,
}
} else {
return exceptions_must_derive_from_base_exception(obj.py());
// Assume obj is Type[Exception]; let later normalization handle if this
// is not the case
PyErrState::Lazy(boxed_args(obj, obj.py().None()))
};

PyErr::from_state(state)
Expand All @@ -208,7 +207,7 @@ impl PyErr {
/// });
/// ```
pub fn get_type<'py>(&'py self, py: Python<'py>) -> &'py PyType {
self.normalized(py).ptype.as_ref(py)
self.normalized(py).ptype(py)
}

/// Returns the value of this exception.
Expand Down Expand Up @@ -248,10 +247,7 @@ impl PyErr {
/// });
/// ```
pub fn traceback<'py>(&'py self, py: Python<'py>) -> Option<&'py PyTraceback> {
self.normalized(py)
.ptraceback
.as_ref()
.map(|obj| obj.as_ref(py))
self.normalized(py).ptraceback(py)
}

/// Gets whether an error is present in the Python interpreter's global state.
Expand All @@ -270,16 +266,21 @@ impl PyErr {
/// expected to have been set, for example from [`PyErr::occurred`] or by an error return value
/// from a C FFI function, use [`PyErr::fetch`].
pub fn take(py: Python<'_>) -> Option<PyErr> {
Self::_take(py)
}

#[cfg(not(Py_3_12))]
fn _take(py: Python<'_>) -> Option<PyErr> {
let (ptype, pvalue, ptraceback) = unsafe {
let mut ptype: *mut ffi::PyObject = std::ptr::null_mut();
let mut pvalue: *mut ffi::PyObject = std::ptr::null_mut();
let mut ptraceback: *mut ffi::PyObject = std::ptr::null_mut();
ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);

// Convert to Py immediately so that any references are freed by early return.
let ptype = Py::from_owned_ptr_or_opt(py, ptype);
let pvalue = Py::from_owned_ptr_or_opt(py, pvalue);
let ptraceback = Py::from_owned_ptr_or_opt(py, ptraceback);
let ptype = PyObject::from_owned_ptr_or_opt(py, ptype);
let pvalue = PyObject::from_owned_ptr_or_opt(py, pvalue);
let ptraceback = PyObject::from_owned_ptr_or_opt(py, ptraceback);

// A valid exception state should always have a non-null ptype, but the other two may be
// null.
Expand Down Expand Up @@ -327,6 +328,29 @@ impl PyErr {
}))
}

#[cfg(Py_3_12)]
fn _take(py: Python<'_>) -> Option<PyErr> {
let pvalue = unsafe { Py::from_owned_ptr_or_opt(py, ffi::PyErr_GetRaisedException()) }?;
let state = PyErrStateNormalized { pvalue };
if state.ptype(py).as_ptr() == PanicException::type_object_raw(py).cast() {
let msg: String = state.pvalue.as_ref(py).to_string();

eprintln!(
"--- PyO3 is resuming a panic after fetching a PanicException from Python. ---"
);
eprintln!("Python stack trace below:");

unsafe {
ffi::PyErr_SetRaisedException(state.pvalue.into_ptr());
ffi::PyErr_PrintEx(0);
}

std::panic::resume_unwind(Box::new(msg))
}

Some(PyErr::from_state(PyErrState::Normalized(state)))
}

/// Equivalent to [PyErr::take], but when no error is set:
/// - Panics in debug mode.
/// - Returns a `SystemError` in release mode.
Expand Down Expand Up @@ -448,15 +472,12 @@ impl PyErr {
/// This is the opposite of `PyErr::fetch()`.
#[inline]
pub fn restore(self, py: Python<'_>) {
let state = match self.state.into_inner() {
Some(state) => state,
match self.state.into_inner() {
Some(state) => state.restore(py),
// Safety: restore takes `self` by value so nothing else is accessing this err
// and the invariant is that state is always defined except during make_normalized
None => unsafe { std::hint::unreachable_unchecked() },
};

let (ptype, pvalue, ptraceback) = state.into_ffi_tuple(py);
unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) }
}

/// Reports the error as unraisable.
Expand Down Expand Up @@ -640,17 +661,10 @@ impl PyErr {
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};
let (mut ptype, mut pvalue, mut ptraceback) = state.into_ffi_tuple(py);

unsafe {
ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
let self_state = &mut *self.state.get();
*self_state = Some(PyErrState::Normalized(PyErrStateNormalized {
ptype: Py::from_owned_ptr_or_opt(py, ptype).expect("Exception type missing"),
pvalue: Py::from_owned_ptr_or_opt(py, pvalue).expect("Exception value missing"),
ptraceback: Py::from_owned_ptr_or_opt(py, ptraceback),
}));

*self_state = Some(PyErrState::Normalized(state.normalize(py)));
match self_state {
Some(PyErrState::Normalized(n)) => n,
_ => unreachable!(),
Expand Down Expand Up @@ -786,11 +800,6 @@ impl_signed_integer!(i64);
impl_signed_integer!(i128);
impl_signed_integer!(isize);

#[inline]
fn exceptions_must_derive_from_base_exception(py: Python<'_>) -> PyErr {
PyErr::from_state(PyErrState::exceptions_must_derive_from_base_exception(py))
}

#[cfg(test)]
mod tests {
use super::PyErrState;
Expand Down Expand Up @@ -822,6 +831,7 @@ mod tests {
assert!(err.is_instance_of::<exceptions::PyTypeError>(py));
err.restore(py);
let err = PyErr::fetch(py);

assert!(err.is_instance_of::<exceptions::PyTypeError>(py));
assert_eq!(
err.to_string(),
Expand Down

0 comments on commit 3619934

Please sign in to comment.