From b2ea35913336e6b43d5954c366f052db91be79f3 Mon Sep 17 00:00:00 2001 From: Nouzan Date: Wed, 4 Oct 2023 12:25:41 +0800 Subject: [PATCH] context: add more attrs --- crates/indicator/src/context/mod.rs | 1 + crates/indicator_derive/src/operator/mod.rs | 54 +++- .../src/operator/signature.rs | 256 +++++++++++++++++- examples/context/context.rs | 12 +- 4 files changed, 303 insertions(+), 20 deletions(-) diff --git a/crates/indicator/src/context/mod.rs b/crates/indicator/src/context/mod.rs index 51c5b49..ec8a8a7 100644 --- a/crates/indicator/src/context/mod.rs +++ b/crates/indicator/src/context/mod.rs @@ -26,6 +26,7 @@ use self::{ pub use self::{ anymap::Context, + extractor::{Data, Env, In, Prev}, layer::{ cache::Cache, data::AddData, diff --git a/crates/indicator_derive/src/operator/mod.rs b/crates/indicator_derive/src/operator/mod.rs index 46ae7e5..666da7f 100644 --- a/crates/indicator_derive/src/operator/mod.rs +++ b/crates/indicator_derive/src/operator/mod.rs @@ -3,8 +3,8 @@ use proc_macro::TokenStream; use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::quote; use syn::{ - punctuated::Punctuated, FnArg, GenericParam, Ident, ItemFn, Lifetime, LifetimeParam, PatType, - ReturnType, Token, Type, Visibility, + punctuated::Punctuated, FnArg, GenericParam, Ident, ItemFn, Lifetime, LifetimeParam, Meta, + PatType, ReturnType, Token, Type, Visibility, }; use self::args::{GenerateOut, OperatorArgs}; @@ -23,8 +23,8 @@ pub(super) fn generate_operator( let args = syn::parse::(args)?; let mut next = syn::parse::(input)?; - let next_fn = generate_next_fn(&next)?; - signature::expand(&mut next.sig, &args)?; + let unattributed = signature::expand(&mut next, &args)?; + let next_fn = generate_next_fn(&unattributed)?; // Documentations. let docs = next @@ -73,28 +73,34 @@ pub(super) fn generate_operator( let indicator = indicator(); let vis = &next.vis; let input_type = &args.input_type; + let unattributed_type_generics = if unattributed.sig.generics.params.is_empty() { + quote!() + } else { + let (_, type_generics, _) = unattributed.sig.generics.split_for_impl(); + quote!(::#type_generics) + }; let return_stmt = match args.generate_out { Some(GenerateOut::Out) => { quote! { - __next(#extractors).into() + __next #unattributed_type_generics (#extractors).into() } } Some(GenerateOut::Data) => { quote! { - __next(#extractors).map(Into::into) + __next #unattributed_type_generics (#extractors).map(Into::into) } } Some(GenerateOut::WithData) => { quote! { { - let (__out, __data) = __next(#extractors); + let (__out, __data) = __next #unattributed_type_generics (#extractors); (__out.into(), __data.map(Into::into)) } } } None => { quote! { - __next(#extractors) + __next #unattributed_type_generics (#extractors) } } }; @@ -130,13 +136,33 @@ fn generate_next_fn(next: &ItemFn) -> syn::Result { fn parse_extractor(arg: &PatType) -> syn::Result { let indicator = indicator(); - let PatType { ty, .. } = arg; - Ok(quote! { - { - let __a: #ty = #indicator::context::extractor::FromValueRef::from_value_ref(&__input); - __a + let PatType { ty, attrs, pat, .. } = arg; + let expanded = if let Some(attr) = attrs.first() { + let Meta::List(attr) = &attr.meta else { + unreachable!("must be meta list"); + }; + let kind = attr.path.get_ident().unwrap(); + let name: Ident = syn::parse2(attr.tokens.clone()).unwrap(); + let rt = match kind.to_string().as_str() { + "borrow" => quote!(core::borrow::Borrow::borrow(#name)), + "as_ref" => quote!(core::convert::AsRef::as_ref(#name)), + _ => unreachable!(), + }; + quote! { + { + let #pat: #ty = #indicator::context::extractor::FromValueRef::from_value_ref(&__input); + #rt + } } - }) + } else { + quote! { + { + let __a: #ty = #indicator::context::extractor::FromValueRef::from_value_ref(&__input); + __a + } + } + }; + Ok(expanded) } fn generate_struct_def( diff --git a/crates/indicator_derive/src/operator/signature.rs b/crates/indicator_derive/src/operator/signature.rs index 79d6d54..257d6ae 100644 --- a/crates/indicator_derive/src/operator/signature.rs +++ b/crates/indicator_derive/src/operator/signature.rs @@ -1,10 +1,27 @@ +use convert_case::{Case, Casing}; +use proc_macro2::{Ident, Span}; use quote::quote; -use syn::{Result, ReturnType, Signature, Type, TypeTuple}; +use syn::{ + FnArg, ItemFn, Meta, Pat, Result, ReturnType, Signature, Type, TypeReference, TypeTuple, +}; use super::args::{GenerateOut, OperatorArgs}; -pub(super) fn expand(sig: &mut Signature, args: &OperatorArgs) -> Result<()> { - expand_generics(sig, args)?; +pub(super) fn expand(input: &mut ItemFn, args: &OperatorArgs) -> Result { + let mut unattributed = input.clone(); + remove_input_attributes(&mut unattributed)?; + expand_generics(&mut input.sig, args)?; + expand_inputs(&mut input.sig, args)?; + Ok(unattributed) +} + +fn remove_input_attributes(item_fn: &mut ItemFn) -> Result<()> { + for arg in item_fn.sig.inputs.iter_mut() { + let FnArg::Typed(arg) = arg else { + return Err(syn::Error::new_spanned(arg, "expected typed argument")); + }; + arg.attrs.clear(); + } Ok(()) } @@ -54,6 +71,239 @@ fn expand_generics(sig: &mut Signature, args: &OperatorArgs) -> Result<()> { Ok(()) } +fn expand_inputs(sig: &mut Signature, args: &OperatorArgs) -> Result<()> { + let mut generics = vec![]; + let mut ctx = Ctx::default(); + for arg in sig.inputs.iter_mut() { + if let Some(generic) = expand_input(&mut ctx, arg, &args.input_type)? { + generics.push(generic); + } + } + if !generics.is_empty() { + sig.generics.params.extend(generics); + } + Ok(()) +} + +#[derive(Default)] +struct Ctx { + input: usize, + env: usize, + data: usize, +} + +enum Way { + Borrow, + AsRef, +} + +impl Way { + fn is_borrow(&self) -> bool { + matches!(self, Way::Borrow) + } +} + +enum ArgKind { + Input(Way), + Env(Way), + Data(Way), + Prev(Way), +} + +impl<'a> TryFrom<&'a Meta> for ArgKind { + type Error = syn::Error; + + fn try_from(value: &'a Meta) -> std::result::Result { + match value { + Meta::Path(path) => { + let ident = path.get_ident().ok_or_else(|| { + syn::Error::new_spanned( + path, + "unsupported attribute, expected `input`, `env` or `data`", + ) + })?; + match ident.to_string().as_str() { + "input" => Ok(ArgKind::Input(Way::Borrow)), + "env" => Ok(ArgKind::Env(Way::Borrow)), + "data" => Ok(ArgKind::Data(Way::Borrow)), + "prev" => Ok(ArgKind::Prev(Way::Borrow)), + _ => Err(syn::Error::new_spanned( + path, + "unsupported attribute, expected `input`, `env` or `data`", + )), + } + } + Meta::List(list) => { + let way: Ident = syn::parse2(list.tokens.clone())?; + let way = match way.to_string().as_str() { + "borrow" => Way::Borrow, + "as_ref" => Way::AsRef, + _ => { + return Err(syn::Error::new_spanned( + way, + "unsupported value, expected `borrow` or `as_ref`", + )) + } + }; + let ident = list.path.get_ident().ok_or_else(|| { + syn::Error::new_spanned( + list, + "unsupported attribute, expected `input`, `env` or `data`", + ) + })?; + match ident.to_string().as_str() { + "input" => Ok(ArgKind::Input(way)), + "env" => Ok(ArgKind::Env(way)), + "data" => Ok(ArgKind::Data(way)), + "prev" => Ok(ArgKind::Prev(way)), + _ => Err(syn::Error::new_spanned( + ident, + "unsupported attribute, expected `input`, `env` or `data`", + )), + } + } + _ => Err(syn::Error::new_spanned( + value, + "unsupported attribute, expected `input`, `env` or `data`", + )), + } + } +} + +fn expand_input( + ctx: &mut Ctx, + fn_arg: &mut FnArg, + input_ty: &Type, +) -> Result> { + let indicator = super::indicator(); + + let FnArg::Typed(arg) = fn_arg else { + return Err(syn::Error::new_spanned(fn_arg, "expected typed argument")); + }; + let attr = arg.attrs.pop(); + if !arg.attrs.is_empty() { + return Err(syn::Error::new_spanned( + arg, + "expected at most one attribute", + )); + } + let Some(attr) = attr else { + return Ok(None); + }; + let Type::Reference(TypeReference { + elem: target_ty, + lifetime: None, + mutability: None, + .. + }) = &*arg.ty + else { + return Err(syn::Error::new_spanned( + arg, + "expected reference type without lifetime and mutability, e.g. `&T`", + )); + }; + let kind: ArgKind = ArgKind::try_from(&attr.meta)?; + let name = get_variable_name(&arg.pat).map(|n| n.to_string()); + let generic = match kind { + ArgKind::Input(way) => { + let name = Ident::new( + &name.unwrap_or_else(|| format!("input{}", ctx.input)), + Span::call_site(), + ); + ctx.input += 1; + let pat = quote!(#indicator::context::In(#name)); + let ty = input_ty.clone(); + let (generic, attr) = if way.is_borrow() { + ( + syn::parse2(quote!(#ty: core::borrow::Borrow<#target_ty>))?, + quote!(#[borrow(#name)]), + ) + } else { + ( + syn::parse2(quote!(#ty: AsRef<#target_ty>))?, + quote!(#[as_ref(#name)]), + ) + }; + *fn_arg = syn::parse2(quote!(#attr #pat: #indicator::context::In<&#ty>))?; + generic + } + ArgKind::Env(way) => { + let name_string = name.unwrap_or_else(|| format!("env{}", ctx.env)); + let name = Ident::new(&name_string, Span::call_site()); + ctx.env += 1; + let pat = quote!(#indicator::context::Env(#name)); + let ty = Ident::new(&name_string.to_case(Case::Pascal), Span::call_site()); + let (generic, attr) = if way.is_borrow() { + ( + syn::parse2( + quote!(#ty: core::borrow::Borrow<#target_ty> + Send + Sync + 'static), + )?, + quote!(#[borrow(#name)]), + ) + } else { + ( + syn::parse2(quote!(#ty: AsRef<#target_ty> + Send + Sync + 'static))?, + quote!(#[as_ref(#name)]), + ) + }; + *fn_arg = syn::parse2(quote!(#attr #pat: #indicator::context::Env<&#ty>))?; + generic + } + ArgKind::Data(way) => { + let name_string = name.unwrap_or_else(|| format!("data{}", ctx.env)); + let name = Ident::new(&name_string, Span::call_site()); + ctx.data += 1; + let pat = quote!(#indicator::context::Data(#name)); + let ty = Ident::new(&name_string.to_case(Case::Pascal), Span::call_site()); + let (generic, attr) = if way.is_borrow() { + ( + syn::parse2( + quote!(#ty: core::borrow::Borrow<#target_ty> + Send + Sync + 'static), + )?, + quote!(#[borrow(#name)]), + ) + } else { + ( + syn::parse2(quote!(#ty: AsRef<#target_ty> + Send + Sync + 'static))?, + quote!(#[as_ref(#name)]), + ) + }; + *fn_arg = syn::parse2(quote!(#attr #pat: #indicator::context::Data<&#ty>))?; + generic + } + ArgKind::Prev(way) => { + let name_string = name.unwrap_or_else(|| format!("prev{}", ctx.env)); + let name = Ident::new(&name_string, Span::call_site()); + ctx.data += 1; + let pat = quote!(#indicator::context::Prev(#name)); + let ty = Ident::new(&name_string.to_case(Case::Pascal), Span::call_site()); + let (generic, attr) = if way.is_borrow() { + ( + syn::parse2( + quote!(#ty: core::borrow::Borrow<#target_ty> + Send + Sync + 'static), + )?, + quote!(#[borrow(#name)]), + ) + } else { + ( + syn::parse2(quote!(#ty: AsRef<#target_ty> + Send + Sync + 'static))?, + quote!(#[as_ref(#name)]), + ) + }; + *fn_arg = syn::parse2(quote!(#attr #pat: #indicator::context::Prev<&#ty>))?; + generic + } + }; + Ok(Some(generic)) +} + +fn get_variable_name(pat: &Pat) -> Option<&Ident> { + match pat { + Pat::Ident(pat) => Some(&pat.ident), + _ => None, + } +} + fn get_return_type(sig: &Signature) -> Type { match &sig.output { ReturnType::Default => Type::Tuple(TypeTuple { diff --git a/examples/context/context.rs b/examples/context/context.rs index 0d8ff88..c86418c 100644 --- a/examples/context/context.rs +++ b/examples/context/context.rs @@ -20,10 +20,16 @@ where #[derive(Clone)] struct Ma(T); +impl core::borrow::Borrow for Ma { + fn borrow(&self) -> &T { + &self.0 + } +} + /// An operator that does the following: /// `x => (x + prev(x)) / 2` -#[operator(input = T)] -fn ma(Env(AddTwo(x)): Env<&AddTwo>, Prev(prev): Prev<&Ma>) -> Ma +#[operator(input = I)] +fn ma(#[env] x: &T, Prev(prev): Prev<&Ma>) -> Ma where T: Num + Clone, T: Send + Sync + 'static, @@ -34,7 +40,7 @@ where } fn main() -> anyhow::Result<()> { - let op = output_with(ma) + let op = output_with(ma::>) .inspect(|value| { println!("input: {}", value.value()); if let Some(AddTwo(x)) = value.context().env().get::>() {