Skip to content

Commit

Permalink
context: use OperatorFn to re-impl the macro
Browse files Browse the repository at this point in the history
  • Loading branch information
Nouzan committed Oct 4, 2023
1 parent ef772c6 commit 44f568d
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 580 deletions.
1 change: 1 addition & 0 deletions crates/indicator_derive/src/operator/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub(crate) struct OperatorArgs {
pub(crate) generate_out: Option<GenerateOut>,
}

#[derive(Clone, Copy)]
pub(crate) enum GenerateOut {
Out,
WithData,
Expand Down
78 changes: 63 additions & 15 deletions crates/indicator_derive/src/operator/extractor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use convert_case::{Case, Casing};
use proc_macro2::Ident;
use proc_macro2::{Ident, TokenStream};
use syn::{
parse::{Parse, ParseStream},
parse_quote,
Expand Down Expand Up @@ -99,8 +99,56 @@ type Optional = bool;

pub(super) enum Extractor {
Plain(Box<Type>),
Borrow(FnArg, Optional),
AsRef(FnArg, Optional),
Borrow(Ident, FnArg, Optional),
AsRef(Ident, FnArg, Optional),
}

impl Extractor {
pub(super) fn expand(&self) -> TokenStream {
let indicator = indicator();
match self {
Self::Plain(ty) => {
quote! {
{
let __a: #ty = #indicator::context::extractor::FromValueRef::from_value_ref(&__input);
__a
}
}
}
Self::Borrow(name, arg, false) => {
quote! {
{
let #arg = #indicator::context::extractor::FromValueRef::from_value_ref(&__input);
core::borrow::Borrow::borrow(#name)
}
}
}
Self::Borrow(name, arg, true) => {
quote! {
{
let #arg = #indicator::context::extractor::FromValueRef::from_value_ref(&__input);
#name.map(core::borrow::Borrow::borrow)
}
}
}
Self::AsRef(name, arg, false) => {
quote! {
{
let #arg = #indicator::context::extractor::FromValueRef::from_value_ref(&__input);
core::convert::AsRef::as_ref(#name)
}
}
}
Self::AsRef(name, arg, true) => {
quote! {
{
let #arg = #indicator::context::extractor::FromValueRef::from_value_ref(&__input);
#name.map(core::convert::AsRef::as_ref)
}
}
}
}
}
}

struct ExtractorWithGenerics {
Expand All @@ -124,10 +172,10 @@ impl ExtractorWithGenerics {
if way.is_borrow() {
generics
.push(syn::parse2(quote!(#ty: core::borrow::Borrow<#target_ty>))?);
Extractor::Borrow(pat, false)
Extractor::Borrow(name, pat, false)
} else {
generics.push(syn::parse2(quote!(#ty: AsRef<#target_ty>))?);
Extractor::AsRef(pat, false)
Extractor::AsRef(name, pat, false)
}
}
Attr::Env(way, false) => {
Expand All @@ -138,12 +186,12 @@ impl ExtractorWithGenerics {
let pat = parse_quote!(#indicator::context::Env(#name): #indicator::context::Env<&#ty>);
if way.is_borrow() {
generics.push(syn::parse2(quote!(#ty: core::borrow::Borrow<#target_ty> + Send + Sync + 'static))?);
Extractor::Borrow(pat, false)
Extractor::Borrow(name, pat, false)
} else {
generics.push(syn::parse2(
quote!(#ty: AsRef<#target_ty> + Send + Sync + 'static),
)?);
Extractor::AsRef(pat, false)
Extractor::AsRef(name, pat, false)
}
}
Attr::Env(way, true) => {
Expand All @@ -155,12 +203,12 @@ impl ExtractorWithGenerics {
let pat = parse_quote!(#indicator::context::Env(#name): #indicator::context::Env<Option<&#ty>>);
if way.is_borrow() {
generics.push(syn::parse2(quote!(#ty: core::borrow::Borrow<#target_ty> + Send + Sync + 'static))?);
Extractor::Borrow(pat, true)
Extractor::Borrow(name, pat, true)
} else {
generics.push(syn::parse2(
quote!(#ty: AsRef<#target_ty> + Send + Sync + 'static),
)?);
Extractor::AsRef(pat, true)
Extractor::AsRef(name, pat, true)
}
}
Attr::Data(way, false) => {
Expand All @@ -171,12 +219,12 @@ impl ExtractorWithGenerics {
let pat = parse_quote!(#indicator::context::Data(#name): #indicator::context::Data<&#ty>);
if way.is_borrow() {
generics.push(syn::parse2(quote!(#ty: core::borrow::Borrow<#target_ty> + Send + Sync + 'static))?);
Extractor::Borrow(pat, false)
Extractor::Borrow(name, pat, false)
} else {
generics.push(syn::parse2(
quote!(#ty: AsRef<#target_ty> + Send + Sync + 'static),
)?);
Extractor::AsRef(pat, false)
Extractor::AsRef(name, pat, false)
}
}
Attr::Data(way, true) => {
Expand All @@ -188,12 +236,12 @@ impl ExtractorWithGenerics {
let pat = parse_quote!(#indicator::context::Data(#name): #indicator::context::Data<Option<&#ty>>);
if way.is_borrow() {
generics.push(syn::parse2(quote!(#ty: core::borrow::Borrow<#target_ty> + Send + Sync + 'static))?);
Extractor::Borrow(pat, true)
Extractor::Borrow(name, pat, true)
} else {
generics.push(syn::parse2(
quote!(#ty: AsRef<#target_ty> + Send + Sync + 'static),
)?);
Extractor::AsRef(pat, true)
Extractor::AsRef(name, pat, true)
}
}
Attr::Prev(way) => {
Expand All @@ -205,12 +253,12 @@ impl ExtractorWithGenerics {
let pat = parse_quote!(#indicator::context::Prev(#name): #indicator::context::Prev<&#ty>);
if way.is_borrow() {
generics.push(syn::parse2(quote!(#ty: core::borrow::Borrow<#target_ty> + Send + Sync + 'static))?);
Extractor::Borrow(pat, true)
Extractor::Borrow(name, pat, true)
} else {
generics.push(syn::parse2(
quote!(#ty: AsRef<#target_ty> + Send + Sync + 'static),
)?);
Extractor::AsRef(pat, true)
Extractor::AsRef(name, pat, true)
}
}
}
Expand Down
209 changes: 3 additions & 206 deletions crates/indicator_derive/src/operator/mod.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,12 @@
use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::quote;
use syn::{
punctuated::Punctuated, FnArg, GenericParam, Ident, ItemFn, Lifetime, LifetimeParam, Meta,
PatType, ReturnType, Token, Type, Visibility,
};
use proc_macro2::TokenStream as TokenStream2;

use self::{
args::{GenerateOut, OperatorArgs},
operator_fn::OperatorFn,
};
use self::{args::OperatorArgs, operator_fn::OperatorFn};
use super::indicator;

/// Arguments for generating operator.
mod args;

/// Expand the signature.
mod signature;

/// Operator Fn.
mod operator_fn;

Expand All @@ -32,197 +20,6 @@ pub(super) fn generate_operator(
input: TokenStream,
) -> syn::Result<TokenStream2> {
let args = syn::parse::<OperatorArgs>(args)?;
let op_fn = OperatorFn::parse_with(input.clone().into(), &args)?;
let mut next = syn::parse::<ItemFn>(input)?;

let unattributed = signature::expand(&mut next, &args)?;
let next_fn = generate_next_fn(&unattributed)?;

// Documentations.
let docs = next
.attrs
.iter()
.filter(|attr| attr.path().is_ident("doc"))
.cloned()
.collect::<Vec<_>>();

// Generate struct name.
let fn_name = &next.sig.ident;
let name = Ident::new(
&format!("{}Op", fn_name.to_string().to_case(Case::Pascal)),
Span::call_site(),
);

// Handle generics.
let struct_def = generate_struct_def(&next.vis, &next.sig.generics, &name, &docs)?;
let (orig_impl_generics, type_generics, where_clause) = next.sig.generics.split_for_impl();

// Add lifetime to generics.
let mut generics = next.sig.generics.clone();
generics
.params
.push(GenericParam::Lifetime(LifetimeParam::new(Lifetime::new(
"'value",
Span::call_site(),
))));
let (impl_generics, _, _) = generics.split_for_impl();

// Handle extractors.
let mut extractors = Punctuated::<_, Token![,]>::new();
for arg in next.sig.inputs.iter() {
let FnArg::Typed(arg) = arg else {
return Err(syn::Error::new_spanned(arg, "expected typed argument"));
};
extractors.push(parse_extractor(arg)?);
}

// Handle output.
let output = match &next.sig.output {
ReturnType::Default => quote!(()),
ReturnType::Type(_, ty) => quote!(#ty),
};

let indicator = indicator();
let vis = &next.vis;
let input_type = &args.input_type;
let unattributed_type_generics = if unattributed.sig.generics.params.is_empty() {
quote!()
} else {
let (_, type_generics, _) = unattributed.sig.generics.split_for_impl();
quote!(::#type_generics)
};
let return_stmt = match args.generate_out {
Some(GenerateOut::Out) => {
quote! {
__next #unattributed_type_generics (#extractors).into()
}
}
Some(GenerateOut::Data) => {
quote! {
__next #unattributed_type_generics (#extractors).map(Into::into)
}
}
Some(GenerateOut::WithData) => {
quote! {
{
let (__out, __data) = __next #unattributed_type_generics (#extractors);
(__out.into(), __data.map(Into::into))
}
}
}
None => {
quote! {
__next #unattributed_type_generics (#extractors)
}
}
};

// Expand.
Ok(quote! {
#struct_def

impl #impl_generics #indicator::context::RefOperator<'value, #input_type> for #name #type_generics #where_clause {
type Output = #output;

fn next(&mut self, __input: #indicator::context::ValueRef<'value, #input_type>) -> Self::Output {
#next_fn
#return_stmt
}
}

#(#docs)*
#vis fn #fn_name #orig_impl_generics() -> #name #type_generics #where_clause {
#name::default()
}
})
}

fn generate_next_fn(next: &ItemFn) -> syn::Result<TokenStream2> {
let mut next = next.clone();
next.vis = Visibility::Inherited;
next.sig.ident = Ident::new("__next", next.sig.ident.span());
Ok(quote! {
#next
})
}

fn parse_extractor(arg: &PatType) -> syn::Result<TokenStream2> {
let indicator = indicator();
let PatType { ty, attrs, pat, .. } = arg;
let expanded = if let Some(attr) = attrs.first() {
let Meta::List(attr) = &attr.meta else {
unreachable!("must be meta list");
};
let kind = attr.path.get_ident().unwrap();
let name: Ident = syn::parse2(attr.tokens.clone()).unwrap();
let rt = match kind.to_string().as_str() {
"borrow" => quote!(core::borrow::Borrow::borrow(#name)),
"as_ref" => quote!(core::convert::AsRef::as_ref(#name)),
_ => unreachable!(),
};
quote! {
{
let #pat: #ty = #indicator::context::extractor::FromValueRef::from_value_ref(&__input);
#rt
}
}
} else {
quote! {
{
let __a: #ty = #indicator::context::extractor::FromValueRef::from_value_ref(&__input);
__a
}
}
};
let expanded = OperatorFn::parse_with(input.clone().into(), &args)?.expand();

Check warning on line 23 in crates/indicator_derive/src/operator/mod.rs

View workflow job for this annotation

GitHub Actions / Test on Rust 1.67.0

redundant clone
Ok(expanded)
}

fn generate_struct_def(
vis: &Visibility,
generics: &syn::Generics,
name: &syn::Ident,
docs: &[syn::Attribute],
) -> syn::Result<TokenStream2> {
if generics.params.is_empty() {
return Ok(quote! {
#[derive(Default)]
#[allow(non_camel_case_types)]
#vis struct #name;
});
}
let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
let phantom_data_type = generate_phantom_data_type(generics);
Ok(quote! {
#(#docs)*
#[allow(non_camel_case_types)]
#vis struct #name #impl_generics (core::marker::PhantomData<#phantom_data_type> ) #where_clause;

impl #impl_generics Default for #name #type_generics #where_clause {
fn default() -> Self {
Self(core::marker::PhantomData)
}
}
})
}

fn generate_phantom_data_type(generics: &syn::Generics) -> Type {
let params: Vec<_> = generics
.params
.iter()
.filter_map(|param| {
if let syn::GenericParam::Type(type_param) = param {
let ident = &type_param.ident;
Some(quote! { #ident })
} else {
None
}
})
.collect();

// Generate `fn() -> (generics)` type
let phantom_data_type = quote! { fn() -> (#(#params),*) };

// Parse as `fn() -> (generics)` type
let phantom_data_type: Type = syn::parse2(phantom_data_type).unwrap();
phantom_data_type
}
Loading

0 comments on commit 44f568d

Please sign in to comment.