Skip to content

Commit

Permalink
Make python function wrapper creation fallible.
Browse files Browse the repository at this point in the history
Wrapping a function can fail if we can't get the module name.

Based on suggestion by @kngwyu
  • Loading branch information
sebpuetz committed Sep 3, 2020
1 parent 1f017b6 commit 3214249
Show file tree
Hide file tree
Showing 13 changed files with 34 additions and 31 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
/// A Python module implemented in Rust.
#[pymodule]
fn string_sum(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(sum_as_string))?;
m.add_function(wrap_pyfunction!(sum_as_string))?;

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion examples/word-count/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn count_line(line: &str, needle: &str) -> usize {

#[pymodule]
fn word_count(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(search))?;
m.add_wrapped(wrap_pyfunction!(search))?;
m.add_function(wrap_pyfunction!(search_sequential))?;
m.add_function(wrap_pyfunction!(search_sequential_allow_threads))?;

Expand Down
2 changes: 1 addition & 1 deletion guide/src/logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn my_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
// A good place to install the Rust -> Python logger.
pyo3_log::init();

m.add_wrapped(wrap_pyfunction!(log_something))?;
m.add_function(wrap_pyfunction!(log_something))?;
Ok(())
}
```
Expand Down
2 changes: 1 addition & 1 deletion guide/src/trait_bounds.md
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ pub struct UserModel {
#[pymodule]
fn trait_exposure(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<UserModel>()?;
m.add_wrapped(wrap_pyfunction!(solve_wrapper))?;
m.add_function(wrap_pyfunction!(solve_wrapper))?;
Ok(())
}

Expand Down
12 changes: 4 additions & 8 deletions pyo3-derive-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ pub fn add_fn_to_module(
Ok(quote! {
fn #function_wrapper_ident<'a>(
args: impl Into<pyo3::derive_utils::WrapPyFunctionArguments<'a>>
) -> pyo3::PyObject {
) -> pyo3::PyResult<pyo3::PyObject> {
let arg = args.into();
let (py, maybe_module) = arg.into_py_and_maybe_module();
#wrapper
Expand All @@ -206,12 +206,8 @@ pub fn add_fn_to_module(

let (mod_ptr, name) = if let Some(m) = maybe_module {
let mod_ptr = <pyo3::types::PyModule as ::pyo3::conversion::AsPyPointer>::as_ptr(m);
let name = match m.name() {
Ok(name) => <&str as pyo3::conversion::IntoPy<PyObject>>::into_py(name, py),
Err(err) => {
return <PyErr as pyo3::conversion::IntoPy<PyObject>>::into_py(err, py);
}
};
let name = m.name()?;
let name = <&str as pyo3::conversion::IntoPy<PyObject>>::into_py(name, py);
(mod_ptr, <PyObject as pyo3::AsPyPointer>::as_ptr(&name))
} else {
(std::ptr::null_mut(), std::ptr::null_mut())
Expand All @@ -228,7 +224,7 @@ pub fn add_fn_to_module(
)
};

function
Ok(function)
}
})
}
Expand Down
21 changes: 14 additions & 7 deletions src/types/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//
// based on Daniel Grunwald's https://github.com/dgrunwald/rust-cpython

use crate::callback::IntoPyCallbackOutput;
use crate::err::{PyErr, PyResult};
use crate::exceptions;
use crate::ffi;
Expand Down Expand Up @@ -197,8 +198,11 @@ impl PyModule {
///
/// **This function will be deprecated in the next release. Please use the specific
/// [add_function] and [add_module] functions instead.**
pub fn add_wrapped<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> {
let function = wrapper(self.py());
pub fn add_wrapped<'a, T>(&'a self, wrapper: &impl Fn(Python<'a>) -> T) -> PyResult<()>
where
T: IntoPyCallbackOutput<PyObject>,
{
let function = wrapper(self.py()).convert(self.py())?;
let name = function.getattr(self.py(), "__name__")?;
self.add(name.extract(self.py())?, function)
}
Expand All @@ -211,9 +215,9 @@ impl PyModule {
/// m.add_module(wrap_pymodule!(utils));
/// ```
pub fn add_module<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> {
let function = wrapper(self.py());
let name = function.getattr(self.py(), "__name__")?;
self.add(name.extract(self.py())?, function)
let module = wrapper(self.py());
let name = module.getattr(self.py(), "__name__")?;
self.add(name.extract(self.py())?, module)
}

