diff --git a/cached_proc_macro/src/cached.rs b/cached_proc_macro/src/cached.rs index 21fbe2b..b4025d5 100644 --- a/cached_proc_macro/src/cached.rs +++ b/cached_proc_macro/src/cached.rs @@ -3,9 +3,18 @@ use darling::ast::NestedMeta; use darling::FromMeta; use proc_macro::TokenStream; use quote::quote; +use std::cmp::PartialEq; use syn::spanned::Spanned; use syn::{parse_macro_input, parse_str, Block, Ident, ItemFn, ReturnType, Type}; +#[derive(Debug, Default, FromMeta, Eq, PartialEq)] +enum SyncWriteMode { + #[default] + Disabled, + Default, + ByKey, +} + #[derive(FromMeta)] struct MacroArgs { #[darling(default)] @@ -27,9 +36,7 @@ struct MacroArgs { #[darling(default)] option: bool, #[darling(default)] - sync_writes: bool, - #[darling(default)] - sync_writes_by_key: bool, + sync_writes: SyncWriteMode, #[darling(default)] with_cached_flag: bool, #[darling(default)] @@ -192,16 +199,8 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { _ => panic!("the result and option attributes are mutually exclusive"), }; - if args.result_fallback && args.sync_writes { - panic!("the result_fallback and sync_writes attributes are mutually exclusive"); - } - - if args.result_fallback && args.sync_writes_by_key { - panic!("the result_fallback and sync_writes_by_key attributes are mutually exclusive"); - } - - if args.sync_writes && args.sync_writes_by_key { - panic!("the sync_writes and sync_writes_by_key attributes are mutually exclusive"); + if args.result_fallback && args.sync_writes != SyncWriteMode::Disabled { + panic!("result_fallback and sync_writes are mutually exclusive"); } let set_cache_and_return = quote! { @@ -216,8 +215,8 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { let function_call; let ty; if asyncness.is_some() { - lock = if args.sync_writes_by_key { - quote! { + lock = match args.sync_writes { + SyncWriteMode::ByKey => quote! { let mut locks = #cache_ident.lock().await; let lock = locks .entry(key.clone()) @@ -225,11 +224,10 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { .clone(); drop(locks); let mut cache = lock.lock().await; - } - } else { - quote! { + }, + _ => quote! { let mut cache = #cache_ident.lock().await; - } + }, }; function_no_cache = quote! { @@ -240,27 +238,25 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { let result = #no_cache_fn_ident(#(#input_names),*).await; }; - ty = if args.sync_writes_by_key { - quote! { + ty = match args.sync_writes { + SyncWriteMode::ByKey => quote! { #visibility static #cache_ident: ::cached::once_cell::sync::Lazy<::cached::async_sync::Mutex>>>> = ::cached::once_cell::sync::Lazy::new(|| ::cached::async_sync::Mutex::new(std::collections::HashMap::new())); - } - } else { - quote! { + }, + _ => quote! { #visibility static #cache_ident: ::cached::once_cell::sync::Lazy<::cached::async_sync::Mutex<#cache_ty>> = ::cached::once_cell::sync::Lazy::new(|| ::cached::async_sync::Mutex::new(#cache_create)); - } + }, }; } else { - lock = if args.sync_writes_by_key { - quote! { + lock = match args.sync_writes { + SyncWriteMode::ByKey => quote! { let mut locks = #cache_ident.lock().unwrap(); let lock = locks.entry(key.clone()).or_insert_with(|| std::sync::Arc::new(std::sync::Mutex::new(#cache_create))).clone(); drop(locks); let mut cache = lock.lock().unwrap(); - } - } else { - quote! { + }, + _ => quote! { let mut cache = #cache_ident.lock().unwrap(); - } + }, }; function_no_cache = quote! { @@ -271,14 +267,13 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { let result = #no_cache_fn_ident(#(#input_names),*); }; - ty = if args.sync_writes_by_key { - quote! { + ty = match args.sync_writes { + SyncWriteMode::ByKey => quote! { #visibility static #cache_ident: ::cached::once_cell::sync::Lazy>>>> = ::cached::once_cell::sync::Lazy::new(|| std::sync::Mutex::new(std::collections::HashMap::new())); - } - } else { - quote! { + }, + _ => quote! { #visibility static #cache_ident: ::cached::once_cell::sync::Lazy> = ::cached::once_cell::sync::Lazy::new(|| std::sync::Mutex::new(#cache_create)); - } + }, } } @@ -290,7 +285,7 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { #set_cache_and_return }; - let do_set_return_block = if args.sync_writes_by_key || args.sync_writes { + let do_set_return_block = if args.sync_writes != SyncWriteMode::Disabled { quote! { #lock if let Some(result) = cache.cache_get(&key) { diff --git a/cached_proc_macro/src/once.rs b/cached_proc_macro/src/once.rs index 70d5617..fd2ac07 100644 --- a/cached_proc_macro/src/once.rs +++ b/cached_proc_macro/src/once.rs @@ -6,6 +6,13 @@ use quote::quote; use syn::spanned::Spanned; use syn::{parse_macro_input, Ident, ItemFn, ReturnType}; +#[derive(Debug, Default, FromMeta)] +enum SyncWriteMode { + #[default] + Disabled, + Default, +} + #[derive(FromMeta)] struct OnceMacroArgs { #[darling(default)] @@ -13,7 +20,7 @@ struct OnceMacroArgs { #[darling(default)] time: Option, #[darling(default)] - sync_writes: bool, + sync_writes: SyncWriteMode, #[darling(default)] result: bool, #[darling(default)] @@ -220,8 +227,8 @@ pub fn once(args: TokenStream, input: TokenStream) -> TokenStream { } }; - let do_set_return_block = if args.sync_writes { - quote! { + let do_set_return_block = match args.sync_writes { + SyncWriteMode::Default => quote! { #r_lock_return_cache_block #w_lock if let Some(result) = &*cached { @@ -229,14 +236,13 @@ pub fn once(args: TokenStream, input: TokenStream) -> TokenStream { } #function_call #set_cache_and_return - } - } else { - quote! { + }, + SyncWriteMode::Disabled => quote! { #r_lock_return_cache_block #function_call #w_lock #set_cache_and_return - } + }, }; let signature_no_muts = get_mut_signature(signature);