From fa8c93cfd1271175df501f43ad8594e310066e2c Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Sun, 13 Dec 2020 22:51:19 +0000 Subject: [PATCH] pyclass #[new]: allow using custom error type --- CHANGELOG.md | 1 + pyo3-derive-backend/src/pymethod.rs | 3 ++- src/pyclass_init.rs | 16 +++++++-------- tests/test_class_new.rs | 30 +++++++++++++++++++++++++++++ 4 files changed, 40 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d695b901a28..f0a7a6aedab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Fix FFI definitions for `PyObject_Vectorcall` and `PyVectorcall_Call`. [#1287](https://github.com/PyO3/pyo3/pull/1285) - Fix building with Anaconda python inside a virtualenv. [#1290](https://github.com/PyO3/pyo3/pull/1290) - Fix definition of opaque FFI types. [#1312](https://github.com/PyO3/pyo3/pull/1312) +- Fix using custom error type in pyclass `#[new]` methods. [#1319](https://github.com/PyO3/pyo3/pull/1319) ## [0.12.4] - 2020-11-28 ### Fixed diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index ddfa4a1831b..225fbc06eb8 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -176,6 +176,7 @@ pub fn impl_wrap_new(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { _kwargs: *mut pyo3::ffi::PyObject) -> *mut pyo3::ffi::PyObject { use pyo3::type_object::PyTypeInfo; + use pyo3::callback::IntoPyCallbackOutput; use std::convert::TryFrom; const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#python_name),"()"); @@ -183,7 +184,7 @@ pub fn impl_wrap_new(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { let _args = _py.from_borrowed_ptr::(_args); let _kwargs: Option<&pyo3::types::PyDict> = _py.from_borrowed_ptr_or_opt(_kwargs); - let initializer = pyo3::PyClassInitializer::try_from(#body)?; + let initializer: pyo3::PyClassInitializer::<#cls> = #body.convert(_py)?; let cell = initializer.create_cell_from_subtype(_py, subtype)?; Ok(cell as *mut pyo3::ffi::PyObject) }) diff --git a/src/pyclass_init.rs b/src/pyclass_init.rs index 1d73f960cfc..2ae441af56d 100644 --- a/src/pyclass_init.rs +++ b/src/pyclass_init.rs @@ -1,7 +1,7 @@ //! Initialization utilities for `#[pyclass]`. +use crate::callback::IntoPyCallbackOutput; use crate::type_object::{PyBorrowFlagLayout, PyLayout, PySizedLayout, PyTypeInfo}; -use crate::{PyCell, PyClass, PyErr, PyResult, Python}; -use std::convert::TryFrom; +use crate::{PyCell, PyClass, PyResult, Python}; use std::marker::PhantomData; /// Initializer for Python types. @@ -182,16 +182,14 @@ where } } -// Implementation which propagates the error from input PyResult. Useful in proc macro -// code where `#[new]` may or may not return PyResult. -impl TryFrom> for PyClassInitializer +// Implementation used by proc macros to allow anything convertible to PyClassInitializer to be +// the return value of pyclass #[new] method (optionally wrapped in `Result`). +impl IntoPyCallbackOutput> for U where T: PyClass, U: Into>, { - type Error = PyErr; - - fn try_from(result: PyResult) -> PyResult { - result.map(Into::into) + fn convert(self, _py: Python) -> PyResult> { + Ok(self.into()) } } diff --git a/tests/test_class_new.rs b/tests/test_class_new.rs index 8d4d6068f32..b98f29b7999 100644 --- a/tests/test_class_new.rs +++ b/tests/test_class_new.rs @@ -1,3 +1,4 @@ +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; #[pyclass] @@ -120,3 +121,32 @@ assert c.from_rust is False .map_err(|e| e.print(py)) .unwrap(); } + +#[pyclass] +#[derive(Debug)] +struct NewWithCustomError {} + +struct CustomError; + +impl From for PyErr { + fn from(_error: CustomError) -> PyErr { + PyValueError::new_err("custom error") + } +} + +#[pymethods] +impl NewWithCustomError { + #[new] + fn new() -> Result { + Err(CustomError) + } +} + +#[test] +fn new_with_custom_error() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let typeobj = py.get_type::(); + let err = typeobj.call0().unwrap_err(); + assert_eq!(err.to_string(), "ValueError: custom error"); +}