diff --git a/pyo3-macros-backend/src/lib.rs b/pyo3-macros-backend/src/lib.rs index 61cdbb630c0..e07ff91c073 100644 --- a/pyo3-macros-backend/src/lib.rs +++ b/pyo3-macros-backend/src/lib.rs @@ -21,7 +21,9 @@ mod pyimpl; mod pymethod; pub use frompyobject::build_derive_from_pyobject; -pub use module::{process_functions_in_module, pymodule_impl, PyModuleOptions}; +pub use module::{ + process_functions_in_module, pymodule_function_impl, pymodule_module_impl, PyModuleOptions, +}; pub use pyclass::{build_py_class, build_py_enum, PyClassArgs}; pub use pyfunction::{build_py_function, PyFunctionOptions}; pub use pyimpl::{build_py_methods, PyClassMethodsType}; diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index 93f22bbe6cd..b34c10bd92e 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -2,6 +2,7 @@ use crate::{ attributes::{self, take_attributes, take_pyo3_options, CrateAttribute, NameAttribute}, + get_doc, pyfunction::{impl_wrap_pyfunction, PyFunctionOptions}, utils::{get_pyo3_crate, PythonDoc}, }; @@ -56,9 +57,132 @@ impl PyModuleOptions { } } +pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { + let syn::ItemMod { + attrs, + vis, + ident, + mod_token, + content, + unsafety: _, + semi: _, + } = &mut module; + let items = match content { + Some((_, items)) => items, + None => bail_spanned!(module.span() => "`#[pymodule]` can only be used on inline modules"), + }; + let options = PyModuleOptions::from_attrs(attrs)?; + let doc = get_doc(attrs, None); + + let name = options.name.unwrap_or_else(|| ident.unraw()); + let krate = get_pyo3_crate(&options.krate); + let pyinit_symbol = format!("PyInit_{}", name); + + let mut module_items = Vec::new(); + let mut module_attrs = Vec::new(); + + fn extract_use_items( + source: &syn::UseTree, + cfg_attrs: &Vec, + names: &mut Vec, + attrs: &mut Vec>, + ) -> Result<()> { + match source { + syn::UseTree::Name(name) => { + names.push(name.ident.clone()); + attrs.push(cfg_attrs.clone()); + } + syn::UseTree::Path(path) => extract_use_items(&path.tree, cfg_attrs, names, attrs)?, + syn::UseTree::Group(group) => { + for tree in &group.items { + extract_use_items(tree, cfg_attrs, names, attrs)? + } + } + syn::UseTree::Glob(glob) => { + bail_spanned!(glob.span() => "#[pyo3] cannot import glob statements") + } + syn::UseTree::Rename(rename) => { + names.push(rename.ident.clone()); + attrs.push(cfg_attrs.clone()); + } + } + Ok(()) + } + + let mut pymodule_init = None; + + for item in items.iter_mut() { + match item { + syn::Item::Use(item_use) => { + let mut is_pyo3 = false; + item_use.attrs.retain(|attr| { + let found = attr.path().is_ident("pyo3"); + is_pyo3 |= found; + !found + }); + if is_pyo3 { + let cfg_attrs: Vec<_> = item_use + .attrs + .iter() + .filter(|attr| attr.path().is_ident("cfg")) + .map(Clone::clone) + .collect(); + extract_use_items( + &item_use.tree, + &cfg_attrs, + &mut module_items, + &mut module_attrs, + )?; + } + } + syn::Item::Fn(item_fn) => { + let mut is_module_init = false; + item_fn.attrs.retain(|attr| { + let found = attr.path().is_ident("pymodule_init"); + is_module_init |= found; + !found + }); + if is_module_init { + ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one pymodule_init may be specified"); + let ident = &item_fn.sig.ident; + pymodule_init = Some(quote! { #ident(module)?; }); + } + } + _ => {} + } + } + + Ok(quote! { + #vis #mod_token #ident { + #(#items)* + + pub static DEF: #krate::impl_::pymodule::ModuleDef = unsafe { + use #krate::impl_::pymodule as impl_; + impl_::ModuleDef::new(concat!(stringify!(#name), "\0"), #doc, impl_::ModuleInitializer(__pyo3_pymodule)) + }; + + pub fn __pyo3_pymodule(_py: #krate::Python, module: &#krate::types::PyModule) -> #krate::PyResult<()> { + #( + #(#module_attrs)* + #module_items::DEF.add_to_module(module)?; + )* + #pymodule_init + Ok(()) + } + + /// This autogenerated function is called by the python interpreter when importing + /// the module. + #[export_name = #pyinit_symbol] + pub unsafe extern "C" fn __pyo3_init() -> *mut #krate::ffi::PyObject { + #krate::impl_::trampoline::module_init(|py| DEF.make_module(py)) + } + } + }) +} + /// Generates the function that is called by the python interpreter to initialize the native /// module -pub fn pymodule_impl( +pub fn pymodule_function_impl( fnname: &Ident, options: PyModuleOptions, doc: PythonDoc, diff --git a/pyo3-macros/src/lib.rs b/pyo3-macros/src/lib.rs index 37c7e6e9b99..43d8b374cc0 100644 --- a/pyo3-macros/src/lib.rs +++ b/pyo3-macros/src/lib.rs @@ -8,8 +8,8 @@ use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use pyo3_macros_backend::{ build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods, - get_doc, process_functions_in_module, pymodule_impl, PyClassArgs, PyClassMethodsType, - PyFunctionOptions, PyModuleOptions, + get_doc, process_functions_in_module, pymodule_function_impl, pymodule_module_impl, + PyClassArgs, PyClassMethodsType, PyFunctionOptions, PyModuleOptions, }; use quote::quote; use syn::{parse::Nothing, parse_macro_input}; @@ -39,25 +39,30 @@ use syn::{parse::Nothing, parse_macro_input}; pub fn pymodule(args: TokenStream, input: TokenStream) -> TokenStream { parse_macro_input!(args as Nothing); - let mut ast = parse_macro_input!(input as syn::ItemFn); - let options = match PyModuleOptions::from_attrs(&mut ast.attrs) { - Ok(options) => options, - Err(e) => return e.into_compile_error().into(), - }; - - if let Err(err) = process_functions_in_module(&options, &mut ast) { - return err.into_compile_error().into(); - } - - let doc = get_doc(&ast.attrs, None); + if let Ok(module) = syn::parse(input.clone()) { + pymodule_module_impl(module) + .unwrap_or_compile_error() + .into() + } else { + let mut ast = parse_macro_input!(input as syn::ItemFn); + let options = match PyModuleOptions::from_attrs(&mut ast.attrs) { + Ok(options) => options, + Err(e) => return e.into_compile_error().into(), + }; + + if let Err(err) = process_functions_in_module(&options, &mut ast) { + return err.into_compile_error().into(); + } - let expanded = pymodule_impl(&ast.sig.ident, options, doc, &ast.vis); + let doc = get_doc(&ast.attrs, None); - quote!( - #ast - #expanded - ) - .into() + let expanded = pymodule_function_impl(&ast.sig.ident, options, doc, &ast.vis); + quote!( + #ast + #expanded + ) + .into() + } } #[proc_macro_attribute] diff --git a/pytests/src/lib.rs b/pytests/src/lib.rs index 8724bcaa928..1030a7de7f4 100644 --- a/pytests/src/lib.rs +++ b/pytests/src/lib.rs @@ -1,6 +1,4 @@ use pyo3::prelude::*; -use pyo3::types::PyDict; -use pyo3::wrap_pymodule; pub mod buf_and_str; pub mod comparisons; @@ -16,39 +14,41 @@ pub mod sequence; pub mod subclassing; #[pymodule] -fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> { - #[cfg(not(Py_LIMITED_API))] - m.add_wrapped(wrap_pymodule!(buf_and_str::buf_and_str))?; - m.add_wrapped(wrap_pymodule!(comparisons::comparisons))?; - #[cfg(not(Py_LIMITED_API))] - m.add_wrapped(wrap_pymodule!(datetime::datetime))?; - m.add_wrapped(wrap_pymodule!(dict_iter::dict_iter))?; - m.add_wrapped(wrap_pymodule!(misc::misc))?; - m.add_wrapped(wrap_pymodule!(objstore::objstore))?; - m.add_wrapped(wrap_pymodule!(othermod::othermod))?; - m.add_wrapped(wrap_pymodule!(path::path))?; - m.add_wrapped(wrap_pymodule!(pyclasses::pyclasses))?; - m.add_wrapped(wrap_pymodule!(pyfunctions::pyfunctions))?; - m.add_wrapped(wrap_pymodule!(sequence::sequence))?; - m.add_wrapped(wrap_pymodule!(subclassing::subclassing))?; +mod pyo3_pytests { + use pyo3::types::{PyDict, PyModule}; + use pyo3::PyResult; - // Inserting to sys.modules allows importing submodules nicely from Python - // e.g. import pyo3_pytests.buf_and_str as bas + #[pyo3] + use { + crate::comparisons::comparisons, crate::dict_iter::dict_iter, crate::misc::misc, + crate::objstore::objstore, crate::othermod::othermod, crate::path::path, + crate::pyclasses::pyclasses, crate::pyfunctions::pyfunctions, crate::sequence::sequence, + crate::subclassing::subclassing, + }; + + #[pyo3] + #[cfg(not(Py_LIMITED_API))] + use {crate::buf_and_str::buf_and_str, crate::datetime::datetime}; - let sys = PyModule::import(py, "sys")?; - let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?; - sys_modules.set_item("pyo3_pytests.buf_and_str", m.getattr("buf_and_str")?)?; - sys_modules.set_item("pyo3_pytests.comparisons", m.getattr("comparisons")?)?; - sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?; - sys_modules.set_item("pyo3_pytests.dict_iter", m.getattr("dict_iter")?)?; - sys_modules.set_item("pyo3_pytests.misc", m.getattr("misc")?)?; - sys_modules.set_item("pyo3_pytests.objstore", m.getattr("objstore")?)?; - sys_modules.set_item("pyo3_pytests.othermod", m.getattr("othermod")?)?; - sys_modules.set_item("pyo3_pytests.path", m.getattr("path")?)?; - sys_modules.set_item("pyo3_pytests.pyclasses", m.getattr("pyclasses")?)?; - sys_modules.set_item("pyo3_pytests.pyfunctions", m.getattr("pyfunctions")?)?; - sys_modules.set_item("pyo3_pytests.sequence", m.getattr("sequence")?)?; - sys_modules.set_item("pyo3_pytests.subclassing", m.getattr("subclassing")?)?; + #[pymodule_init] + fn init(m: &PyModule) -> PyResult<()> { + let sys = PyModule::import(m.py(), "sys")?; + let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?; + #[cfg(not(Py_LIMITED_API))] + sys_modules.set_item("pyo3_pytests.buf_and_str", m.getattr("buf_and_str")?)?; + sys_modules.set_item("pyo3_pytests.comparisons", m.getattr("comparisons")?)?; + #[cfg(not(Py_LIMITED_API))] + sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?; + sys_modules.set_item("pyo3_pytests.dict_iter", m.getattr("dict_iter")?)?; + sys_modules.set_item("pyo3_pytests.misc", m.getattr("misc")?)?; + sys_modules.set_item("pyo3_pytests.objstore", m.getattr("objstore")?)?; + sys_modules.set_item("pyo3_pytests.othermod", m.getattr("othermod")?)?; + sys_modules.set_item("pyo3_pytests.path", m.getattr("path")?)?; + sys_modules.set_item("pyo3_pytests.pyclasses", m.getattr("pyclasses")?)?; + sys_modules.set_item("pyo3_pytests.pyfunctions", m.getattr("pyfunctions")?)?; + sys_modules.set_item("pyo3_pytests.sequence", m.getattr("sequence")?)?; + sys_modules.set_item("pyo3_pytests.subclassing", m.getattr("subclassing")?)?; - Ok(()) + Ok(()) + } } diff --git a/src/impl_/pymodule.rs b/src/impl_/pymodule.rs index 2572d431e2c..317318e67f6 100644 --- a/src/impl_/pymodule.rs +++ b/src/impl_/pymodule.rs @@ -82,6 +82,10 @@ impl ModuleDef { (self.initializer.0)(py, module.as_ref(py))?; Ok(module) } + + pub fn add_to_module(&'static self, module: &PyModule) -> PyResult<()> { + module.add_object(self.make_module(module.py())?.into()) + } } #[cfg(test)] diff --git a/src/types/module.rs b/src/types/module.rs index 02881a0e5ae..9a10e2a71c3 100644 --- a/src/types/module.rs +++ b/src/types/module.rs @@ -1,5 +1,5 @@ use crate::callback::IntoPyCallbackOutput; -use crate::err::{PyErr, PyResult}; +use crate::err::{self, PyErr, PyResult}; use crate::exceptions; use crate::ffi; use crate::pyclass::PyClass; @@ -248,6 +248,16 @@ impl PyModule { self.setattr(name, value.into_py(self.py())) } + pub(crate) fn add_object(&self, value: PyObject) -> PyResult<()> { + let py = self.py(); + let attr_name = value.getattr(py, "__name__")?; + + unsafe { + let ret = ffi::PyObject_SetAttr(self.as_ptr(), attr_name.as_ptr(), value.as_ptr()); + err::error_on_minusone(py, ret) + } + } + /// Adds a new class to the module. /// /// Notice that this method does not take an argument. diff --git a/tests/test_module.rs b/tests/test_module.rs index f3977299e17..af79ae83ffe 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -450,3 +450,35 @@ fn test_module_doc_hidden() { py_assert!(py, m, "m.__doc__ == ''"); }) } + +/// A module written using declarative syntax. +#[pymodule] +mod declarative_module { + + #[pyo3] + use super::module_with_functions; +} + +#[test] +fn test_declarative_module() { + Python::with_gil(|py| { + let m = pyo3::wrap_pymodule!(declarative_module)(py).into_ref(py); + py_assert!( + py, + m, + "m.__doc__ == 'A module written using declarative syntax.'" + ); + + let submodule = m.getattr("module_with_functions").unwrap(); + assert_eq!( + submodule + .getattr("no_parameters") + .unwrap() + .call0() + .unwrap() + .extract::() + .unwrap(), + 42 + ); + }) +}