/// Adds a function to a module, using the functions __name__ as name.
Expand All @@ -229,8 +233,11 @@ impl PyModule {
/// ```rust,ignore
/// m.add("also_double", wrap_pyfunction!(double)(py, m));
/// ```
pub fn add_function<'a>(&'a self, wrapper: &impl Fn(&'a Self) -> PyObject) -> PyResult<()> {
let function = wrapper(self);
pub fn add_function<'a>(
&'a self,
wrapper: &impl Fn(&'a Self) -> PyResult<PyObject>,
) -> PyResult<()> {
let function = wrapper(self)?;
let name = function.getattr(self.py(), "__name__")?;
self.add(name.extract(self.py())?, function)
}
Expand Down
6 changes: 3 additions & 3 deletions tests/test_bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn test_pybytes_bytes_conversion() {
let gil = Python::acquire_gil();
let py = gil.python();

let f = wrap_pyfunction!(bytes_pybytes_conversion)(py);
let f = wrap_pyfunction!(bytes_pybytes_conversion)(py).unwrap();
py_assert!(py, f, "f(b'Hello World') == b'Hello World'");
}

Expand All @@ -28,7 +28,7 @@ fn test_pybytes_vec_conversion() {
let gil = Python::acquire_gil();
let py = gil.python();

let f = wrap_pyfunction!(bytes_vec_conversion)(py);
let f = wrap_pyfunction!(bytes_vec_conversion)(py).unwrap();
py_assert!(py, f, "f(b'Hello World') == b'Hello World'");
}

Expand All @@ -37,6 +37,6 @@ fn test_bytearray_vec_conversion() {
let gil = Python::acquire_gil();
let py = gil.python();

let f = wrap_pyfunction!(bytes_vec_conversion)(py);
let f = wrap_pyfunction!(bytes_vec_conversion)(py).unwrap();
py_assert!(py, f, "f(bytearray(b'Hello World')) == b'Hello World'");
}
4 changes: 2 additions & 2 deletions tests/test_exceptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn fail_to_open_file() -> PyResult<()> {
fn test_filenotfounderror() {
let gil = Python::acquire_gil();
let py = gil.python();
let fail_to_open_file = wrap_pyfunction!(fail_to_open_file)(py);
let fail_to_open_file = wrap_pyfunction!(fail_to_open_file)(py).unwrap();

py_run!(
py,
Expand Down Expand Up @@ -64,7 +64,7 @@ fn call_fail_with_custom_error() -> PyResult<()> {
fn test_custom_error() {
let gil = Python::acquire_gil();
let py = gil.python();
let call_fail_with_custom_error = wrap_pyfunction!(call_fail_with_custom_error)(py);
let call_fail_with_custom_error = wrap_pyfunction!(call_fail_with_custom_error)(py).unwrap();

py_run!(
py,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> {
m.add("foo", "bar").unwrap();

m.add_function(wrap_pyfunction!(double)).unwrap();
m.add("also_double", wrap_pyfunction!(double)(m)).unwrap();
m.add("also_double", wrap_pyfunction!(double)(m)?).unwrap();

Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn test_optional_bool() {
// Regression test for issue #932
let gil = Python::acquire_gil();
let py = gil.python();
let f = wrap_pyfunction!(optional_bool)(py);
let f = wrap_pyfunction!(optional_bool)(py).unwrap();

py_assert!(py, f, "f() == 'Some(true)'");
py_assert!(py, f, "f(True) == 'Some(true)'");
Expand All @@ -36,7 +36,7 @@ fn buffer_inplace_add(py: Python, x: PyBuffer<i32>, y: PyBuffer<i32>) {
fn test_buffer_add() {
let gil = Python::acquire_gil();
let py = gil.python();
let f = wrap_pyfunction!(buffer_inplace_add)(py);
let f = wrap_pyfunction!(buffer_inplace_add)(py).unwrap();

py_expect_exception!(
py,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn test_unicode_encode_error() {
let gil = Python::acquire_gil();
let py = gil.python();

let take_str = wrap_pyfunction!(take_str)(py);
let take_str = wrap_pyfunction!(take_str)(py).unwrap();
py_run!(
py,
take_str,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_text_signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ fn test_function() {

let gil = Python::acquire_gil();
let py = gil.python();
let f = wrap_pyfunction!(my_function)(py);
let f = wrap_pyfunction!(my_function)(py).unwrap();

py_assert!(py, f, "f.__text_signature__ == '(a, b=None, *, c=42)'");
}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_various.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fn return_custom_class() {
assert_eq!(get_zero().unwrap().value, 0);

// Using from python
let get_zero = wrap_pyfunction!(get_zero)(py);
let get_zero = wrap_pyfunction!(get_zero)(py).unwrap();
py_assert!(py, get_zero, "get_zero().value == 0");
}

Expand Down Expand Up @@ -206,5 +206,5 @@ fn result_conversion_function() -> Result<(), MyError> {
fn test_result_conversion() {
let gil = Python::acquire_gil();
let py = gil.python();
wrap_pyfunction!(result_conversion_function)(py);
wrap_pyfunction!(result_conversion_function)(py).unwrap();
}

0 comments on commit 3214249

Please sign in to comment.