Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix exception handling on Python 3.12 #3306

Merged
merged 1 commit into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/3306.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update `PyErr` for 3.12 betas to avoid deprecated ffi methods.
108 changes: 100 additions & 8 deletions src/err/err_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,37 @@ use crate::{

#[derive(Clone)]
pub(crate) struct PyErrStateNormalized {
#[cfg(not(Py_3_12))]
pub ptype: Py<PyType>,
pub pvalue: Py<PyBaseException>,
#[cfg(not(Py_3_12))]
pub ptraceback: Option<Py<PyTraceback>>,
}

impl PyErrStateNormalized {
#[cfg(not(Py_3_12))]
pub(crate) fn ptype<'py>(&'py self, py: Python<'py>) -> &'py PyType {
self.ptype.as_ref(py)
}

#[cfg(Py_3_12)]
pub(crate) fn ptype<'py>(&'py self, py: Python<'py>) -> &'py PyType {
self.pvalue.as_ref(py).get_type()
}

#[cfg(not(Py_3_12))]
pub(crate) fn ptraceback<'py>(&'py self, py: Python<'py>) -> Option<&'py PyTraceback> {
self.ptraceback
.as_ref()
.map(|traceback| traceback.as_ref(py))
}

#[cfg(Py_3_12)]
pub(crate) fn ptraceback<'py>(&'py self, py: Python<'py>) -> Option<&'py PyTraceback> {
unsafe { py.from_owned_ptr_or_opt(ffi::PyException_GetTraceback(self.pvalue.as_ptr())) }
}
}

pub(crate) struct PyErrStateLazyFnOutput {
pub(crate) ptype: PyObject,
pub(crate) pvalue: PyObject,
Expand All @@ -22,6 +48,7 @@ pub(crate) type PyErrStateLazyFn =

pub(crate) enum PyErrState {
Lazy(Box<PyErrStateLazyFn>),
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved
#[cfg(not(Py_3_12))]
FfiTuple {
ptype: PyObject,
pvalue: Option<PyObject>,
Expand Down Expand Up @@ -53,9 +80,23 @@ impl PyErrState {
pvalue: args.arguments(py),
}))
}
}

impl PyErrState {
pub(crate) fn normalized(pvalue: &PyBaseException) -> Self {
Self::Normalized(PyErrStateNormalized {
#[cfg(not(Py_3_12))]
ptype: pvalue.get_type().into(),
pvalue: pvalue.into(),
#[cfg(not(Py_3_12))]
ptraceback: unsafe {
Py::from_owned_ptr_or_opt(
pvalue.py(),
ffi::PyException_GetTraceback(pvalue.as_ptr()),
)
},
})
}

#[cfg(not(Py_3_12))]
pub(crate) fn into_ffi_tuple(
self,
py: Python<'_>,
Expand All @@ -64,7 +105,11 @@ impl PyErrState {
PyErrState::Lazy(lazy) => {
let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
if unsafe { ffi::PyExceptionClass_Check(ptype.as_ptr()) } == 0 {
Self::exceptions_must_derive_from_base_exception(py).into_ffi_tuple(py)
PyErrState::lazy(
PyTypeError::type_object(py),
"exceptions must derive from BaseException",
)
.into_ffi_tuple(py)
} else {
(ptype.into_ptr(), pvalue.into_ptr(), std::ptr::null_mut())
}
Expand All @@ -82,10 +127,57 @@ impl PyErrState {
}
}

fn exceptions_must_derive_from_base_exception(py: Python<'_>) -> Self {
PyErrState::lazy(
PyTypeError::type_object(py),
"exceptions must derive from BaseException",
)
#[cfg(not(Py_3_12))]
pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
let (mut ptype, mut pvalue, mut ptraceback) = self.into_ffi_tuple(py);

unsafe {
ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
PyErrStateNormalized {
ptype: Py::from_owned_ptr_or_opt(py, ptype).expect("Exception type missing"),
pvalue: Py::from_owned_ptr_or_opt(py, pvalue).expect("Exception value missing"),
ptraceback: Py::from_owned_ptr_or_opt(py, ptraceback),
}
}
}

#[cfg(Py_3_12)]
pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
// To keep the implementation simple, just write the exception into the interpreter,
// which will cause it to be normalized
self.restore(py);
// Safety: self.restore(py) will set the raised exception
let pvalue = unsafe { Py::from_owned_ptr(py, ffi::PyErr_GetRaisedException()) };
PyErrStateNormalized { pvalue }
}

#[cfg(not(Py_3_12))]
pub(crate) fn restore(self, py: Python<'_>) {
let (ptype, pvalue, ptraceback) = self.into_ffi_tuple(py);
unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) }
}

