Skip to content

Commit

Permalink
WIP: Possible to pass PyModule as first arg.
Browse files Browse the repository at this point in the history
  • Loading branch information
sebpuetz committed Sep 3, 2020
1 parent 3214249 commit 6f96ae8
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 16 deletions.
59 changes: 44 additions & 15 deletions pyo3-derive-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
//! Code generation for the function that initializes a python module and adds classes and function.
use crate::method;
use crate::pyfunction;
use crate::pyfunction::PyFunctionAttr;
use crate::pymethod;
use crate::pymethod::get_arg_names;
Expand Down Expand Up @@ -78,11 +77,11 @@ fn wrap_fn_argument<'a>(cap: &'a syn::PatType) -> syn::Result<method::FnArg<'a>>
/// Extracts the data from the #[pyfn(...)] attribute of a function
fn extract_pyfn_attrs(
attrs: &mut Vec<syn::Attribute>,
) -> syn::Result<Option<(syn::Path, Ident, Vec<pyfunction::Argument>)>> {
) -> syn::Result<Option<(syn::Path, Ident, PyFunctionAttr)>> {
let mut new_attrs = Vec::new();
let mut fnname = None;
let mut modname = None;
let mut fn_attrs = Vec::new();
let mut fn_attrs = PyFunctionAttr::default();

for attr in attrs.iter() {
match attr.parse_meta() {
Expand Down Expand Up @@ -115,9 +114,7 @@ fn extract_pyfn_attrs(
}
// Read additional arguments
if list.nested.len() >= 3 {
fn_attrs = PyFunctionAttr::from_meta(&meta[2..meta.len()])
.unwrap()
.arguments;
fn_attrs = PyFunctionAttr::from_meta(&meta[2..meta.len()])?;
}
} else {
return Err(syn::Error::new_spanned(
Expand Down Expand Up @@ -148,11 +145,11 @@ fn function_wrapper_ident(name: &Ident) -> Ident {
pub fn add_fn_to_module(
func: &mut syn::ItemFn,
python_name: Ident,
pyfn_attrs: Vec<pyfunction::Argument>,
pyfn_attrs: PyFunctionAttr,
) -> syn::Result<TokenStream> {
let mut arguments = Vec::new();

for input in func.sig.inputs.iter() {
for (i, input) in func.sig.inputs.iter().enumerate() {
match input {
syn::FnArg::Receiver(_) => {
return Err(syn::Error::new_spanned(
Expand All @@ -161,7 +158,27 @@ pub fn add_fn_to_module(
))
}
syn::FnArg::Typed(ref cap) => {
arguments.push(wrap_fn_argument(cap)?);
if pyfn_attrs.need_module && i == 0 {
if let syn::Type::Reference(tyref) = cap.ty.as_ref() {
if let syn::Type::Path(typath) = tyref.elem.as_ref() {
if typath
.path
.segments
.last()
.map(|seg| seg.ident == "PyModule")
.unwrap_or(false)
{
continue;
}
}
}
return Err(syn::Error::new_spanned(
cap,
"Expected &PyModule as first argument with `need_module`.",
));
} else {
arguments.push(wrap_fn_argument(cap)?);
}
}
}
}
Expand All @@ -177,7 +194,7 @@ pub fn add_fn_to_module(
tp: method::FnType::FnStatic,
name: &function_wrapper_ident,
python_name,
attrs: pyfn_attrs,
attrs: pyfn_attrs.arguments,
args: arguments,
output: ty,
doc,
Expand All @@ -187,7 +204,7 @@ pub fn add_fn_to_module(

let python_name = &spec.python_name;

let wrapper = function_c_wrapper(&func.sig.ident, &spec);
let wrapper = function_c_wrapper(&func.sig.ident, &spec, pyfn_attrs.need_module);

Ok(quote! {
fn #function_wrapper_ident<'a>(
Expand Down Expand Up @@ -230,12 +247,23 @@ pub fn add_fn_to_module(
}

/// Generate static function wrapper (PyCFunction, PyCFunctionWithKeywords)
fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>) -> TokenStream {
fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>, need_module: bool) -> TokenStream {
let names: Vec<Ident> = get_arg_names(&spec);
let cb = quote! {
#name(#(#names),*)
let cb;
let slf_module;
if need_module {
cb = quote! {
#name(_slf, #(#names),*)
};
slf_module = Some(quote! {
let _slf = _py.from_borrowed_ptr::<pyo3::types::PyModule>(_slf);
});
} else {
cb = quote! {
#name(#(#names),*)
};
slf_module = None;
};

let body = pymethod::impl_arg_params(spec, None, cb);

quote! {
Expand All @@ -246,6 +274,7 @@ fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>) -> TokenStream {
{
const _LOCATION: &'static str = concat!(stringify!(#name), "()");
pyo3::callback_body!(_py, {
#slf_module
let _args = _py.from_borrowed_ptr::<pyo3::types::PyTuple>(_args);
let _kwargs: Option<&pyo3::types::PyDict> = _py.from_borrowed_ptr_or_opt(_kwargs);

Expand Down
6 changes: 5 additions & 1 deletion pyo3-derive-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub struct PyFunctionAttr {
has_kw: bool,
has_varargs: bool,
has_kwargs: bool,
pub need_module: bool,
}

impl syn::parse::Parse for PyFunctionAttr {
Expand All @@ -45,6 +46,9 @@ impl PyFunctionAttr {

pub fn add_item(&mut self, item: &NestedMeta) -> syn::Result<()> {
match item {
NestedMeta::Meta(syn::Meta::Path(ref ident)) if ident.is_ident("need_module") => {
self.need_module = true;
}
NestedMeta::Meta(syn::Meta::Path(ref ident)) => self.add_work(item, ident)?,
NestedMeta::Meta(syn::Meta::NameValue(ref nv)) => {
self.add_name_value(item, nv)?;
Expand Down Expand Up @@ -204,7 +208,7 @@ pub fn parse_name_attribute(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Opti
pub fn build_py_function(ast: &mut syn::ItemFn, args: PyFunctionAttr) -> syn::Result<TokenStream> {
let python_name =
parse_name_attribute(&mut ast.attrs)?.unwrap_or_else(|| ast.sig.ident.unraw());
add_fn_to_module(ast, python_name, args.arguments)
add_fn_to_module(ast, python_name, args)
}

#[cfg(test)]
Expand Down

0 comments on commit 6f96ae8

Please sign in to comment.