diff --git a/src/err/err_state.rs b/src/err/err_state.rs index 50a17fda474..abf1817cd57 100644 --- a/src/err/err_state.rs +++ b/src/err/err_state.rs @@ -2,7 +2,7 @@ use crate::{ exceptions::{PyBaseException, PyTypeError}, ffi, types::{PyTraceback, PyType}, - AsPyPointer, IntoPy, IntoPyPointer, Py, PyObject, Python, + AsPyPointer, IntoPy, IntoPyPointer, Py, PyAny, PyObject, PyTypeInfo, Python, }; #[derive(Clone)] @@ -12,15 +12,16 @@ pub(crate) struct PyErrStateNormalized { pub ptraceback: Option>, } +pub(crate) struct PyErrStateLazyFnOutput { + pub(crate) ptype: PyObject, + pub(crate) pvalue: PyObject, +} + +pub(crate) type PyErrStateLazyFn = + dyn for<'py> FnOnce(Python<'py>) -> PyErrStateLazyFnOutput + Send + Sync; + pub(crate) enum PyErrState { - LazyTypeAndValue { - ptype: for<'py> fn(Python<'py>) -> &PyType, - pvalue: Box FnOnce(Python<'py>) -> PyObject + Send + Sync>, - }, - LazyValue { - ptype: Py, - pvalue: Box FnOnce(Python<'py>) -> PyObject + Send + Sync>, - }, + Lazy(Box), FfiTuple { ptype: PyObject, pvalue: Option, @@ -44,10 +45,14 @@ where } } -pub(crate) fn boxed_args( - args: impl PyErrArguments + 'static, -) -> Box FnOnce(Python<'py>) -> PyObject + Send + Sync> { - Box::new(|py| args.arguments(py)) +impl PyErrState { + pub(crate) fn lazy(ptype: &PyAny, args: impl PyErrArguments + 'static) -> Self { + let ptype = ptype.into(); + PyErrState::Lazy(Box::new(move |py| PyErrStateLazyFnOutput { + ptype, + pvalue: args.arguments(py), + })) + } } impl PyErrState { @@ -56,27 +61,12 @@ impl PyErrState { py: Python<'_>, ) -> (*mut ffi::PyObject, *mut ffi::PyObject, *mut ffi::PyObject) { match self { - PyErrState::LazyTypeAndValue { ptype, pvalue } => { - let ty = ptype(py); - if unsafe { ffi::PyExceptionClass_Check(ty.as_ptr()) } == 0 { - Self::exceptions_must_derive_from_base_exception(py).into_ffi_tuple(py) - } else { - ( - ptype(py).into_ptr(), - pvalue(py).into_ptr(), - std::ptr::null_mut(), - ) - } - } - PyErrState::LazyValue { ptype, pvalue } => { + PyErrState::Lazy(lazy) => { + let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py); if unsafe { ffi::PyExceptionClass_Check(ptype.as_ptr()) } == 0 { Self::exceptions_must_derive_from_base_exception(py).into_ffi_tuple(py) } else { - ( - ptype.into_ptr(), - pvalue(py).into_ptr(), - std::ptr::null_mut(), - ) + (ptype.into_ptr(), pvalue.into_ptr(), std::ptr::null_mut()) } } PyErrState::FfiTuple { @@ -92,11 +82,10 @@ impl PyErrState { } } - #[inline] - pub(crate) fn exceptions_must_derive_from_base_exception(py: Python<'_>) -> Self { - PyErrState::LazyValue { - ptype: py.get_type::().into(), - pvalue: boxed_args("exceptions must derive from BaseException"), - } + fn exceptions_must_derive_from_base_exception(py: Python<'_>) -> Self { + PyErrState::lazy( + PyTypeError::type_object(py), + "exceptions must derive from BaseException", + ) } } diff --git a/src/err/mod.rs b/src/err/mod.rs index 4d3331b07d8..cdfb64b1249 100644 --- a/src/err/mod.rs +++ b/src/err/mod.rs @@ -14,7 +14,7 @@ mod err_state; mod impls; pub use err_state::PyErrArguments; -use err_state::{boxed_args, PyErrState, PyErrStateNormalized}; +use err_state::{PyErrState, PyErrStateLazyFnOutput, PyErrStateNormalized}; /// Represents a Python exception. /// @@ -119,10 +119,12 @@ impl PyErr { T: PyTypeInfo, A: PyErrArguments + Send + Sync + 'static, { - PyErr::from_state(PyErrState::LazyTypeAndValue { - ptype: T::type_object, - pvalue: boxed_args(args), - }) + PyErr::from_state(PyErrState::Lazy(Box::new(move |py| { + PyErrStateLazyFnOutput { + ptype: T::type_object(py).into(), + pvalue: args.arguments(py), + } + }))) } /// Constructs a new PyErr from the given Python type and arguments. @@ -139,10 +141,7 @@ impl PyErr { where A: PyErrArguments + Send + Sync + 'static, { - PyErr::from_state(PyErrState::LazyValue { - ptype: ty.into(), - pvalue: boxed_args(args), - }) + PyErr::from_state(PyErrState::lazy(ty, args)) } /// Creates a new PyErr. @@ -183,14 +182,10 @@ impl PyErr { pvalue: obj.into(), ptraceback: None, }) - } else if unsafe { ffi::PyExceptionClass_Check(obj.as_ptr()) } != 0 { - PyErrState::FfiTuple { - ptype: obj.into(), - pvalue: None, - ptraceback: None, - } } else { - return exceptions_must_derive_from_base_exception(obj.py()); + // Assume obj is Type[Exception]; let later normalization handle if this + // is not the case + PyErrState::lazy(obj, obj.py().None()) }; PyErr::from_state(state) @@ -277,9 +272,9 @@ impl PyErr { ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback); // Convert to Py immediately so that any references are freed by early return. - let ptype = Py::from_owned_ptr_or_opt(py, ptype); - let pvalue = Py::from_owned_ptr_or_opt(py, pvalue); - let ptraceback = Py::from_owned_ptr_or_opt(py, ptraceback); + let ptype = PyObject::from_owned_ptr_or_opt(py, ptype); + let pvalue = PyObject::from_owned_ptr_or_opt(py, pvalue); + let ptraceback = PyObject::from_owned_ptr_or_opt(py, ptraceback); // A valid exception state should always have a non-null ptype, but the other two may be // null. @@ -786,11 +781,6 @@ impl_signed_integer!(i64); impl_signed_integer!(i128); impl_signed_integer!(isize); -#[inline] -fn exceptions_must_derive_from_base_exception(py: Python<'_>) -> PyErr { - PyErr::from_state(PyErrState::exceptions_must_derive_from_base_exception(py)) -} - #[cfg(test)] mod tests { use super::PyErrState;