#[cfg(Py_3_12)]
pub(crate) fn restore(self, py: Python<'_>) {
match self {
PyErrState::Lazy(lazy) => {
let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
unsafe {
if ffi::PyExceptionClass_Check(ptype.as_ptr()) == 0 {
ffi::PyErr_SetString(
PyTypeError::type_object_raw(py).cast(),
"exceptions must derive from BaseException\0"
.as_ptr()
.cast(),
)
} else {
ffi::PyErr_SetObject(ptype.as_ptr(), pvalue.as_ptr())
}
}
}
PyErrState::Normalized(PyErrStateNormalized { pvalue }) => unsafe {
ffi::PyErr_SetRaisedException(pvalue.into_ptr())
},
}
}
}
95 changes: 50 additions & 45 deletions src/err/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,7 @@ impl PyErr {
/// ```
pub fn from_value(obj: &PyAny) -> PyErr {
let state = if let Ok(obj) = obj.downcast::<PyBaseException>() {
let pvalue: Py<PyBaseException> = obj.into();

let ptraceback = unsafe {
Py::from_owned_ptr_or_opt(obj.py(), ffi::PyException_GetTraceback(pvalue.as_ptr()))
};

PyErrState::Normalized(PyErrStateNormalized {
ptype: obj.get_type().into(),
pvalue,
ptraceback,
})
PyErrState::normalized(obj)
} else {
// Assume obj is Type[Exception]; let later normalization handle if this
// is not the case
Expand All @@ -209,7 +199,7 @@ impl PyErr {
/// });
/// ```
pub fn get_type<'py>(&'py self, py: Python<'py>) -> &'py PyType {
self.normalized(py).ptype.as_ref(py)
self.normalized(py).ptype(py)
}

/// Returns the value of this exception.
Expand All @@ -236,7 +226,7 @@ impl PyErr {
// complexity.
let normalized = self.normalized(py);
let exc = normalized.pvalue.clone_ref(py);
if let Some(tb) = normalized.ptraceback.as_ref() {
if let Some(tb) = normalized.ptraceback(py) {
unsafe {
ffi::PyException_SetTraceback(exc.as_ptr(), tb.as_ptr());
}
Expand All @@ -256,10 +246,7 @@ impl PyErr {
/// });
/// ```
pub fn traceback<'py>(&'py self, py: Python<'py>) -> Option<&'py PyTraceback> {
self.normalized(py)
.ptraceback
.as_ref()
.map(|obj| obj.as_ref(py))
self.normalized(py).ptraceback(py)
}

/// Gets whether an error is present in the Python interpreter's global state.
Expand All @@ -278,6 +265,11 @@ impl PyErr {
/// expected to have been set, for example from [`PyErr::occurred`] or by an error return value
/// from a C FFI function, use [`PyErr::fetch`].
pub fn take(py: Python<'_>) -> Option<PyErr> {
Self::_take(py)
}

#[cfg(not(Py_3_12))]
fn _take(py: Python<'_>) -> Option<PyErr> {
adamreichold marked this conversation as resolved.
Show resolved Hide resolved
let (ptype, pvalue, ptraceback) = unsafe {
let mut ptype: *mut ffi::PyObject = std::ptr::null_mut();
let mut pvalue: *mut ffi::PyObject = std::ptr::null_mut();
Expand Down Expand Up @@ -316,17 +308,12 @@ impl PyErr {
.map(|py_str| py_str.to_string_lossy().into())
.unwrap_or_else(|| String::from("Unwrapped panic from Python code"));

eprintln!(
"--- PyO3 is resuming a panic after fetching a PanicException from Python. ---"
);
eprintln!("Python stack trace below:");

unsafe {
ffi::PyErr_Restore(ptype.into_ptr(), pvalue.into_ptr(), ptraceback.into_ptr());
ffi::PyErr_PrintEx(0);
}

std::panic::resume_unwind(Box::new(msg))
let state = PyErrState::FfiTuple {
ptype,
pvalue,
ptraceback,
};
Self::print_panic_and_unwind(py, state, msg)
}

Some(PyErr::from_state(PyErrState::FfiTuple {
Expand All @@ -336,6 +323,35 @@ impl PyErr {
}))
}

#[cfg(Py_3_12)]
fn _take(py: Python<'_>) -> Option<PyErr> {
let pvalue = unsafe {
py.from_owned_ptr_or_opt::<PyBaseException>(ffi::PyErr_GetRaisedException())
}?;
if pvalue.get_type().as_ptr() == PanicException::type_object_raw(py).cast() {
let msg: String = pvalue
.str()
.map(|py_str| py_str.to_string_lossy().into())
.unwrap_or_else(|_| String::from("Unwrapped panic from Python code"));
Self::print_panic_and_unwind(py, PyErrState::normalized(pvalue), msg)
}

Some(PyErr::from_state(PyErrState::normalized(pvalue)))
}

fn print_panic_and_unwind(py: Python<'_>, state: PyErrState, msg: String) -> ! {
eprintln!("--- PyO3 is resuming a panic after fetching a PanicException from Python. ---");
eprintln!("Python stack trace below:");

state.restore(py);

unsafe {
ffi::PyErr_PrintEx(0);
}

std::panic::resume_unwind(Box::new(msg))
}

/// Equivalent to [PyErr::take], but when no error is set:
/// - Panics in debug mode.
/// - Returns a `SystemError` in release mode.
Expand Down Expand Up @@ -457,15 +473,10 @@ impl PyErr {
/// This is the opposite of `PyErr::fetch()`.
#[inline]
pub fn restore(self, py: Python<'_>) {
let state = match self.state.into_inner() {
Some(state) => state,
// Safety: restore takes `self` by value so nothing else is accessing this err
// and the invariant is that state is always defined except during make_normalized
None => unsafe { std::hint::unreachable_unchecked() },
};

let (ptype, pvalue, ptraceback) = state.into_ffi_tuple(py);
unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) }
self.state
.into_inner()
.expect("PyErr state should never be invalid outside of normalization")
.restore(py)
}

/// Reports the error as unraisable.
Expand Down Expand Up @@ -649,17 +660,10 @@ impl PyErr {
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};
let (mut ptype, mut pvalue, mut ptraceback) = state.into_ffi_tuple(py);

unsafe {
ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
let self_state = &mut *self.state.get();
*self_state = Some(PyErrState::Normalized(PyErrStateNormalized {
ptype: Py::from_owned_ptr_or_opt(py, ptype).expect("Exception type missing"),
pvalue: Py::from_owned_ptr_or_opt(py, pvalue).expect("Exception value missing"),
ptraceback: Py::from_owned_ptr_or_opt(py, ptraceback),
}));

*self_state = Some(PyErrState::Normalized(state.normalize(py)));
match self_state {
Some(PyErrState::Normalized(n)) => n,
_ => unreachable!(),
Expand Down Expand Up @@ -826,6 +830,7 @@ mod tests {
assert!(err.is_instance_of::<exceptions::PyTypeError>(py));
err.restore(py);
let err = PyErr::fetch(py);

assert!(err.is_instance_of::<exceptions::PyTypeError>(py));
assert_eq!(
err.to_string(),
Expand Down