Skip to content

Commit

Permalink
context: add more sugars to #[operator]
Browse files Browse the repository at this point in the history
- Allow generate generics for outputs
  • Loading branch information
Nouzan committed Oct 3, 2023
1 parent ac7f901 commit 776f04c
Show file tree
Hide file tree
Showing 7 changed files with 399 additions and 160 deletions.
155 changes: 8 additions & 147 deletions crates/indicator_derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,10 @@
use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
use proc_macro_crate::{crate_name, FoundCrate};
use quote::quote;
use syn::{
FnArg, GenericParam, Ident, ItemFn, Lifetime, LifetimeParam, PatType, ReturnType, Type,
Visibility,
};
use syn::Ident;

/// Create a `RefOperator` from a function.
#[proc_macro_attribute]
pub fn operator(args: TokenStream, input: TokenStream) -> TokenStream {
match generate_operator(args, input) {
Ok(output) => output.into(),
Err(err) => err.to_compile_error().into(),
}
}
mod operator;

fn indicator() -> TokenStream2 {
let found_crate = crate_name("indicator").expect("my-crate is present in `Cargo.toml`");
Expand All @@ -29,139 +18,11 @@ fn indicator() -> TokenStream2 {
}
}

fn generate_operator(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream2> {
let indicator = indicator();
let input_type = parse_input_type(args)?;
let ItemFn {
vis,
sig,
block,
attrs,
} = syn::parse::<ItemFn>(input)?;

// Documentations.
let docs = attrs
.iter()
.filter(|attr| attr.path().is_ident("doc"))
.cloned()
.collect::<Vec<_>>();

// Generate struct name.
let fn_name = &sig.ident;
let name = Ident::new(
&format!("{}Op", fn_name.to_string().to_case(Case::Pascal)),
Span::call_site(),
);

// Handle generics.
let struct_def = generate_struct_def(&vis, &sig.generics, &name, &docs)?;
let (orig_impl_generics, type_generics, where_clause) = sig.generics.split_for_impl();

// Add lifetime to generics.
let mut generics = sig.generics.clone();
generics
.params
.push(GenericParam::Lifetime(LifetimeParam::new(Lifetime::new(
"'value",
Span::call_site(),
))));
let (impl_generics, _, _) = generics.split_for_impl();

// Handle extractors.
let mut extractors = Vec::new();
for arg in sig.inputs.iter() {
let FnArg::Typed(arg) = arg else {
return Err(syn::Error::new_spanned(arg, "expected typed argument"));
};
extractors.push(parse_extractor(arg)?);
}

// Handle output.
let output = match &sig.output {
ReturnType::Default => quote!(()),
ReturnType::Type(_, ty) => quote!(#ty),
};

let stmts = block.stmts;

// Expand.
Ok(quote! {
#struct_def

impl #impl_generics #indicator::context::RefOperator<'value, #input_type> for #name #type_generics #where_clause {
type Output = #output;

fn next(&mut self, __input: #indicator::context::ValueRef<'value, #input_type>) -> Self::Output {
#(#extractors)*
#(#stmts)*
}
}

#(#docs)*
#vis fn #fn_name #orig_impl_generics() -> #name #type_generics #where_clause {
#name::default()
}
})
}

fn parse_input_type(args: TokenStream) -> syn::Result<Type> {
syn::parse::<Type>(args)
}

fn parse_extractor(arg: &PatType) -> syn::Result<TokenStream2> {
let indicator = indicator();
let PatType { pat, ty, .. } = arg;
Ok(quote! {
let #pat: #ty = #indicator::context::extractor::FromValueRef::from_value_ref(&__input);
})
}

