Skip to content

Commit

Permalink
Improve type mapping and use call_method0()/1() where appropriate
Browse files Browse the repository at this point in the history
Signed-off-by: Andrej Orsula <orsula.andrej@gmail.com>
  • Loading branch information
AndrejOrsula committed Jan 23, 2024
1 parent a096f39 commit a2a630c
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 120 deletions.
7 changes: 6 additions & 1 deletion pyo3_bindgen_engine/src/bindgen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@ pub use class::bind_class;
pub use function::bind_function;
pub use module::{bind_module, bind_reexport};

// TODO: Ensure there are no duplicate entries in the generated code
// TODO: Refactor everything into a large configurable struct that keeps track of all the
// important information needed to properly generate the bindings
// - Use builder pattern for the configuration of the struct
// - Keep track of all the types/classes that have been generated
// - Keep track of all imports to understand where each type is coming from
// - Keep track of all the external types that are used as parameters/return types and consider generating bindings for them as well

// TODO: Ensure there are no duplicate entries in the generated code

/// Generate Rust bindings to a Python module specified by its name. Generating bindings to
/// submodules such as `os.path` is also supported as long as the module can be directly imported
Expand Down
2 changes: 1 addition & 1 deletion pyo3_bindgen_engine/src/bindgen/attribute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::types::Type;

