Skip to content

Commit

Permalink
Set the module of #[pyfunction]s.
Browse files Browse the repository at this point in the history
Previously neither the module nor the name of the module of
pyfunctions were registered. This commit passes the module and
its name when creating a new pyfunction.

PyModule::add_function and PyModule::add_module have been added and are
set to replace `add_wrapped` in a future release. `add_wrapped` is kept
for compatibility reasons during the transition.

Depending on whether a `PyModule` or `Python` is the argument for the
Python function-wrapper, the module will be registered with the function.
  • Loading branch information
sebpuetz committed Sep 3, 2020
1 parent 21ad52a commit e296c2b
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 38 deletions.
28 changes: 14 additions & 14 deletions examples/rustapi_module/src/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,29 +215,29 @@ impl TzClass {

#[pymodule]
fn datetime(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(make_date))?;
m.add_wrapped(wrap_pyfunction!(get_date_tuple))?;
m.add_wrapped(wrap_pyfunction!(date_from_timestamp))?;
m.add_wrapped(wrap_pyfunction!(make_time))?;
m.add_wrapped(wrap_pyfunction!(get_time_tuple))?;
m.add_wrapped(wrap_pyfunction!(make_delta))?;
m.add_wrapped(wrap_pyfunction!(get_delta_tuple))?;
m.add_wrapped(wrap_pyfunction!(make_datetime))?;
m.add_wrapped(wrap_pyfunction!(get_datetime_tuple))?;
m.add_wrapped(wrap_pyfunction!(datetime_from_timestamp))?;
m.add_function(wrap_pyfunction!(make_date))?;
m.add_function(wrap_pyfunction!(get_date_tuple))?;
m.add_function(wrap_pyfunction!(date_from_timestamp))?;
m.add_function(wrap_pyfunction!(make_time))?;
m.add_function(wrap_pyfunction!(get_time_tuple))?;
m.add_function(wrap_pyfunction!(make_delta))?;
m.add_function(wrap_pyfunction!(get_delta_tuple))?;
m.add_function(wrap_pyfunction!(make_datetime))?;
m.add_function(wrap_pyfunction!(get_datetime_tuple))?;
m.add_function(wrap_pyfunction!(datetime_from_timestamp))?;

// Python 3.6+ functions
#[cfg(Py_3_6)]
{
m.add_wrapped(wrap_pyfunction!(time_with_fold))?;
m.add_function(wrap_pyfunction!(time_with_fold))?;
#[cfg(not(PyPy))]
{
m.add_wrapped(wrap_pyfunction!(get_time_tuple_fold))?;
m.add_wrapped(wrap_pyfunction!(get_datetime_tuple_fold))?;
m.add_function(wrap_pyfunction!(get_time_tuple_fold))?;
m.add_function(wrap_pyfunction!(get_datetime_tuple_fold))?;
}
}

m.add_wrapped(wrap_pyfunction!(issue_219))?;
m.add_function(wrap_pyfunction!(issue_219))?;
m.add_class::<TzClass>()?;

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion examples/rustapi_module/src/othermod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn double(x: i32) -> i32 {

#[pymodule]
fn othermod(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double))?;
m.add_function(wrap_pyfunction!(double))?;

m.add_class::<ModClass>()?;

Expand Down
6 changes: 3 additions & 3 deletions examples/word-count/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ fn count_line(line: &str, needle: &str) -> usize {

#[pymodule]
fn word_count(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(search))?;
m.add_wrapped(wrap_pyfunction!(search_sequential))?;
m.add_wrapped(wrap_pyfunction!(search_sequential_allow_threads))?;
m.add_function(wrap_pyfunction!(search))?;
m.add_function(wrap_pyfunction!(search_sequential))?;
m.add_function(wrap_pyfunction!(search_sequential_allow_threads))?;

Ok(())
}
4 changes: 2 additions & 2 deletions guide/src/function.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fn double(x: usize) -> usize {

#[pymodule]
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double)).unwrap();
m.add_function(wrap_pyfunction!(double)).unwrap();