fn generate_struct_def(
vis: &Visibility,
generics: &syn::Generics,
name: &syn::Ident,
docs: &[syn::Attribute],
) -> syn::Result<TokenStream2> {
if generics.params.is_empty() {
return Ok(quote! {
#[derive(Default)]
#[allow(non_camel_case_types)]
#vis struct #name;
});
/// Create a `RefOperator` from a function.
#[proc_macro_attribute]
pub fn operator(args: TokenStream, input: TokenStream) -> TokenStream {
match self::operator::generate_operator(args, input) {
Ok(output) => output.into(),
Err(err) => err.to_compile_error().into(),
}
let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
let phantom_data_type = generate_phantom_data_type(generics);
Ok(quote! {
#(#docs)*
#[allow(non_camel_case_types)]
#vis struct #name #impl_generics (core::marker::PhantomData<#phantom_data_type> ) #where_clause;

impl #impl_generics Default for #name #type_generics #where_clause {
fn default() -> Self {
Self(core::marker::PhantomData)
}
}
})
}

fn generate_phantom_data_type(generics: &syn::Generics) -> Type {
let params: Vec<_> = generics
.params
.iter()
.filter_map(|param| {
if let syn::GenericParam::Type(type_param) = param {
let ident = &type_param.ident;
Some(quote! { #ident })
} else {
None
}
})
.collect();

// Generate `fn() -> (generics)` type
let phantom_data_type = quote! { fn() -> (#(#params),*) };

// Parse as `fn() -> (generics)` type
let phantom_data_type: Type = syn::parse2(phantom_data_type).unwrap();
phantom_data_type
}
89 changes: 89 additions & 0 deletions crates/indicator_derive/src/operator/args.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use proc_macro2::Span;
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
Expr, ExprAssign, ExprPath, Ident, Result, Token, Type, TypePath,
};

pub(crate) struct OperatorArgs {
pub(crate) input_type: Type,
pub(crate) generate_out: Option<GenerateOut>,
}

pub(crate) enum GenerateOut {
Out,
WithData,
Data,
}

impl Parse for OperatorArgs {
fn parse(input: ParseStream) -> Result<Self> {
let args = Punctuated::<Args, Token![,]>::parse_terminated(input)?;
let mut input_type = None;
let mut generate_out = None;
for arg in args {
match arg {
Args::GenerateOut(out) => generate_out = Some(out),
Args::InputType(ty) => input_type = Some(ty),
}
}
let input_type = input_type
.ok_or_else(|| syn::Error::new(Span::call_site(), "`input` argument is required"))?;
Ok(Self {
input_type,
generate_out,
})
}
}

enum Args {
InputType(Type),
GenerateOut(GenerateOut),
}

fn get_ident(expr: &Expr) -> Result<&Ident> {
let Expr::Path(ExprPath { path, .. }) = expr else {
return Err(syn::Error::new(
Span::call_site(),
"Expecting an identifier",
));
};
let ident = path
.get_ident()
.ok_or_else(|| syn::Error::new(Span::call_site(), "Expecting an identifier"))?;
Ok(ident)
}

impl Parse for Args {
fn parse(input: ParseStream) -> Result<Self> {
let expr: Expr = input.parse()?;
match expr {
Expr::Assign(ExprAssign { left, right, .. }) => {
let ident = get_ident(&left)?;
if ident == "input" {
let Expr::Path(ExprPath { path, .. }) = *right else {
return Err(syn::Error::new(Span::call_site(), "Expecting a type"));
};
Ok(Self::InputType(Type::Path(TypePath { qself: None, path })))
} else {
Err(syn::Error::new(
ident.span(),
format!("Unknown argument: `{ident}`, expecting `input = T`"),
))
}
}
expr => {
let ident = get_ident(&expr)?;
match ident.to_string().as_str() {
"generate_out" => Ok(Self::GenerateOut(GenerateOut::Out)),
"generate_data" => Ok(Self::GenerateOut(GenerateOut::Data)),
"generate_out_with_data" => Ok(Self::GenerateOut(GenerateOut::WithData)),
_ => Err(syn::Error::new(
ident.span(),
format!("Unknown argument: `{ident}`, expecting `generate_out`, `generate_out_data` or `generate_out_with_data`"),
)),
}
}
}
}
}
Loading

0 comments on commit 776f04c

Please sign in to comment.