Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thread pyo3's path through the builder functions #3907

Merged
merged 2 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions pyo3-macros-backend/src/deprecations.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::utils::Ctx;
use proc_macro2::{Span, TokenStream};
use quote::{quote_spanned, ToTokens};

Expand All @@ -14,28 +15,30 @@ impl Deprecation {
}
}

#[derive(Default)]
pub struct Deprecations(Vec<(Deprecation, Span)>);
pub struct Deprecations<'ctx>(Vec<(Deprecation, Span)>, &'ctx Ctx);

impl Deprecations {
pub fn new() -> Self {
Deprecations(Vec::new())
impl<'ctx> Deprecations<'ctx> {
pub fn new(ctx: &'ctx Ctx) -> Self {
Deprecations(Vec::new(), ctx)
}

pub fn push(&mut self, deprecation: Deprecation, span: Span) {
self.0.push((deprecation, span))
}
}

impl ToTokens for Deprecations {
impl<'ctx> ToTokens for Deprecations<'ctx> {
fn to_tokens(&self, tokens: &mut TokenStream) {
for (deprecation, span) in &self.0 {
let Self(deprecations, Ctx { pyo3_path }) = self;

for (deprecation, span) in deprecations {
let pyo3_path = pyo3_path.to_tokens_spanned(*span);
let ident = deprecation.ident(*span);
quote_spanned!(
*span =>
#[allow(clippy::let_unit_value)]
{
let _ = _pyo3::impl_::deprecations::#ident;
let _ = #pyo3_path::impl_::deprecations::#ident;
}
)
.to_tokens(tokens)
Expand Down
72 changes: 38 additions & 34 deletions pyo3-macros-backend/src/frompyobject.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use crate::{
attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute},
utils::get_pyo3_crate,
};
use crate::attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute};
use crate::utils::Ctx;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{
Expand Down Expand Up @@ -46,14 +44,15 @@ impl<'a> Enum<'a> {
}

/// Build derivation body for enums.
fn build(&self) -> TokenStream {
fn build(&self, ctx: &Ctx) -> TokenStream {
let Ctx { pyo3_path } = ctx;
let mut var_extracts = Vec::new();
let mut variant_names = Vec::new();
let mut error_names = Vec::new();
for var in &self.variants {
let struct_derive = var.build();
let struct_derive = var.build(ctx);
let ext = quote!({
let maybe_ret = || -> _pyo3::PyResult<Self> {
let maybe_ret = || -> #pyo3_path::PyResult<Self> {
#struct_derive
}();

Expand All @@ -73,7 +72,7 @@ impl<'a> Enum<'a> {
#(#var_extracts),*
];
::std::result::Result::Err(
_pyo3::impl_::frompyobject::failed_to_extract_enum(
#pyo3_path::impl_::frompyobject::failed_to_extract_enum(
obj.py(),
#ty_name,
&[#(#variant_names),*],
Expand Down Expand Up @@ -239,57 +238,60 @@ impl<'a> Container<'a> {
}

/// Build derivation body for a struct.
fn build(&self) -> TokenStream {
fn build(&self, ctx: &Ctx) -> TokenStream {
match &self.ty {
ContainerType::StructNewtype(ident, from_py_with) => {
self.build_newtype_struct(Some(ident), from_py_with)
self.build_newtype_struct(Some(ident), from_py_with, ctx)
}
ContainerType::TupleNewtype(from_py_with) => {
self.build_newtype_struct(None, from_py_with)
self.build_newtype_struct(None, from_py_with, ctx)
}
ContainerType::Tuple(tups) => self.build_tuple_struct(tups),
ContainerType::Struct(tups) => self.build_struct(tups),
ContainerType::Tuple(tups) => self.build_tuple_struct(tups, ctx),
ContainerType::Struct(tups) => self.build_struct(tups, ctx),
}
}

fn build_newtype_struct(
&self,
field_ident: Option<&Ident>,
from_py_with: &Option<FromPyWithAttribute>,
ctx: &Ctx,
) -> TokenStream {
let Ctx { pyo3_path } = ctx;
let self_ty = &self.path;
let struct_name = self.name();
if let Some(ident) = field_ident {
let field_name = ident.to_string();
match from_py_with {
None => quote! {
Ok(#self_ty {
#ident: _pyo3::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
#ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
})
},
Some(FromPyWithAttribute {
value: expr_path, ..
}) => quote! {
Ok(#self_ty {
#ident: _pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, #field_name)?
#ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, #field_name)?
})
},
}
} else {
match from_py_with {
None => quote!(
_pyo3::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty)
#pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty)
),
Some(FromPyWithAttribute {
value: expr_path, ..
}) => quote! (
_pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, 0).map(#self_ty)
#pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, 0).map(#self_ty)
),
}
}
}

fn build_tuple_struct(&self, struct_fields: &[TupleStructField]) -> TokenStream {
fn build_tuple_struct(&self, struct_fields: &[TupleStructField], ctx: &Ctx) -> TokenStream {
let Ctx { pyo3_path } = ctx;
let self_ty = &self.path;
let struct_name = &self.name();
let field_idents: Vec<_> = (0..struct_fields.len())
Expand All @@ -298,12 +300,12 @@ impl<'a> Container<'a> {
let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
match &field.from_py_with {
None => quote!(
_pyo3::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
#pyo3_path::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
),
Some(FromPyWithAttribute {
value: expr_path, ..
}) => quote! (
_pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, &#ident, #struct_name, #index)?
#pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, &#ident, #struct_name, #index)?
),
}
});
Expand All @@ -315,7 +317,8 @@ impl<'a> Container<'a> {
)
}

fn build_struct(&self, struct_fields: &[NamedStructField<'_>]) -> TokenStream {
fn build_struct(&self, struct_fields: &[NamedStructField<'_>], ctx: &Ctx) -> TokenStream {
let Ctx { pyo3_path } = ctx;
let self_ty = &self.path;
let struct_name = &self.name();
let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
Expand All @@ -324,27 +327,27 @@ impl<'a> Container<'a> {
let field_name = ident.to_string();
let getter = match field.getter.as_ref().unwrap_or(&FieldGetter::GetAttr(None)) {
FieldGetter::GetAttr(Some(name)) => {
quote!(getattr(_pyo3::intern!(obj.py(), #name)))
quote!(getattr(#pyo3_path::intern!(obj.py(), #name)))
}
FieldGetter::GetAttr(None) => {
quote!(getattr(_pyo3::intern!(obj.py(), #field_name)))
quote!(getattr(#pyo3_path::intern!(obj.py(), #field_name)))
}
FieldGetter::GetItem(Some(syn::Lit::Str(key))) => {
quote!(get_item(_pyo3::intern!(obj.py(), #key)))
quote!(get_item(#pyo3_path::intern!(obj.py(), #key)))
}
FieldGetter::GetItem(Some(key)) => quote!(get_item(#key)),
FieldGetter::GetItem(None) => {
quote!(get_item(_pyo3::intern!(obj.py(), #field_name)))
quote!(get_item(#pyo3_path::intern!(obj.py(), #field_name)))
}
};
let extractor = match &field.from_py_with {
None => {
quote!(_pyo3::impl_::frompyobject::extract_struct_field(&obj.#getter?, #struct_name, #field_name)?)
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&obj.#getter?, #struct_name, #field_name)?)
}
Some(FromPyWithAttribute {
value: expr_path, ..
}) => {
quote! (_pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &obj.#getter?, #struct_name, #field_name)?)
quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &obj.#getter?, #struct_name, #field_name)?)
}
};

Expand Down Expand Up @@ -579,23 +582,25 @@ pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
.push(parse_quote!(#gen_ident: FromPyObject<#lt_param>))
}
let options = ContainerOptions::from_attrs(&tokens.attrs)?;
let krate = get_pyo3_crate(&options.krate);
let ctx = &Ctx::new(&options.krate);
let Ctx { pyo3_path } = &ctx;

let derives = match &tokens.data {
syn::Data::Enum(en) => {
if options.transparent || options.annotation.is_some() {
bail_spanned!(tokens.span() => "`transparent` or `annotation` is not supported \
at top level for enums");
}
let en = Enum::new(en, &tokens.ident)?;
en.build()
en.build(ctx)
}
syn::Data::Struct(st) => {
if let Some(lit_str) = &options.annotation {
bail_spanned!(lit_str.span() => "`annotation` is unsupported for structs");
}
let ident = &tokens.ident;
let st = Container::new(&st.fields, parse_quote!(#ident), options)?;
st.build()
st.build(ctx)
}
syn::Data::Union(_) => bail_spanned!(
tokens.span() => "#[derive(FromPyObject)] is not supported for unions"
Expand All @@ -607,12 +612,11 @@ pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
// FIXME https://github.com/PyO3/pyo3/issues/3903
#[allow(unknown_lints, non_local_definitions)]
const _: () = {
use #krate as _pyo3;
use _pyo3::prelude::PyAnyMethods;
use #pyo3_path::prelude::PyAnyMethods;

#[automatically_derived]
impl #trait_generics _pyo3::FromPyObject<#lt_param> for #ident #generics #where_clause {
fn extract_bound(obj: &_pyo3::Bound<#lt_param, _pyo3::PyAny>) -> _pyo3::PyResult<Self> {
impl #trait_generics #pyo3_path::FromPyObject<#lt_param> for #ident #generics #where_clause {
fn extract_bound(obj: &#pyo3_path::Bound<#lt_param, #pyo3_path::PyAny>) -> #pyo3_path::PyResult<Self> {
#derives
}
}
Expand Down
17 changes: 9 additions & 8 deletions pyo3-macros-backend/src/konst.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::borrow::Cow;

use crate::utils::Ctx;
use crate::{
attributes::{self, get_pyo3_options, take_attributes, NameAttribute},
deprecations::Deprecations,
Expand All @@ -13,12 +14,12 @@ use syn::{
Result,
};

pub struct ConstSpec {
pub struct ConstSpec<'ctx> {
pub rust_ident: syn::Ident,
pub attributes: ConstAttributes,
pub attributes: ConstAttributes<'ctx>,
}

impl ConstSpec {
impl ConstSpec<'_> {
pub fn python_name(&self) -> Cow<'_, Ident> {
if let Some(name) = &self.attributes.name {
Cow::Borrowed(&name.value.0)
Expand All @@ -34,10 +35,10 @@ impl ConstSpec {
}
}

pub struct ConstAttributes {
pub struct ConstAttributes<'ctx> {
pub is_class_attr: bool,
pub name: Option<NameAttribute>,
pub deprecations: Deprecations,
pub deprecations: Deprecations<'ctx>,
}

pub enum PyO3ConstAttribute {
Expand All @@ -55,12 +56,12 @@ impl Parse for PyO3ConstAttribute {
}
}

impl ConstAttributes {
pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
impl<'ctx> ConstAttributes<'ctx> {
pub fn from_attrs(attrs: &mut Vec<syn::Attribute>, ctx: &'ctx Ctx) -> syn::Result<Self> {
let mut attributes = ConstAttributes {
is_class_attr: false,
name: None,
deprecations: Deprecations::new(),
deprecations: Deprecations::new(ctx),
};

take_attributes(attrs, |attr| {
Expand Down
Loading
Loading