/// Generate Rust bindings to a Python attribute. The attribute can be a standalone
/// attribute or a property of a class.
pub fn bind_attribute<S: ::std::hash::BuildHasher>(
pub fn bind_attribute<S: ::std::hash::BuildHasher + Default>(
py: pyo3::Python,
module_name: &str,
is_class: bool,
Expand Down
3 changes: 2 additions & 1 deletion pyo3_bindgen_engine/src/bindgen/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::bindgen::{bind_attribute, bind_function};

/// Generate Rust bindings to a Python class with all its methods and attributes (properties).
/// This function will call itself recursively to generate bindings to all nested classes.
pub fn bind_class<S: ::std::hash::BuildHasher>(
pub fn bind_class<S: ::std::hash::BuildHasher + Default>(
py: pyo3::Python,
root_module: &pyo3::types::PyModule,
class: &pyo3::types::PyType,
Expand Down Expand Up @@ -165,6 +165,7 @@ pub fn bind_class<S: ::std::hash::BuildHasher>(
});

// Add new and call aliases (currently a reimplemented versions of the function)
// TODO: Call the Rust `self.__init__()` and `self.__call__()` functions directly instead of reimplementing it
if fn_names.contains(&"__init__".to_string()) && !fn_names.contains(&"new".to_string()) {
impl_token_stream.extend(bind_function(
py,
Expand Down
141 changes: 65 additions & 76 deletions pyo3_bindgen_engine/src/bindgen/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::types::Type;

/// Generate Rust bindings to a Python function. The function can be a standalone function or a
/// method of a class.
pub fn bind_function<S: ::std::hash::BuildHasher>(
pub fn bind_function<S: ::std::hash::BuildHasher + Default>(
py: pyo3::Python,
module_name: &str,
name: &str,
Expand Down Expand Up @@ -168,90 +168,79 @@ pub fn bind_function<S: ::std::hash::BuildHasher>(
doc = String::new();
};

// TODO: Use `call_method0` and `call_method1`` where appropriate
Ok(if has_self_param {
if let Some(var_keyword_ident) = var_keyword_ident {
let (maybe_ref_self, callable_object) = if has_self_param {
(quote::quote! { &'py self, }, quote::quote! { self })
} else {
(
quote::quote! {},
quote::quote! { py.import(::pyo3::intern!(py, #module_name))? },
)
};

let has_positional_args = !positional_args_idents.is_empty();
let set_args = match (
positional_args_idents.len() > 1,
var_positional_ident.is_some(),
) {
(true, _) => {
quote::quote! {
#[doc = #doc]
pub fn #function_ident<'py>(
&'py self,
py: ::pyo3::marker::Python<'py>,
#(#param_idents: #param_types),*
) -> ::pyo3::PyResult<#return_annotation> {
#[allow(unused_imports)]
use ::pyo3::IntoPy;
let __internal_args = (
#({
let #positional_args_idents: &'py ::pyo3::PyAny = #positional_args_idents.into();
#positional_args_idents
},)*
);
let __internal_kwargs = #var_keyword_ident;
#(__internal_kwargs.set_item(::pyo3::intern!(py, #keyword_args_names), #keyword_args_idents)?;)*
self.call_method(::pyo3::intern!(py, #function_name), __internal_args, Some(__internal_kwargs))?.extract()
}
let __internal_args = ::pyo3::types::PyTuple::new(
py,
[#(::pyo3::IntoPy::<::pyo3::PyObject>::into_py(#positional_args_idents.to_owned(), py).as_ref(py),)*]
);

Check warning on line 190 in pyo3_bindgen_engine/src/bindgen/function.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/bindgen/function.rs#L187-L190

Added lines #L187 - L190 were not covered by tests
}
} else {
}
(false, true) => {
let var_positional_ident = var_positional_ident.unwrap();
quote::quote! {
#[doc = #doc]
pub fn #function_ident<'py>(
&'py self,
py: ::pyo3::marker::Python<'py>,
#(#param_idents: #param_types),*
) -> ::pyo3::PyResult<#return_annotation> {
#[allow(unused_imports)]
use ::pyo3::IntoPy;
let __internal_args = (
#({
let #positional_args_idents: &'py ::pyo3::PyAny = #positional_args_idents.into();
#positional_args_idents
},)*
);
let __internal_kwargs = ::pyo3::types::PyDict::new(py);
#(__internal_kwargs.set_item(::pyo3::intern!(py, #keyword_args_names), #keyword_args_idents)?;)*
self.call_method(::pyo3::intern!(py, #function_name), __internal_args, Some(__internal_kwargs))?.extract()
}
let __internal_args = #var_positional_ident;
}
}
} else if let Some(var_keyword_ident) = var_keyword_ident {
quote::quote! {
#[doc = #doc]
pub fn #function_ident<'py>(
py: ::pyo3::marker::Python<'py>,
#(#param_idents: #param_types),*
) -> ::pyo3::PyResult<#return_annotation> {
#[allow(unused_imports)]
use ::pyo3::IntoPy;
let __internal_args = (
#({
let #positional_args_idents: &'py ::pyo3::PyAny = #positional_args_idents.into();
#positional_args_idents
},)*
);
let __internal_kwargs = #var_keyword_ident;
#(__internal_kwargs.set_item(::pyo3::intern!(py, #keyword_args_names), #keyword_args_idents)?;)*
py.import(::pyo3::intern!(py, #module_name))?.call_method(::pyo3::intern!(py, #function_name), __internal_args, Some(__internal_kwargs))?.extract()
}
(false, false) => {
quote::quote! { let __internal_args = (); }
}
};

let has_kwargs = !keyword_args_idents.is_empty();
let kwargs_initial = if let Some(var_keyword_ident) = var_keyword_ident {
quote::quote! { #var_keyword_ident }
} else {
quote::quote! {
quote::quote! { ::pyo3::types::PyDict::new(py) }
};
let set_kwargs = quote::quote! {
let __internal_kwargs = #kwargs_initial;
#(__internal_kwargs.set_item(::pyo3::intern!(py, #keyword_args_names), #keyword_args_idents)?;)*
};

let call_method = match (has_positional_args, has_kwargs) {
(_, true) => {
quote::quote! {
#set_args
#set_kwargs
#callable_object.call_method(::pyo3::intern!(py, #function_name), __internal_args, Some(__internal_kwargs))?
}
}
(true, false) => {
quote::quote! {
#set_args
#callable_object.call_method1(::pyo3::intern!(py, #function_name), __internal_args)?
}

Check warning on line 227 in pyo3_bindgen_engine/src/bindgen/function.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/bindgen/function.rs#L224-L227

Added lines #L224 - L227 were not covered by tests
}
(false, false) => {
quote::quote! {
#callable_object.call_method0(::pyo3::intern!(py, #function_name))?
}

Check warning on line 232 in pyo3_bindgen_engine/src/bindgen/function.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/bindgen/function.rs#L230-L232

Added lines #L230 - L232 were not covered by tests
}
};

Ok(quote::quote! {
#[doc = #doc]
pub fn #function_ident<'py>(
py: ::pyo3::marker::Python<'py>,
#(#param_idents: #param_types),*
) -> ::pyo3::PyResult<#return_annotation> {
#[allow(unused_imports)]
use ::pyo3::IntoPy;
let __internal_args = (
#({
let #positional_args_idents: &'py ::pyo3::PyAny = #positional_args_idents.into();
#positional_args_idents
},)*
);
let __internal_kwargs = ::pyo3::types::PyDict::new(py);
#(__internal_kwargs.set_item(::pyo3::intern!(py, #keyword_args_names), #keyword_args_idents)?;)*
py.import(::pyo3::intern!(py, #module_name))?.call_method(::pyo3::intern!(py, #function_name), __internal_args, Some(__internal_kwargs))?.extract()
}
#maybe_ref_self
py: ::pyo3::marker::Python<'py>,
#(#param_idents: #param_types),*
) -> ::pyo3::PyResult<#return_annotation> {
#call_method.extract()
}
})
}
2 changes: 1 addition & 1 deletion pyo3_bindgen_engine/src/bindgen/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::bindgen::{bind_attribute, bind_class, bind_function};
/// attributes of the Python module. During the first call, the `root_module` argument should be
/// the same as the `module` argument and the `processed_modules` argument should be an empty
/// `HashSet`.
pub fn bind_module<S: ::std::hash::BuildHasher>(
pub fn bind_module<S: ::std::hash::BuildHasher + Default>(
py: pyo3::Python,
root_module: &pyo3::types::PyModule,
module: &pyo3::types::PyModule,
Expand Down
87 changes: 56 additions & 31 deletions pyo3_bindgen_engine/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ pub enum Type {
PyString,

// Enums
// TODO: Optional causes issues when passed as a position-only argument to a function. Fix!
Optional(Box<Type>),
Union(Vec<Type>),
PyNone,
Expand Down Expand Up @@ -513,7 +512,7 @@ impl Type {
}

#[must_use]
pub fn into_rs<S: ::std::hash::BuildHasher>(
pub fn into_rs<S: ::std::hash::BuildHasher + Default>(
self,
owned: bool,
module_name: &str,
Expand All @@ -527,7 +526,7 @@ impl Type {
}

Check warning on line 526 in pyo3_bindgen_engine/src/types.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/types.rs#L526

Added line #L526 was not covered by tests

#[must_use]
pub fn into_rs_owned<S: ::std::hash::BuildHasher>(
pub fn into_rs_owned<S: ::std::hash::BuildHasher + Default>(
self,
module_name: &str,
all_types: &std::collections::HashSet<String, S>,
Expand Down Expand Up @@ -704,7 +703,7 @@ impl Type {
}

#[must_use]
pub fn into_rs_borrowed<S: ::std::hash::BuildHasher>(
pub fn into_rs_borrowed<S: ::std::hash::BuildHasher + Default>(
self,
module_name: &str,
all_types: &std::collections::HashSet<String, S>,
Expand Down Expand Up @@ -737,7 +736,7 @@ impl Type {

// Enums
Self::Optional(t) => {
let inner = t.into_rs_borrowed(module_name, all_types);
let inner = t.into_rs_owned(module_name, all_types);
quote::quote! {
::std::option::Option<#inner>
}
Expand Down Expand Up @@ -879,14 +878,13 @@ impl Type {
}
}

fn try_into_module_path<S: ::std::hash::BuildHasher>(
fn try_into_module_path<S: ::std::hash::BuildHasher + Default>(
self,
module_name: &str,
all_types: &std::collections::HashSet<String, S>,
) -> proc_macro2::TokenStream {
let value = match self {
Self::Unhandled(value) => value,
_ => unreachable!(),
let Self::Unhandled(value) = self else {
unreachable!()

Check warning on line 887 in pyo3_bindgen_engine/src/types.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/types.rs#L887

Added line #L887 was not covered by tests
};
let module_root = if module_name.contains('.') {
module_name.split('.').next().unwrap()

Check warning on line 890 in pyo3_bindgen_engine/src/types.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/types.rs#L890

Added line #L890 was not covered by tests
Expand Down Expand Up @@ -938,36 +936,63 @@ impl Type {
)
.join("::");

// dbg!(all_types);

// The path contains both ident and "::", combine into something that can be quoted
let reexport_path = syn::parse_str::<syn::Path>(&reexport_path).unwrap();
quote::quote! {
&'py #reexport_path
}

Check warning on line 943 in pyo3_bindgen_engine/src/types.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/types.rs#L935-L943

Added lines #L935 - L943 were not covered by tests
}
_ => {
// TODO: Make this more robust (possibly parsing all local reexports to figure out where the type is coming from)
// TODO: Fix this! The matching is wrong in many cases
let module_member_end_match = value
.split_once('[')
.unwrap_or((&value, ""))
.0
.split('.')
.last()
.unwrap();
if let Some(module_member_full) = all_types
.iter()
.find(|x| x.ends_with(module_member_end_match))
{
Self::Unhandled(module_member_full.to_owned())
.try_into_module_path(module_name, all_types)
} else {
// Unsupported
// TODO: Support more types
// dbg!(value);
quote::quote! {&'py ::pyo3::types::PyAny}
let value_without_brackets = value.split_once('[').unwrap_or((&value, "")).0;
let module_scopes = value_without_brackets.split('.');
let n_module_scopes = module_scopes.clone().count();

// Approach: Find types without a module scope (no dot) and check if the type is local (or imported in the current module)
if !value_without_brackets.contains('.') {
if let Some(member) = all_types
.iter()
.filter(|member| {
member
.split('.')
.take(member.split('.').count() - 1)
.join(".")
== module_name
})
.find(|&member| {
member.trim_start_matches(&format!("{module_name}."))
== value_without_brackets
})

Check warning on line 964 in pyo3_bindgen_engine/src/types.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/types.rs#L952-L964

Added lines #L952 - L964 were not covered by tests
{
return Self::Unhandled(member.to_owned())
.try_into_module_path(module_name, all_types);

Check warning on line 967 in pyo3_bindgen_engine/src/types.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/types.rs#L966-L967

Added lines #L966 - L967 were not covered by tests
}
}

// Approach: Find the shallowest match that contains the value
// TODO: Fix this! The matching might be wrong in many cases
let mut possible_matches = std::collections::HashSet::<String, S>::default();
for i in 0..n_module_scopes {
let module_member_scopes_end = module_scopes.clone().skip(i).join(".");
all_types
.iter()
.filter(|member| member.ends_with(&module_member_scopes_end))
.for_each(|member| {
possible_matches.insert(member.to_owned());

Check warning on line 980 in pyo3_bindgen_engine/src/types.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/types.rs#L980

Added line #L980 was not covered by tests
});
if !possible_matches.is_empty() {
let shallowest_match = possible_matches
.iter()
.min_by(|m1, m2| m1.split('.').count().cmp(&m2.split('.').count()))
.unwrap();
return Self::Unhandled(shallowest_match.to_owned())
.try_into_module_path(module_name, all_types);

Check warning on line 988 in pyo3_bindgen_engine/src/types.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/types.rs#L983-L988

Added lines #L983 - L988 were not covered by tests
}
}

// Unsupported
// TODO: Support more types
// dbg!(value);
quote::quote! {&'py ::pyo3::types::PyAny}
}
}
}
Expand Down
Loading

0 comments on commit a2a630c

Please sign in to comment.