diff --git a/CHANGELOG.md b/CHANGELOG.md index e3886d0e1da..962ba01e0de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Add optional implementations of `ToPyObject`, `IntoPy`, and `FromPyObject` for [hashbrown](https://crates.io/crates/hashbrown)'s `HashMap` and `HashSet` types. The `hashbrown` feature must be enabled for these implementations to be built. [#1114](https://github.com/PyO3/pyo3/pull/1114/) - Allow other `Result` types when using `#[pyfunction]`. [#1106](https://github.com/PyO3/pyo3/issues/1106). - Add `#[derive(FromPyObject)]` macro for enums and structs. [#1065](https://github.com/PyO3/pyo3/pull/1065) +- Add macro attribute to `#[pyfn]` and `#[pyfunction]` to pass the module of a Python function to the function + body. [#1143](https://github.com/PyO3/pyo3/pull/1143) +- Add `add_function()` and `add_submodule()` functions to `PyModule` [#1143](https://github.com/PyO3/pyo3/pull/1143) ### Changed - Exception types have been renamed from e.g. `RuntimeError` to `PyRuntimeError`, and are now only accessible by `&T` or `Py` similar to other Python-native types. The old names continue to exist but are deprecated. [#1024](https://github.com/PyO3/pyo3/pull/1024) @@ -30,6 +33,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Implement `Send + Sync` for `PyErr`. `PyErr::new`, `PyErr::from_type`, `PyException::py_err` and `PyException::into` have had these bounds added to their arguments. [#1067](https://github.com/PyO3/pyo3/pull/1067) - Change `#[pyproto]` to return NotImplemented for operators for which Python can try a reversed operation. #[1072](https://github.com/PyO3/pyo3/pull/1072) - `PyModule::add` now uses `IntoPy` instead of `ToPyObject`. #[1124](https://github.com/PyO3/pyo3/pull/1124) +- Add nested modules as `&PyModule` instead of using the wrapper generated by `#[pymodule]`. [#1143](https://github.com/PyO3/pyo3/pull/1143) ### Removed - Remove `PyString::as_bytes`. [#1023](https://github.com/PyO3/pyo3/pull/1023) @@ -50,6 +54,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Link against libpython on android with `extension-module` set. [#1095](https://github.com/PyO3/pyo3/pull/1095) - Fix support for both `__add__` and `__radd__` in the `+` operator when both are defined in `PyNumberProtocol` (and similar for all other reversible operators). [#1107](https://github.com/PyO3/pyo3/pull/1107) +- Associate Python functions with their module by passing the Module and Module name [#1143](https://github.com/PyO3/pyo3/pull/1143) ## [0.11.1] - 2020-06-30 ### Added diff --git a/README.md b/README.md index bf06df9e06e..ff041bf60b7 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ fn sum_as_string(a: usize, b: usize) -> PyResult { /// A Python module implemented in Rust. #[pymodule] fn string_sum(py: Python, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(sum_as_string))?; + m.add_function(wrap_pyfunction!(sum_as_string))?; Ok(()) } diff --git a/examples/rustapi_module/src/datetime.rs b/examples/rustapi_module/src/datetime.rs index 3181ae79d97..3ccb7c697f6 100644 --- a/examples/rustapi_module/src/datetime.rs +++ b/examples/rustapi_module/src/datetime.rs @@ -215,29 +215,29 @@ impl TzClass { #[pymodule] fn datetime(_py: Python<'_>, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(make_date))?; - m.add_wrapped(wrap_pyfunction!(get_date_tuple))?; - m.add_wrapped(wrap_pyfunction!(date_from_timestamp))?; - m.add_wrapped(wrap_pyfunction!(make_time))?; - m.add_wrapped(wrap_pyfunction!(get_time_tuple))?; - m.add_wrapped(wrap_pyfunction!(make_delta))?; - m.add_wrapped(wrap_pyfunction!(get_delta_tuple))?; - m.add_wrapped(wrap_pyfunction!(make_datetime))?; - m.add_wrapped(wrap_pyfunction!(get_datetime_tuple))?; - m.add_wrapped(wrap_pyfunction!(datetime_from_timestamp))?; + m.add_function(wrap_pyfunction!(make_date))?; + m.add_function(wrap_pyfunction!(get_date_tuple))?; + m.add_function(wrap_pyfunction!(date_from_timestamp))?; + m.add_function(wrap_pyfunction!(make_time))?; + m.add_function(wrap_pyfunction!(get_time_tuple))?; + m.add_function(wrap_pyfunction!(make_delta))?; + m.add_function(wrap_pyfunction!(get_delta_tuple))?; + m.add_function(wrap_pyfunction!(make_datetime))?; + m.add_function(wrap_pyfunction!(get_datetime_tuple))?; + m.add_function(wrap_pyfunction!(datetime_from_timestamp))?; // Python 3.6+ functions #[cfg(Py_3_6)] { - m.add_wrapped(wrap_pyfunction!(time_with_fold))?; + m.add_function(wrap_pyfunction!(time_with_fold))?; #[cfg(not(PyPy))] { - m.add_wrapped(wrap_pyfunction!(get_time_tuple_fold))?; - m.add_wrapped(wrap_pyfunction!(get_datetime_tuple_fold))?; + m.add_function(wrap_pyfunction!(get_time_tuple_fold))?; + m.add_function(wrap_pyfunction!(get_datetime_tuple_fold))?; } } - m.add_wrapped(wrap_pyfunction!(issue_219))?; + m.add_function(wrap_pyfunction!(issue_219))?; m.add_class::()?; Ok(()) diff --git a/examples/rustapi_module/src/othermod.rs b/examples/rustapi_module/src/othermod.rs index 20745b29fb6..b9955806186 100644 --- a/examples/rustapi_module/src/othermod.rs +++ b/examples/rustapi_module/src/othermod.rs @@ -31,7 +31,7 @@ fn double(x: i32) -> i32 { #[pymodule] fn othermod(_py: Python<'_>, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(double))?; + m.add_function(wrap_pyfunction!(double))?; m.add_class::()?; diff --git a/examples/word-count/src/lib.rs b/examples/word-count/src/lib.rs index 06d696e895f..50a0078026b 100644 --- a/examples/word-count/src/lib.rs +++ b/examples/word-count/src/lib.rs @@ -56,8 +56,8 @@ fn count_line(line: &str, needle: &str) -> usize { #[pymodule] fn word_count(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(search))?; - m.add_wrapped(wrap_pyfunction!(search_sequential))?; - m.add_wrapped(wrap_pyfunction!(search_sequential_allow_threads))?; + m.add_function(wrap_pyfunction!(search_sequential))?; + m.add_function(wrap_pyfunction!(search_sequential_allow_threads))?; Ok(()) } diff --git a/guide/src/function.md b/guide/src/function.md index 1a12d8ec6f1..b33221c9d41 100644 --- a/guide/src/function.md +++ b/guide/src/function.md @@ -36,7 +36,7 @@ fn double(x: usize) -> usize { #[pymodule] fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(double)).unwrap(); + m.add_function(wrap_pyfunction!(double)).unwrap(); Ok(()) } @@ -65,7 +65,7 @@ fn num_kwds(kwds: Option<&PyDict>) -> usize { #[pymodule] fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(num_kwds)).unwrap(); + m.add_function(wrap_pyfunction!(num_kwds)).unwrap(); Ok(()) } @@ -189,3 +189,47 @@ If you have a static function, you can expose it with `#[pyfunction]` and use [` [`PyAny::call1`]: https://docs.rs/pyo3/latest/pyo3/struct.PyAny.html#tymethod.call1 [`PyObject`]: https://docs.rs/pyo3/latest/pyo3/type.PyObject.html [`wrap_pyfunction!`]: https://docs.rs/pyo3/latest/pyo3/macro.wrap_pyfunction.html + +### Accessing the module of a function + +It is possible to access the module of a `#[pyfunction]` and `#[pyfn]` in the +function body by passing the `pass_module` argument to the attribute: + +```rust +use pyo3::wrap_pyfunction; +use pyo3::prelude::*; + +#[pyfunction(pass_module)] +fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> { + module.name() +} + +#[pymodule] +fn module_with_fn(py: Python, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(pyfunction_with_module)) +} + +# fn main() {} +``` + +If `pass_module` is set, the first argument **must** be the `&PyModule`. It is then possible to use the module +in the function body. + +The same works for `#[pyfn]`: + +```rust +use pyo3::wrap_pyfunction; +use pyo3::prelude::*; + +#[pymodule] +fn module_with_fn(py: Python, m: &PyModule) -> PyResult<()> { + + #[pyfn(m, "module_name", pass_module)] + fn module_name(module: &PyModule) -> PyResult<&str> { + module.name() + } + Ok(()) +} + +# fn main() {} +``` diff --git a/guide/src/logging.md b/guide/src/logging.md index a9078b5d01a..500a8804709 100644 --- a/guide/src/logging.md +++ b/guide/src/logging.md @@ -35,7 +35,7 @@ fn my_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> { // A good place to install the Rust -> Python logger. pyo3_log::init(); - m.add_wrapped(wrap_pyfunction!(log_something))?; + m.add_function(wrap_pyfunction!(log_something))?; Ok(()) } ``` diff --git a/guide/src/module.md b/guide/src/module.md index 4dea21b1b9b..6b1d4581eec 100644 --- a/guide/src/module.md +++ b/guide/src/module.md @@ -32,16 +32,22 @@ fn sum_as_string(a: i64, b: i64) -> String { # fn main() {} ``` -The `#[pymodule]` procedural macro attribute takes care of exporting the initialization function of your module to Python. It can take as an argument the name of your module, which must be the name of the `.so` or `.pyd` file; the default is the Rust function's name. +The `#[pymodule]` procedural macro attribute takes care of exporting the initialization function of your +module to Python. It can take as an argument the name of your module, which must be the name of the `.so` +or `.pyd` file; the default is the Rust function's name. -If the name of the module (the default being the function name) does not match the name of the `.so` or `.pyd` file, you will get an import error in Python with the following message: +If the name of the module (the default being the function name) does not match the name of the `.so` or +`.pyd` file, you will get an import error in Python with the following message: `ImportError: dynamic module does not define module export function (PyInit_name_of_your_module)` -To import the module, either copy the shared library as described in [the README](https://github.com/PyO3/pyo3) or use a tool, e.g. `maturin develop` with [maturin](https://github.com/PyO3/maturin) or `python setup.py develop` with [setuptools-rust](https://github.com/PyO3/setuptools-rust). +To import the module, either copy the shared library as described in [the README](https://github.com/PyO3/pyo3) +or use a tool, e.g. `maturin develop` with [maturin](https://github.com/PyO3/maturin) or +`python setup.py develop` with [setuptools-rust](https://github.com/PyO3/setuptools-rust). ## Documentation -The [Rust doc comments](https://doc.rust-lang.org/stable/book/first-edition/comments.html) of the module initialization function will be applied automatically as the Python docstring of your module. +The [Rust doc comments](https://doc.rust-lang.org/stable/book/first-edition/comments.html) of the module +initialization function will be applied automatically as the Python docstring of your module. ```python import rust2py @@ -53,7 +59,8 @@ Which means that the above Python code will print `This module is implemented in ## Modules as objects -In Python, modules are first class objects. This means that you can store them as values or add them to dicts or other modules: +In Python, modules are first class objects. This means that you can store them as values or add them to +dicts or other modules: ```rust use pyo3::prelude::*; @@ -65,15 +72,16 @@ fn subfunction() -> String { "Subfunction".to_string() } -#[pymodule] -fn submodule(_py: Python, module: &PyModule) -> PyResult<()> { - module.add_wrapped(wrap_pyfunction!(subfunction))?; +fn init_submodule(module: &PyModule) -> PyResult<()> { + module.add_function(wrap_pyfunction!(subfunction))?; Ok(()) } #[pymodule] -fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> { - module.add_wrapped(wrap_pymodule!(submodule))?; +fn supermodule(py: Python, module: &PyModule) -> PyResult<()> { + let submod = PyModule::new(py, "submodule")?; + init_submodule(submod)?; + module.add_submodule(submod)?; Ok(()) } @@ -86,3 +94,5 @@ fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> { ``` This way, you can create a module hierarchy within a single extension module. + +It is not necessary to add `#[pymodule]` on nested modules, this is only required on the top-level module. \ No newline at end of file diff --git a/guide/src/trait_bounds.md b/guide/src/trait_bounds.md index c11d585ced9..65e173cd40d 100644 --- a/guide/src/trait_bounds.md +++ b/guide/src/trait_bounds.md @@ -488,7 +488,7 @@ pub struct UserModel { #[pymodule] fn trait_exposure(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; - m.add_wrapped(wrap_pyfunction!(solve_wrapper))?; + m.add_function(wrap_pyfunction!(solve_wrapper))?; Ok(()) } diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index bd6e4182793..a706100ecd2 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -2,7 +2,6 @@ //! Code generation for the function that initializes a python module and adds classes and function. use crate::method; -use crate::pyfunction; use crate::pyfunction::PyFunctionAttr; use crate::pymethod; use crate::pymethod::get_arg_names; @@ -45,7 +44,7 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> { let item: syn::ItemFn = syn::parse_quote! { fn block_wrapper() { #function_to_python - #module_name.add_wrapped(&#function_wrapper_ident)?; + #module_name.add_function(&#function_wrapper_ident)?; } }; stmts.extend(item.block.stmts.into_iter()); @@ -78,11 +77,11 @@ fn wrap_fn_argument<'a>(cap: &'a syn::PatType) -> syn::Result> /// Extracts the data from the #[pyfn(...)] attribute of a function fn extract_pyfn_attrs( attrs: &mut Vec, -) -> syn::Result)>> { +) -> syn::Result> { let mut new_attrs = Vec::new(); let mut fnname = None; let mut modname = None; - let mut fn_attrs = Vec::new(); + let mut fn_attrs = PyFunctionAttr::default(); for attr in attrs.iter() { match attr.parse_meta() { @@ -115,9 +114,7 @@ fn extract_pyfn_attrs( } // Read additional arguments if list.nested.len() >= 3 { - fn_attrs = PyFunctionAttr::from_meta(&meta[2..meta.len()]) - .unwrap() - .arguments; + fn_attrs = PyFunctionAttr::from_meta(&meta[2..meta.len()])?; } } else { return Err(syn::Error::new_spanned( @@ -148,11 +145,11 @@ fn function_wrapper_ident(name: &Ident) -> Ident { pub fn add_fn_to_module( func: &mut syn::ItemFn, python_name: Ident, - pyfn_attrs: Vec, + pyfn_attrs: PyFunctionAttr, ) -> syn::Result { let mut arguments = Vec::new(); - for input in func.sig.inputs.iter() { + for (i, input) in func.sig.inputs.iter().enumerate() { match input { syn::FnArg::Receiver(_) => { return Err(syn::Error::new_spanned( @@ -161,7 +158,27 @@ pub fn add_fn_to_module( )) } syn::FnArg::Typed(ref cap) => { - arguments.push(wrap_fn_argument(cap)?); + if pyfn_attrs.pass_module && i == 0 { + if let syn::Type::Reference(tyref) = cap.ty.as_ref() { + if let syn::Type::Path(typath) = tyref.elem.as_ref() { + if typath + .path + .segments + .last() + .map(|seg| seg.ident == "PyModule") + .unwrap_or(false) + { + continue; + } + } + } + return Err(syn::Error::new_spanned( + cap, + "Expected &PyModule as first argument with `pass_module`.", + )); + } else { + arguments.push(wrap_fn_argument(cap)?); + } } } } @@ -177,7 +194,7 @@ pub fn add_fn_to_module( tp: method::FnType::FnStatic, name: &function_wrapper_ident, python_name, - attrs: pyfn_attrs, + attrs: pyfn_attrs.arguments, args: arguments, output: ty, doc, @@ -187,10 +204,14 @@ pub fn add_fn_to_module( let python_name = &spec.python_name; - let wrapper = function_c_wrapper(&func.sig.ident, &spec); + let wrapper = function_c_wrapper(&func.sig.ident, &spec, pyfn_attrs.pass_module); Ok(quote! { - fn #function_wrapper_ident(py: pyo3::Python) -> pyo3::PyObject { + fn #function_wrapper_ident<'a>( + args: impl Into> + ) -> pyo3::PyResult { + let arg = args.into(); + let (py, maybe_module) = arg.into_py_and_maybe_module(); #wrapper let _def = pyo3::class::PyMethodDef { @@ -200,28 +221,49 @@ pub fn add_fn_to_module( ml_doc: #doc, }; + let (mod_ptr, name) = if let Some(m) = maybe_module { + let mod_ptr = ::as_ptr(m); + let name = m.name()?; + let name = <&str as pyo3::conversion::IntoPy>::into_py(name, py); + (mod_ptr, ::as_ptr(&name)) + } else { + (std::ptr::null_mut(), std::ptr::null_mut()) + }; + let function = unsafe { pyo3::PyObject::from_owned_ptr( py, - pyo3::ffi::PyCFunction_New( + pyo3::ffi::PyCFunction_NewEx( Box::into_raw(Box::new(_def.as_method_def())), - ::std::ptr::null_mut() + mod_ptr, + name ) ) }; - function + Ok(function) } }) } /// Generate static function wrapper (PyCFunction, PyCFunctionWithKeywords) -fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>) -> TokenStream { +fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>, pass_module: bool) -> TokenStream { let names: Vec = get_arg_names(&spec); - let cb = quote! { - #name(#(#names),*) + let cb; + let slf_module; + if pass_module { + cb = quote! { + #name(_slf, #(#names),*) + }; + slf_module = Some(quote! { + let _slf = _py.from_borrowed_ptr::(_slf); + }); + } else { + cb = quote! { + #name(#(#names),*) + }; + slf_module = None; }; - let body = pymethod::impl_arg_params(spec, None, cb); quote! { @@ -232,6 +274,7 @@ fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>) -> TokenStream { { const _LOCATION: &'static str = concat!(stringify!(#name), "()"); pyo3::callback_body!(_py, { + #slf_module let _args = _py.from_borrowed_ptr::(_args); let _kwargs: Option<&pyo3::types::PyDict> = _py.from_borrowed_ptr_or_opt(_kwargs); diff --git a/pyo3-derive-backend/src/pyfunction.rs b/pyo3-derive-backend/src/pyfunction.rs index 96ef584a9ca..80ac1cf35f0 100644 --- a/pyo3-derive-backend/src/pyfunction.rs +++ b/pyo3-derive-backend/src/pyfunction.rs @@ -24,6 +24,7 @@ pub struct PyFunctionAttr { has_kw: bool, has_varargs: bool, has_kwargs: bool, + pub pass_module: bool, } impl syn::parse::Parse for PyFunctionAttr { @@ -45,6 +46,9 @@ impl PyFunctionAttr { pub fn add_item(&mut self, item: &NestedMeta) -> syn::Result<()> { match item { + NestedMeta::Meta(syn::Meta::Path(ref ident)) if ident.is_ident("pass_module") => { + self.pass_module = true; + } NestedMeta::Meta(syn::Meta::Path(ref ident)) => self.add_work(item, ident)?, NestedMeta::Meta(syn::Meta::NameValue(ref nv)) => { self.add_name_value(item, nv)?; @@ -204,7 +208,7 @@ pub fn parse_name_attribute(attrs: &mut Vec) -> syn::Result syn::Result { let python_name = parse_name_attribute(&mut ast.attrs)?.unwrap_or_else(|| ast.sig.ident.unraw()); - add_fn_to_module(ast, python_name, args.arguments) + add_fn_to_module(ast, python_name, args) } #[cfg(test)] diff --git a/src/derive_utils.rs b/src/derive_utils.rs index 2a736d7ebcd..cef90fd1248 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -207,3 +207,34 @@ where >>::try_from(cell) } } + +/// Enum to abstract over the arguments of Python function wrappers. +#[doc(hidden)] +pub enum WrapPyFunctionArguments<'a> { + Python(Python<'a>), + PyModule(&'a PyModule), +} + +impl<'a> WrapPyFunctionArguments<'a> { + pub fn into_py_and_maybe_module(self) -> (Python<'a>, Option<&'a PyModule>) { + match self { + WrapPyFunctionArguments::Python(py) => (py, None), + WrapPyFunctionArguments::PyModule(module) => { + let py = module.py(); + (py, Some(module)) + } + } + } +} + +impl<'a> From> for WrapPyFunctionArguments<'a> { + fn from(py: Python<'a>) -> WrapPyFunctionArguments<'a> { + WrapPyFunctionArguments::Python(py) + } +} + +impl<'a> From<&'a PyModule> for WrapPyFunctionArguments<'a> { + fn from(module: &'a PyModule) -> WrapPyFunctionArguments<'a> { + WrapPyFunctionArguments::PyModule(module) + } +} diff --git a/src/ffi/methodobject.rs b/src/ffi/methodobject.rs index 7c3872c5dc8..921cca847c1 100644 --- a/src/ffi/methodobject.rs +++ b/src/ffi/methodobject.rs @@ -1,6 +1,6 @@ use crate::ffi::object::{PyObject, PyTypeObject, Py_TYPE}; +use std::mem; use std::os::raw::{c_char, c_int}; -use std::{mem, ptr}; #[cfg_attr(windows, link(name = "pythonXY"))] extern "C" { @@ -96,19 +96,16 @@ impl Default for PyMethodDef { } } -#[inline] -pub unsafe fn PyCFunction_New(ml: *mut PyMethodDef, slf: *mut PyObject) -> *mut PyObject { - #[cfg_attr(PyPy, link_name = "PyPyCFunction_NewEx")] - PyCFunction_NewEx(ml, slf, ptr::null_mut()) -} - extern "C" { #[cfg_attr(PyPy, link_name = "PyPyCFunction_NewEx")] pub fn PyCFunction_NewEx( - arg1: *mut PyMethodDef, - arg2: *mut PyObject, - arg3: *mut PyObject, + ml: *mut PyMethodDef, + slf: *mut PyObject, + module: *mut PyObject, ) -> *mut PyObject; + + #[cfg_attr(PyPy, link_name = "PyPyCFunction_NewEx")] + pub fn PyCFunction_New(ml: *mut PyMethodDef, slf: *mut PyObject) -> *mut PyObject; } /* Flag passed to newmethodobject */ diff --git a/src/lib.rs b/src/lib.rs index 10f3e768f8e..4c2313e3a47 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,7 +71,7 @@ //! #[pymodule] //! /// A Python module implemented in Rust. //! fn string_sum(py: Python, m: &PyModule) -> PyResult<()> { -//! m.add_wrapped(wrap_pyfunction!(sum_as_string))?; +//! m.add_function(wrap_pyfunction!(sum_as_string))?; //! //! Ok(()) //! } diff --git a/src/python.rs b/src/python.rs index 901a426bcf0..db4abfe6b1a 100644 --- a/src/python.rs +++ b/src/python.rs @@ -134,7 +134,7 @@ impl<'p> Python<'p> { /// let gil = Python::acquire_gil(); /// let py = gil.python(); /// let m = PyModule::new(py, "pcount").unwrap(); - /// m.add_wrapped(wrap_pyfunction!(parallel_count)).unwrap(); + /// m.add_function(wrap_pyfunction!(parallel_count)).unwrap(); /// let locals = [("pcount", m)].into_py_dict(py); /// py.run(r#" /// s = ["Flow", "my", "tears", "the", "Policeman", "Said"] diff --git a/src/types/module.rs b/src/types/module.rs index b345fcab563..0e480010822 100644 --- a/src/types/module.rs +++ b/src/types/module.rs @@ -2,6 +2,7 @@ // // based on Daniel Grunwald's https://github.com/dgrunwald/rust-cpython +use crate::callback::IntoPyCallbackOutput; use crate::err::{PyErr, PyResult}; use crate::exceptions; use crate::ffi; @@ -184,21 +185,102 @@ impl PyModule { /// Use this together with the`#[pyfunction]` and [wrap_pyfunction!] or `#[pymodule]` and /// [wrap_pymodule!]. /// - /// ```rust,ignore - /// m.add_wrapped(wrap_pyfunction!(double)); - /// m.add_wrapped(wrap_pymodule!(utils)); + /// ```rust + /// use pyo3::prelude::*; + /// #[pymodule] + /// fn utils(_py: Python, _module: &PyModule) -> PyResult<()> { + /// Ok(()) + /// } + /// + /// #[pyfunction] + /// fn double(x: usize) -> usize { + /// x * 2 + /// } + /// #[pymodule] + /// fn top_level(_py: Python, module: &PyModule) -> PyResult<()> { + /// module.add_wrapped(pyo3::wrap_pymodule!(utils))?; + /// module.add_wrapped(pyo3::wrap_pyfunction!(double)) + /// } /// ``` /// /// You can also add a function with a custom name using [add](PyModule::add): /// /// ```rust,ignore - /// m.add("also_double", wrap_pyfunction!(double)(py)); + /// m.add("also_double", wrap_pyfunction!(double)(m)?)?; + /// ``` + /// + /// **This function will be deprecated in the next release. Please use the specific + /// [add_function] and [add_submodule] functions instead.** + pub fn add_wrapped<'a, T>(&'a self, wrapper: &impl Fn(Python<'a>) -> T) -> PyResult<()> + where + T: IntoPyCallbackOutput, + { + let py = self.py(); + let function = wrapper(py).convert(py)?; + let name = function.getattr(py, "__name__")?; + let name = name.extract(py)?; + self.add(name, function) + } + + /// Add a submodule to a module. + /// + /// Use this together with `#[pymodule]` and [wrap_pymodule!]. + /// + /// ```rust + /// use pyo3::prelude::*; + /// + /// fn init_utils(module: &PyModule) -> PyResult<()> { + /// module.add("super_useful_constant", "important") + /// } + /// #[pymodule] + /// fn top_level(py: Python, module: &PyModule) -> PyResult<()> { + /// let utils = PyModule::new(py, "utils")?; + /// init_utils(utils)?; + /// module.add_submodule(utils) + /// } + /// ``` + pub fn add_submodule(&self, module: &PyModule) -> PyResult<()> { + let name = module.name()?; + self.add(name, module) + } + + /// Add a function to a module. + /// + /// Use this together with the`#[pyfunction]` and [wrap_pyfunction!]. + /// + /// ```rust + /// use pyo3::prelude::*; + /// #[pyfunction] + /// fn double(x: usize) -> usize { + /// x * 2 + /// } + /// #[pymodule] + /// fn double_mod(_py: Python, module: &PyModule) -> PyResult<()> { + /// module.add_function(pyo3::wrap_pyfunction!(double)) + /// } + /// ``` + /// + /// You can also add a function with a custom name using [add](PyModule::add): + /// + /// ```rust + /// use pyo3::prelude::*; + /// #[pyfunction] + /// fn double(x: usize) -> usize { + /// x * 2 + /// } + /// #[pymodule] + /// fn double_mod(_py: Python, module: &PyModule) -> PyResult<()> { + /// module.add("also_double", pyo3::wrap_pyfunction!(double)(module)?) + /// } /// ``` - pub fn add_wrapped(&self, wrapper: &impl Fn(Python) -> PyObject) -> PyResult<()> { - let function = wrapper(self.py()); - let name = function - .getattr(self.py(), "__name__") - .expect("A function or module must have a __name__"); - self.add(name.extract(self.py()).unwrap(), function) + pub fn add_function<'a>( + &'a self, + wrapper: &impl Fn(&'a Self) -> PyResult, + ) -> PyResult<()> { + let py = self.py(); + let function = wrapper(self)?; + let name = function.getattr(py, "__name__")?; + let name = name.extract(py)?; + self.add(name, function) } } diff --git a/tests/test_bytes.rs b/tests/test_bytes.rs index a48c05fc014..e458e35b553 100644 --- a/tests/test_bytes.rs +++ b/tests/test_bytes.rs @@ -14,7 +14,7 @@ fn test_pybytes_bytes_conversion() { let gil = Python::acquire_gil(); let py = gil.python(); - let f = wrap_pyfunction!(bytes_pybytes_conversion)(py); + let f = wrap_pyfunction!(bytes_pybytes_conversion)(py).unwrap(); py_assert!(py, f, "f(b'Hello World') == b'Hello World'"); } @@ -28,7 +28,7 @@ fn test_pybytes_vec_conversion() { let gil = Python::acquire_gil(); let py = gil.python(); - let f = wrap_pyfunction!(bytes_vec_conversion)(py); + let f = wrap_pyfunction!(bytes_vec_conversion)(py).unwrap(); py_assert!(py, f, "f(b'Hello World') == b'Hello World'"); } @@ -37,6 +37,6 @@ fn test_bytearray_vec_conversion() { let gil = Python::acquire_gil(); let py = gil.python(); - let f = wrap_pyfunction!(bytes_vec_conversion)(py); + let f = wrap_pyfunction!(bytes_vec_conversion)(py).unwrap(); py_assert!(py, f, "f(bytearray(b'Hello World')) == b'Hello World'"); } diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index 653a80d56b5..5d02b8e8991 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -4,6 +4,7 @@ fn test_compile_errors() { let t = trybuild::TestCases::new(); t.compile_fail("tests/ui/invalid_frompy_derive.rs"); t.compile_fail("tests/ui/invalid_macro_args.rs"); + t.compile_fail("tests/ui/invalid_need_module_arg_position.rs"); t.compile_fail("tests/ui/invalid_property_args.rs"); t.compile_fail("tests/ui/invalid_pyclass_args.rs"); t.compile_fail("tests/ui/invalid_pymethod_names.rs"); diff --git a/tests/test_exceptions.rs b/tests/test_exceptions.rs index 3726dfb7e39..d232f29d1c6 100644 --- a/tests/test_exceptions.rs +++ b/tests/test_exceptions.rs @@ -19,7 +19,7 @@ fn fail_to_open_file() -> PyResult<()> { fn test_filenotfounderror() { let gil = Python::acquire_gil(); let py = gil.python(); - let fail_to_open_file = wrap_pyfunction!(fail_to_open_file)(py); + let fail_to_open_file = wrap_pyfunction!(fail_to_open_file)(py).unwrap(); py_run!( py, @@ -64,7 +64,7 @@ fn call_fail_with_custom_error() -> PyResult<()> { fn test_custom_error() { let gil = Python::acquire_gil(); let py = gil.python(); - let call_fail_with_custom_error = wrap_pyfunction!(call_fail_with_custom_error)(py); + let call_fail_with_custom_error = wrap_pyfunction!(call_fail_with_custom_error)(py).unwrap(); py_run!( py, diff --git a/tests/test_module.rs b/tests/test_module.rs index 0746fb8f868..7c278bdcd5e 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -1,6 +1,6 @@ use pyo3::prelude::*; -use pyo3::types::{IntoPyDict, PyTuple}; +use pyo3::types::{IntoPyDict, PyDict, PyTuple}; mod common; @@ -35,7 +35,7 @@ fn double(x: usize) -> usize { /// This module is implemented in Rust. #[pymodule] -fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { +fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; #[pyfn(m, "sum_as_string")] @@ -49,6 +49,11 @@ fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { Ok(42) } + #[pyfn(m, "with_module", pass_module)] + fn with_module(module: &PyModule) -> PyResult<&str> { + module.name() + } + #[pyfn(m, "double_value")] fn double_value(v: &ValueClass) -> usize { v.value * 2 @@ -60,8 +65,8 @@ fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { m.add("foo", "bar").unwrap(); - m.add_wrapped(wrap_pyfunction!(double)).unwrap(); - m.add("also_double", wrap_pyfunction!(double)(py)).unwrap(); + m.add_function(wrap_pyfunction!(double)).unwrap(); + m.add("also_double", wrap_pyfunction!(double)(m)?).unwrap(); Ok(()) } @@ -97,6 +102,7 @@ fn test_module_with_functions() { run("assert module_with_functions.also_double(3) == 6"); run("assert module_with_functions.also_double.__doc__ == 'Doubles the given value'"); run("assert module_with_functions.double_value(module_with_functions.ValueClass(1)) == 2"); + run("assert module_with_functions.with_module() == 'module_with_functions'"); } #[pymodule(other_name)] @@ -157,7 +163,7 @@ fn r#move() -> usize { fn raw_ident_module(_py: Python, module: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; - module.add_wrapped(wrap_pyfunction!(r#move)) + module.add_function(wrap_pyfunction!(r#move)) } #[test] @@ -182,7 +188,7 @@ fn custom_named_fn() -> usize { fn foobar_module(_py: Python, m: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; - m.add_wrapped(wrap_pyfunction!(custom_named_fn))?; + m.add_function(wrap_pyfunction!(custom_named_fn))?; m.dict().set_item("yay", "me")?; Ok(()) } @@ -212,11 +218,18 @@ fn subfunction() -> String { "Subfunction".to_string() } +fn submodule(module: &PyModule) -> PyResult<()> { + use pyo3::wrap_pyfunction; + + module.add_function(wrap_pyfunction!(subfunction))?; + Ok(()) +} + #[pymodule] -fn submodule(_py: Python, module: &PyModule) -> PyResult<()> { +fn submodule_with_init_fn(_py: Python, module: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; - module.add_wrapped(wrap_pyfunction!(subfunction))?; + module.add_function(wrap_pyfunction!(subfunction))?; Ok(()) } @@ -226,11 +239,16 @@ fn superfunction() -> String { } #[pymodule] -fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> { - use pyo3::{wrap_pyfunction, wrap_pymodule}; +fn supermodule(py: Python, module: &PyModule) -> PyResult<()> { + use pyo3::wrap_pyfunction; - module.add_wrapped(wrap_pyfunction!(superfunction))?; - module.add_wrapped(wrap_pymodule!(submodule))?; + module.add_function(wrap_pyfunction!(superfunction))?; + let module_to_add = PyModule::new(py, "submodule")?; + submodule(module_to_add)?; + module.add_submodule(module_to_add)?; + let module_to_add = PyModule::new(py, "submodule_with_init_fn")?; + submodule_with_init_fn(py, module_to_add)?; + module.add_submodule(module_to_add)?; Ok(()) } @@ -252,6 +270,11 @@ fn test_module_nesting() { supermodule, "supermodule.submodule.subfunction() == 'Subfunction'" ); + py_assert!( + py, + supermodule, + "supermodule.submodule_with_init_fn.subfunction() == 'Subfunction'" + ); } // Test that argument parsing specification works for pyfunctions @@ -268,7 +291,7 @@ fn vararg_module(_py: Python, m: &PyModule) -> PyResult<()> { ext_vararg_fn(py, a, vararg) } - m.add_wrapped(pyo3::wrap_pyfunction!(ext_vararg_fn)) + m.add_function(pyo3::wrap_pyfunction!(ext_vararg_fn)) .unwrap(); Ok(()) } @@ -305,3 +328,82 @@ fn test_module_with_constant() { py_assert!(py, m, "isinstance(m.ANON, m.AnonClass)"); }); } + +#[pyfunction(pass_module)] +fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> { + module.name() +} + +#[pyfunction(pass_module)] +fn pyfunction_with_module_and_py<'a>( + module: &'a PyModule, + _python: Python<'a>, +) -> PyResult<&'a str> { + module.name() +} + +#[pyfunction(pass_module)] +fn pyfunction_with_module_and_arg(module: &PyModule, string: String) -> PyResult<(&str, String)> { + module.name().map(|s| (s, string)) +} + +#[pyfunction(pass_module, string = "\"foo\"")] +fn pyfunction_with_module_and_default_arg<'a>( + module: &'a PyModule, + string: &str, +) -> PyResult<(&'a str, String)> { + module.name().map(|s| (s, string.into())) +} + +#[pyfunction(pass_module, args = "*", kwargs = "**")] +fn pyfunction_with_module_and_args_kwargs<'a>( + module: &'a PyModule, + args: &PyTuple, + kwargs: Option<&PyDict>, +) -> PyResult<(&'a str, usize, Option)> { + module + .name() + .map(|s| (s, args.len(), kwargs.map(|d| d.len()))) +} + +#[pymodule] +fn module_with_functions_with_module(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module))?; + m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module_and_py))?; + m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module_and_arg))?; + m.add_function(pyo3::wrap_pyfunction!( + pyfunction_with_module_and_default_arg + ))?; + m.add_function(pyo3::wrap_pyfunction!( + pyfunction_with_module_and_args_kwargs + )) +} + +#[test] +fn test_module_functions_with_module() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let m = pyo3::wrap_pymodule!(module_with_functions_with_module)(py); + py_assert!( + py, + m, + "m.pyfunction_with_module() == 'module_with_functions_with_module'" + ); + py_assert!( + py, + m, + "m.pyfunction_with_module_and_py() == 'module_with_functions_with_module'" + ); + py_assert!( + py, + m, + "m.pyfunction_with_module_and_default_arg() \ + == ('module_with_functions_with_module', 'foo')" + ); + py_assert!( + py, + m, + "m.pyfunction_with_module_and_args_kwargs(1, x=1, y=2) \ + == ('module_with_functions_with_module', 1, 2)" + ); +} diff --git a/tests/test_pyfunction.rs b/tests/test_pyfunction.rs index e8e95bdf332..0d8500a36fe 100644 --- a/tests/test_pyfunction.rs +++ b/tests/test_pyfunction.rs @@ -14,7 +14,7 @@ fn test_optional_bool() { // Regression test for issue #932 let gil = Python::acquire_gil(); let py = gil.python(); - let f = wrap_pyfunction!(optional_bool)(py); + let f = wrap_pyfunction!(optional_bool)(py).unwrap(); py_assert!(py, f, "f() == 'Some(true)'"); py_assert!(py, f, "f(True) == 'Some(true)'"); @@ -36,7 +36,7 @@ fn buffer_inplace_add(py: Python, x: PyBuffer, y: PyBuffer) { fn test_buffer_add() { let gil = Python::acquire_gil(); let py = gil.python(); - let f = wrap_pyfunction!(buffer_inplace_add)(py); + let f = wrap_pyfunction!(buffer_inplace_add)(py).unwrap(); py_expect_exception!( py, diff --git a/tests/test_string.rs b/tests/test_string.rs index 6236484a578..38d375b5418 100644 --- a/tests/test_string.rs +++ b/tests/test_string.rs @@ -14,7 +14,7 @@ fn test_unicode_encode_error() { let gil = Python::acquire_gil(); let py = gil.python(); - let take_str = wrap_pyfunction!(take_str)(py); + let take_str = wrap_pyfunction!(take_str)(py).unwrap(); py_run!( py, take_str, diff --git a/tests/test_text_signature.rs b/tests/test_text_signature.rs index 85211a34f10..e81260811e5 100644 --- a/tests/test_text_signature.rs +++ b/tests/test_text_signature.rs @@ -104,7 +104,7 @@ fn test_function() { let gil = Python::acquire_gil(); let py = gil.python(); - let f = wrap_pyfunction!(my_function)(py); + let f = wrap_pyfunction!(my_function)(py).unwrap(); py_assert!(py, f, "f.__text_signature__ == '(a, b=None, *, c=42)'"); } diff --git a/tests/test_various.rs b/tests/test_various.rs index 87270c39186..b2de718b22e 100644 --- a/tests/test_various.rs +++ b/tests/test_various.rs @@ -59,7 +59,7 @@ fn return_custom_class() { assert_eq!(get_zero().unwrap().value, 0); // Using from python - let get_zero = wrap_pyfunction!(get_zero)(py); + let get_zero = wrap_pyfunction!(get_zero)(py).unwrap(); py_assert!(py, get_zero, "get_zero().value == 0"); } @@ -206,5 +206,5 @@ fn result_conversion_function() -> Result<(), MyError> { fn test_result_conversion() { let gil = Python::acquire_gil(); let py = gil.python(); - wrap_pyfunction!(result_conversion_function)(py); + wrap_pyfunction!(result_conversion_function)(py).unwrap(); } diff --git a/tests/ui/invalid_need_module_arg_position.rs b/tests/ui/invalid_need_module_arg_position.rs new file mode 100644 index 00000000000..607b21273f6 --- /dev/null +++ b/tests/ui/invalid_need_module_arg_position.rs @@ -0,0 +1,12 @@ +use pyo3::prelude::*; + +#[pymodule] +fn module(_py: Python, m: &PyModule) -> PyResult<()> { + #[pyfn(m, "with_module", pass_module)] + fn fail(string: &str, module: &PyModule) -> PyResult<&str> { + module.name() + } + Ok(()) +} + +fn main(){} \ No newline at end of file diff --git a/tests/ui/invalid_need_module_arg_position.stderr b/tests/ui/invalid_need_module_arg_position.stderr new file mode 100644 index 00000000000..0fd00964c53 --- /dev/null +++ b/tests/ui/invalid_need_module_arg_position.stderr @@ -0,0 +1,5 @@ +error: Expected &PyModule as first argument with `pass_module`. + --> $DIR/invalid_need_module_arg_position.rs:6:13 + | +6 | fn fail(string: &str, module: &PyModule) -> PyResult<&str> { + | ^^^^^^^^^^^^