Skip to content

Commit

Permalink
PyAddToModule: Properly propagate initialization error
Browse files Browse the repository at this point in the history
Better than panics
  • Loading branch information
Tpt committed Mar 5, 2024
1 parent b08ee4b commit 59a6bd1
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 10 deletions.
19 changes: 17 additions & 2 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -881,11 +881,12 @@ fn impl_complex_enum(
}
};

let pyclass_impls: TokenStream = vec![
let pyclass_impls: TokenStream = [
impl_builder.impl_pyclass(ctx),
impl_builder.impl_extractext(ctx),
enum_into_py_impl,
impl_builder.impl_pyclassimpl(ctx)?,
impl_builder.impl_add_to_module(ctx),
impl_builder.impl_freelist(ctx),
]
.into_iter()
Expand Down Expand Up @@ -1372,11 +1373,12 @@ impl<'a> PyClassImplsBuilder<'a> {
}

fn impl_all(&self, ctx: &Ctx) -> Result<TokenStream> {
let tokens = vec![
let tokens = [
self.impl_pyclass(ctx),
self.impl_extractext(ctx),
self.impl_into_py(ctx),
self.impl_pyclassimpl(ctx)?,
self.impl_add_to_module(ctx),
self.impl_freelist(ctx),
]
.into_iter()
Expand Down Expand Up @@ -1625,6 +1627,19 @@ impl<'a> PyClassImplsBuilder<'a> {
})
}

fn impl_add_to_module(&self, ctx: &Ctx) -> TokenStream {
let Ctx { pyo3_path } = ctx;
let cls = self.cls;
quote! {
impl #pyo3_path::impl_::pymodule::PyAddToModule for #cls {
fn add_to_module(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
use #pyo3_path::types::PyModuleMethods;
module.add_class::<Self>()
}
}
}
}

fn impl_freelist(&self, ctx: &Ctx) -> TokenStream {
let cls = self.cls;
let Ctx { pyo3_path } = ctx;
Expand Down
9 changes: 1 addition & 8 deletions src/impl_/pymodule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ use portable_atomic::{AtomicI64, Ordering};

#[cfg(not(PyPy))]
use crate::exceptions::PyImportError;
use crate::types::module::PyModuleMethods;
use crate::{ffi, sync::GILOnceCell, types::PyModule, Bound, Py, PyResult, PyTypeInfo, Python};
use crate::{ffi, sync::GILOnceCell, types::PyModule, Bound, Py, PyResult, Python};

/// `Sync` wrapper of `ffi::PyModuleDef`.
pub struct ModuleDef {
Expand Down Expand Up @@ -141,12 +140,6 @@ pub trait PyAddToModule {
fn add_to_module(module: &Bound<'_, PyModule>) -> PyResult<()>;
}

impl<T: PyTypeInfo> PyAddToModule for T {
fn add_to_module(module: &Bound<'_, PyModule>) -> PyResult<()> {
module.add(Self::NAME, Self::type_object_bound(module.py()))
}
}

#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};
Expand Down
12 changes: 12 additions & 0 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,18 @@ macro_rules! pyobject_native_type_info(
}
)?
}

impl<$($generics,)*> $crate::impl_::pymodule::PyAddToModule for $name {
fn add_to_module(
module: &$crate::Bound<'_, $crate::types::PyModule>,
) -> $crate::PyResult<()> {
use $crate::types::PyModuleMethods;
module.add(
<Self as $crate::PyTypeInfo>::NAME,
<Self as $crate::PyTypeInfo>::type_object_bound(module.py()),
)
}
}
};
);

Expand Down
41 changes: 41 additions & 0 deletions tests/test_declarative_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
use pyo3::create_exception;
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
#[cfg(not(Py_LIMITED_API))]
use pyo3::types::PyBool;

#[path = "../src/tests/common.rs"]
mod common;
Expand Down Expand Up @@ -99,3 +101,42 @@ fn test_declarative_module() {
py_assert!(py, m, "hasattr(m, 'LocatedClass')");
})
}

#[cfg(not(Py_LIMITED_API))]
#[pyclass(extends = PyBool)]
struct ExtendsBool;

#[cfg(not(Py_LIMITED_API))]
#[pymodule]
mod class_initialization_module {
#[pymodule_export]
use super::ExtendsBool;
}

#[test]
#[cfg(not(Py_LIMITED_API))]
fn test_class_initialization_fails() {
Python::with_gil(|py| {
let err = class_initialization_module::DEF
.make_module(py)
.unwrap_err();
assert_eq!(
err.to_string(),
"RuntimeError: An error occurred while initializing class ExtendsBool"
);
})
}

#[pymodule]
mod r#type {
#[pymodule_export]
use super::double;
}

#[test]
fn test_raw_ident_module() {
Python::with_gil(|py| {
let m = pyo3::wrap_pymodule!(r#type)(py).into_bound(py);
py_assert!(py, m, "m.double(2) == 4");
})
}

0 comments on commit 59a6bd1

Please sign in to comment.