Ok(())
}
Expand Down Expand Up @@ -65,7 +65,7 @@ fn num_kwds(kwds: Option<&PyDict>) -> usize {

#[pymodule]
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(num_kwds)).unwrap();
m.add_function(wrap_pyfunction!(num_kwds)).unwrap();
Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions guide/src/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ fn subfunction() -> String {

#[pymodule]
fn submodule(_py: Python, module: &PyModule) -> PyResult<()> {
module.add_wrapped(wrap_pyfunction!(subfunction))?;
module.add_function(wrap_pyfunction!(subfunction))?;
Ok(())
}

#[pymodule]
fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> {
module.add_wrapped(wrap_pymodule!(submodule))?;
module.add_module(wrap_pymodule!(submodule))?;
Ok(())
}

Expand Down
24 changes: 20 additions & 4 deletions pyo3-derive-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
let item: syn::ItemFn = syn::parse_quote! {
fn block_wrapper() {
#function_to_python
#module_name.add_wrapped(&#function_wrapper_ident)?;
#module_name.add_function(&#function_wrapper_ident)?;
}
};
stmts.extend(item.block.stmts.into_iter());
Expand Down Expand Up @@ -193,7 +193,10 @@ pub fn add_fn_to_module(
let wrapper = function_c_wrapper(&func.sig.ident, &spec);

Ok(quote! {
fn #function_wrapper_ident(py: pyo3::Python) -> pyo3::PyObject {
fn #function_wrapper_ident<'a, 'b>(
args: impl pyo3::derive_utils::WrapPyFunctionArguments<'a, 'b>
) -> pyo3::PyObject {
let (py, maybe_module) = args.arguments();
#wrapper

let _def = pyo3::class::PyMethodDef {
Expand All @@ -203,12 +206,25 @@ pub fn add_fn_to_module(
ml_doc: #doc,
};

let (mod_ptr, name) = if let Some(m) = maybe_module {
let mod_ptr = <pyo3::types::PyModule as ::pyo3::conversion::AsPyPointer>::as_ptr(m);
let name = unsafe { pyo3::ffi::PyModule_GetNameObject(mod_ptr) };
if name.is_null() {
let err = PyErr::fetch(py);
return <PyErr as pyo3::conversion::IntoPy<PyObject>>::into_py(err, py);
}
(mod_ptr, name)
} else {
(std::ptr::null_mut(), std::ptr::null_mut())
};

let function = unsafe {
pyo3::PyObject::from_owned_ptr(
py,
pyo3::ffi::PyCFunction_New(
pyo3::ffi::PyCFunction_NewEx(
Box::into_raw(Box::new(_def.as_method_def())),
::std::ptr::null_mut()
mod_ptr,
name
)
)
};
Expand Down
21 changes: 21 additions & 0 deletions src/derive_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,24 @@ where
<R as std::convert::TryFrom<&'a PyCell<T>>>::try_from(cell)
}
}

/// Trait to abstract over the arguments of Python function wrappers.
#[doc(hidden)]
pub trait WrapPyFunctionArguments<'a, 'b> {
fn arguments(self) -> (Python<'b>, Option<&'a PyModule>);
}

impl<'a, 'b> WrapPyFunctionArguments<'a, 'b> for Python<'b> {
fn arguments(self) -> (Python<'b>, Option<&'a PyModule>) {
(self, None)
}
}

impl<'a, 'b> WrapPyFunctionArguments<'a, 'b> for &'a PyModule
where
'a: 'b,
{
fn arguments(self) -> (Python<'b>, Option<&'a PyModule>) {
(self.py(), Some(self))
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
//! #[pymodule]
//! /// A Python module implemented in Rust.
//! fn string_sum(py: Python, m: &PyModule) -> PyResult<()> {
//! m.add_wrapped(wrap_pyfunction!(sum_as_string))?;
//! m.add_function(wrap_pyfunction!(sum_as_string))?;
//!
//! Ok(())
//! }
Expand Down
2 changes: 1 addition & 1 deletion src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl<'p> Python<'p> {
/// let gil = Python::acquire_gil();
/// let py = gil.python();
/// let m = PyModule::new(py, "pcount").unwrap();
/// m.add_wrapped(wrap_pyfunction!(parallel_count)).unwrap();
/// m.add_function(wrap_pyfunction!(parallel_count)).unwrap();
/// let locals = [("pcount", m)].into_py_dict(py);
/// py.run(r#"
/// s = ["Flow", "my", "tears", "the", "Policeman", "Said"]
Expand Down
41 changes: 40 additions & 1 deletion src/types/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,50 @@ impl PyModule {
/// ```rust,ignore
/// m.add("also_double", wrap_pyfunction!(double)(py));
/// ```
pub fn add_wrapped(&self, wrapper: &impl Fn(Python) -> PyObject) -> PyResult<()> {
///
/// **This function will be deprecated in the next release. Please use the specific
/// [add_function] and [add_module] functions instead.**
pub fn add_wrapped<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> {
let function = wrapper(self.py());
let name = function
.getattr(self.py(), "__name__")
.expect("A function or module must have a __name__");
self.add(name.extract(self.py()).unwrap(), function)
}

/// Adds a (sub)module to a module.
///
/// Use this together with `#[pymodule]` and [wrap_pymodule!].
///
/// ```rust,ignore
/// m.add_module(wrap_pymodule!(utils));
/// ```
pub fn add_module<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> {
let function = wrapper(self.py());
let name = function
.getattr(self.py(), "__name__")
.expect("A module must have a __name__");
self.add(name.extract(self.py()).unwrap(), function)
}

/// Adds a function to a module, using the functions __name__ as name.
///
/// Use this together with the`#[pyfunction]` and [wrap_pyfunction!].
///
/// ```rust,ignore
/// m.add_function(wrap_pyfunction!(double));
/// ```
///
/// You can also add a function with a custom name using [add](PyModule::add):
///
/// ```rust,ignore
/// m.add("also_double", wrap_pyfunction!(double)(py, m));
/// ```
pub fn add_function<'a>(&'a self, wrapper: &impl Fn(&'a Self) -> PyObject) -> PyResult<()> {
let function = wrapper(self);
let name = function
.getattr(self.py(), "__name__")
.expect("A function or module must have a __name__");
self.add(name.extract(self.py()).unwrap(), function)
}
}
18 changes: 9 additions & 9 deletions tests/test_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn double(x: usize) -> usize {

/// This module is implemented in Rust.
#[pymodule]
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> {
use pyo3::wrap_pyfunction;

#[pyfn(m, "sum_as_string")]
Expand All @@ -60,8 +60,8 @@ fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {

m.add("foo", "bar").unwrap();

m.add_wrapped(wrap_pyfunction!(double)).unwrap();
m.add("also_double", wrap_pyfunction!(double)(py)).unwrap();
m.add_function(wrap_pyfunction!(double)).unwrap();
m.add("also_double", wrap_pyfunction!(double)(m)).unwrap();

Ok(())
}
Expand Down Expand Up @@ -157,7 +157,7 @@ fn r#move() -> usize {
fn raw_ident_module(_py: Python, module: &PyModule) -> PyResult<()> {
use pyo3::wrap_pyfunction;

module.add_wrapped(wrap_pyfunction!(r#move))
module.add_function(wrap_pyfunction!(r#move))
}

#[test]
Expand All @@ -182,7 +182,7 @@ fn custom_named_fn() -> usize {
fn foobar_module(_py: Python, m: &PyModule) -> PyResult<()> {
use pyo3::wrap_pyfunction;

m.add_wrapped(wrap_pyfunction!(custom_named_fn))?;
m.add_function(wrap_pyfunction!(custom_named_fn))?;
m.dict().set_item("yay", "me")?;
Ok(())
}
Expand Down Expand Up @@ -216,7 +216,7 @@ fn subfunction() -> String {
fn submodule(_py: Python, module: &PyModule) -> PyResult<()> {
use pyo3::wrap_pyfunction;

module.add_wrapped(wrap_pyfunction!(subfunction))?;
module.add_function(wrap_pyfunction!(subfunction))?;
Ok(())
}

Expand All @@ -229,8 +229,8 @@ fn superfunction() -> String {
fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> {
use pyo3::{wrap_pyfunction, wrap_pymodule};

module.add_wrapped(wrap_pyfunction!(superfunction))?;
module.add_wrapped(wrap_pymodule!(submodule))?;
module.add_function(wrap_pyfunction!(superfunction))?;
module.add_module(wrap_pymodule!(submodule))?;
Ok(())
}

Expand Down Expand Up @@ -268,7 +268,7 @@ fn vararg_module(_py: Python, m: &PyModule) -> PyResult<()> {
ext_vararg_fn(py, a, vararg)
}

m.add_wrapped(pyo3::wrap_pyfunction!(ext_vararg_fn))
m.add_function(pyo3::wrap_pyfunction!(ext_vararg_fn))
.unwrap();
Ok(())
}
Expand Down

0 comments on commit e296c2b

Please sign in to comment.