Skip to content

Commit

Permalink
Merge pull request #3323 from davidhewitt/pyerr-simplification
Browse files Browse the repository at this point in the history
merge PyErr internal states for simplicity
  • Loading branch information
davidhewitt authored Jul 17, 2023
2 parents e5a7400 + 2d1b8e0 commit 421e13a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 61 deletions.
63 changes: 26 additions & 37 deletions src/err/err_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
exceptions::{PyBaseException, PyTypeError},
ffi,
types::{PyTraceback, PyType},
AsPyPointer, IntoPy, IntoPyPointer, Py, PyObject, Python,
AsPyPointer, IntoPy, IntoPyPointer, Py, PyAny, PyObject, PyTypeInfo, Python,
};

#[derive(Clone)]
Expand All @@ -12,15 +12,16 @@ pub(crate) struct PyErrStateNormalized {
pub ptraceback: Option<Py<PyTraceback>>,
}

pub(crate) struct PyErrStateLazyFnOutput {
pub(crate) ptype: PyObject,
pub(crate) pvalue: PyObject,
}

pub(crate) type PyErrStateLazyFn =
dyn for<'py> FnOnce(Python<'py>) -> PyErrStateLazyFnOutput + Send + Sync;

pub(crate) enum PyErrState {
LazyTypeAndValue {
ptype: for<'py> fn(Python<'py>) -> &PyType,
pvalue: Box<dyn for<'py> FnOnce(Python<'py>) -> PyObject + Send + Sync>,
},
LazyValue {
ptype: Py<PyType>,
pvalue: Box<dyn for<'py> FnOnce(Python<'py>) -> PyObject + Send + Sync>,
},
Lazy(Box<PyErrStateLazyFn>),
FfiTuple {
ptype: PyObject,
pvalue: Option<PyObject>,
Expand All @@ -44,10 +45,14 @@ where
}
}

pub(crate) fn boxed_args(
args: impl PyErrArguments + 'static,
) -> Box<dyn for<'py> FnOnce(Python<'py>) -> PyObject + Send + Sync> {
Box::new(|py| args.arguments(py))
impl PyErrState {
pub(crate) fn lazy(ptype: &PyAny, args: impl PyErrArguments + 'static) -> Self {
let ptype = ptype.into();
PyErrState::Lazy(Box::new(move |py| PyErrStateLazyFnOutput {
ptype,
pvalue: args.arguments(py),
}))
}
}

impl PyErrState {
Expand All @@ -56,27 +61,12 @@ impl PyErrState {
py: Python<'_>,
) -> (*mut ffi::PyObject, *mut ffi::PyObject, *mut ffi::PyObject) {
match self {
PyErrState::LazyTypeAndValue { ptype, pvalue } => {
let ty = ptype(py);
if unsafe { ffi::PyExceptionClass_Check(ty.as_ptr()) } == 0 {
Self::exceptions_must_derive_from_base_exception(py).into_ffi_tuple(py)
} else {
(
ptype(py).into_ptr(),
pvalue(py).into_ptr(),
std::ptr::null_mut(),
)
}
}
PyErrState::LazyValue { ptype, pvalue } => {
PyErrState::Lazy(lazy) => {
let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
if unsafe { ffi::PyExceptionClass_Check(ptype.as_ptr()) } == 0 {
Self::exceptions_must_derive_from_base_exception(py).into_ffi_tuple(py)
} else {
(
ptype.into_ptr(),
pvalue(py).into_ptr(),
std::ptr::null_mut(),
)
(ptype.into_ptr(), pvalue.into_ptr(), std::ptr::null_mut())
}
}
PyErrState::FfiTuple {
Expand All @@ -92,11 +82,10 @@ impl PyErrState {
}
}

#[inline]
pub(crate) fn exceptions_must_derive_from_base_exception(py: Python<'_>) -> Self {
PyErrState::LazyValue {
ptype: py.get_type::<PyTypeError>().into(),
pvalue: boxed_args("exceptions must derive from BaseException"),
}
fn exceptions_must_derive_from_base_exception(py: Python<'_>) -> Self {
PyErrState::lazy(
PyTypeError::type_object(py),
"exceptions must derive from BaseException",
)
}
}
38 changes: 14 additions & 24 deletions src/err/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mod err_state;
mod impls;

pub use err_state::PyErrArguments;
use err_state::{boxed_args, PyErrState, PyErrStateNormalized};
use err_state::{PyErrState, PyErrStateLazyFnOutput, PyErrStateNormalized};

/// Represents a Python exception.
///
Expand Down Expand Up @@ -119,10 +119,12 @@ impl PyErr {
T: PyTypeInfo,
A: PyErrArguments + Send + Sync + 'static,
{
PyErr::from_state(PyErrState::LazyTypeAndValue {
ptype: T::type_object,
pvalue: boxed_args(args),
})
PyErr::from_state(PyErrState::Lazy(Box::new(move |py| {
PyErrStateLazyFnOutput {
ptype: T::type_object(py).into(),
pvalue: args.arguments(py),
}
})))
}

/// Constructs a new PyErr from the given Python type and arguments.
Expand All @@ -139,10 +141,7 @@ impl PyErr {
where
A: PyErrArguments + Send + Sync + 'static,
{
PyErr::from_state(PyErrState::LazyValue {
ptype: ty.into(),
pvalue: boxed_args(args),
})
PyErr::from_state(PyErrState::lazy(ty, args))
}

/// Creates a new PyErr.
Expand Down Expand Up @@ -183,14 +182,10 @@ impl PyErr {
pvalue: obj.into(),
ptraceback: None,
})
} else if unsafe { ffi::PyExceptionClass_Check(obj.as_ptr()) } != 0 {
PyErrState::FfiTuple {
ptype: obj.into(),
pvalue: None,
ptraceback: None,
}
} else {
return exceptions_must_derive_from_base_exception(obj.py());
// Assume obj is Type[Exception]; let later normalization handle if this
// is not the case
PyErrState::lazy(obj, obj.py().None())
};

PyErr::from_state(state)
Expand Down Expand Up @@ -277,9 +272,9 @@ impl PyErr {
ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);

// Convert to Py immediately so that any references are freed by early return.
let ptype = Py::from_owned_ptr_or_opt(py, ptype);
let pvalue = Py::from_owned_ptr_or_opt(py, pvalue);
let ptraceback = Py::from_owned_ptr_or_opt(py, ptraceback);
let ptype = PyObject::from_owned_ptr_or_opt(py, ptype);
let pvalue = PyObject::from_owned_ptr_or_opt(py, pvalue);
let ptraceback = PyObject::from_owned_ptr_or_opt(py, ptraceback);

// A valid exception state should always have a non-null ptype, but the other two may be
// null.
Expand Down Expand Up @@ -786,11 +781,6 @@ impl_signed_integer!(i64);
impl_signed_integer!(i128);
impl_signed_integer!(isize);

#[inline]
fn exceptions_must_derive_from_base_exception(py: Python<'_>) -> PyErr {
PyErr::from_state(PyErrState::exceptions_must_derive_from_base_exception(py))
}

#[cfg(test)]
mod tests {
use super::PyErrState;
Expand Down

0 comments on commit 421e13a

Please sign in to comment.