Skip to content

Commit

Permalink
Move py fn wrapper argument expansion to associated function.
Browse files Browse the repository at this point in the history
Suggestion by @kngwyu.

Additionally replace some `expect` calls with error handling.
  • Loading branch information
sebpuetz committed Sep 3, 2020
1 parent 5bbca1a commit 1f017b6
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 19 deletions.
8 changes: 1 addition & 7 deletions pyo3-derive-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,7 @@ pub fn add_fn_to_module(
args: impl Into<pyo3::derive_utils::WrapPyFunctionArguments<'a>>
) -> pyo3::PyObject {
let arg = args.into();
let (py, maybe_module) = match arg {
pyo3::derive_utils::WrapPyFunctionArguments::Python(py) => (py, None),
pyo3::derive_utils::WrapPyFunctionArguments::PyModule(module) => {
let py = <pyo3::types::PyModule as pyo3::PyNativeType>::py(module);
(py, Some(module))
}
};
let (py, maybe_module) = arg.into_py_and_maybe_module();
#wrapper

let _def = pyo3::class::PyMethodDef {
Expand Down
12 changes: 12 additions & 0 deletions src/derive_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,18 @@ pub enum WrapPyFunctionArguments<'a> {
PyModule(&'a PyModule),
}

impl<'a> WrapPyFunctionArguments<'a> {
pub fn into_py_and_maybe_module(self) -> (Python<'a>, Option<&'a PyModule>) {
match self {
WrapPyFunctionArguments::Python(py) => (py, None),
WrapPyFunctionArguments::PyModule(module) => {
let py = module.py();
(py, Some(module))
}
}
}
}

impl<'a> From<Python<'a>> for WrapPyFunctionArguments<'a> {
fn from(py: Python<'a>) -> WrapPyFunctionArguments<'a> {
WrapPyFunctionArguments::Python(py)
Expand Down
18 changes: 6 additions & 12 deletions src/types/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,8 @@ impl PyModule {
/// [add_function] and [add_module] functions instead.**
pub fn add_wrapped<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> {
let function = wrapper(self.py());
let name = function
.getattr(self.py(), "__name__")
.expect("A function or module must have a __name__");
self.add(name.extract(self.py()).unwrap(), function)
let name = function.getattr(self.py(), "__name__")?;
self.add(name.extract(self.py())?, function)
}

/// Adds a (sub)module to a module.
Expand All @@ -214,10 +212,8 @@ impl PyModule {
/// ```
pub fn add_module<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> {
let function = wrapper(self.py());
let name = function
.getattr(self.py(), "__name__")
.expect("A module must have a __name__");
self.add(name.extract(self.py()).unwrap(), function)
let name = function.getattr(self.py(), "__name__")?;
self.add(name.extract(self.py())?, function)
}

/// Adds a function to a module, using the functions __name__ as name.
Expand All @@ -235,9 +231,7 @@ impl PyModule {
/// ```
pub fn add_function<'a>(&'a self, wrapper: &impl Fn(&'a Self) -> PyObject) -> PyResult<()> {
let function = wrapper(self);
let name = function
.getattr(self.py(), "__name__")
.expect("A function or module must have a __name__");
self.add(name.extract(self.py()).unwrap(), function)
let name = function.getattr(self.py(), "__name__")?;
self.add(name.extract(self.py())?, function)
}
}

0 comments on commit 1f017b6

Please sign in to comment.