diff --git a/README.md b/README.md index e2bb7ef..0ad4d18 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,8 @@ function-cache wrapped in a mutex/rwlock, or externally synchronized in the case By default, the function-cache is **not** locked for the duration of the function's execution, so initial (on an empty cache) concurrent calls of long-running functions with the same arguments will each execute fully and each overwrite the memoized value as they complete. This mirrors the behavior of Python's `functools.lru_cache`. To synchronize the execution and caching -of un-cached arguments, specify `#[cached(sync_writes = true)]` / `#[once(sync_writes = true)]` (not supported by `#[io_cached]`. +of un-cached arguments, specify `#[cached(sync_writes = true)]` / `#[once(sync_writes = true)]` (not supported by `#[io_cached]`). To synchronize +by cache_key use `#[cached(sync_writes_by_key = true)]` (not supported by `#[once]` / `#[io_cached]`). - See [`cached::stores` docs](https://docs.rs/cached/latest/cached/stores/index.html) cache stores available. - See [`proc_macro`](https://docs.rs/cached/latest/cached/proc_macro/index.html) for more procedural macro examples. diff --git a/cached_proc_macro/src/cached.rs b/cached_proc_macro/src/cached.rs index fc63555..0dda2f8 100644 --- a/cached_proc_macro/src/cached.rs +++ b/cached_proc_macro/src/cached.rs @@ -29,6 +29,8 @@ struct MacroArgs { #[darling(default)] sync_writes: bool, #[darling(default)] + sync_writes_by_key: bool, + #[darling(default)] with_cached_flag: bool, #[darling(default)] ty: Option, @@ -92,6 +94,13 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { Some(ref name) => Ident::new(name, fn_ident.span()), None => Ident::new(&fn_ident.to_string().to_uppercase(), fn_ident.span()), }; + let cache_ident_key = match args.name { + Some(ref name) => Ident::new(&format!("{}_key", name), fn_ident.span()), + None => Ident::new( + &format!("{}_key", fn_ident.to_string().to_uppercase()), + fn_ident.span(), + ), + }; let (cache_key_ty, key_convert_block) = make_cache_key_type(&args.key, &args.convert, &args.ty, input_tys, &input_names); @@ -194,6 +203,14 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { panic!("the result_fallback and sync_writes attributes are mutually exclusive"); } + if args.result_fallback && args.sync_writes_by_key { + panic!("the result_fallback and sync_writes_by_key attributes are mutually exclusive"); + } + + if args.sync_writes && args.sync_writes_by_key { + panic!("the sync_writes and sync_writes_by_key attributes are mutually exclusive"); + } + let set_cache_and_return = quote! { #set_cache_block result @@ -202,6 +219,7 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { let no_cache_fn_ident = Ident::new(&format!("{}_no_cache", &fn_ident), fn_ident.span()); let lock; + let lock_key; let function_no_cache; let function_call; let ty; @@ -210,6 +228,16 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { let mut cache = #cache_ident.lock().await; }; + lock_key = quote! { + let mut locks = #cache_ident_key.lock().await; + let lock = locks + .entry(key.clone()) + .or_insert_with(|| std::sync::Arc::new(::cached::async_sync::Mutex::new(#cache_create))) + .clone(); + drop(locks); + let mut cache = lock.lock().await; + }; + function_no_cache = quote! { async fn #no_cache_fn_ident #generics (#inputs) #output #body }; @@ -220,12 +248,20 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { ty = quote! { #visibility static #cache_ident: ::cached::once_cell::sync::Lazy<::cached::async_sync::Mutex<#cache_ty>> = ::cached::once_cell::sync::Lazy::new(|| ::cached::async_sync::Mutex::new(#cache_create)); + #visibility static #cache_ident_key: ::cached::once_cell::sync::Lazy<::cached::async_sync::Mutex>>>> = ::cached::once_cell::sync::Lazy::new(|| ::cached::async_sync::Mutex::new(std::collections::HashMap::new())); }; } else { lock = quote! { let mut cache = #cache_ident.lock().unwrap(); }; + lock_key = quote! { + let mut locks = #cache_ident_key.lock().unwrap(); + let lock = locks.entry(key.clone()).or_insert_with(|| std::sync::Arc::new(std::sync::Mutex::new(#cache_create))).clone(); + drop(locks); + let mut cache = lock.lock().unwrap(); + }; + function_no_cache = quote! { fn #no_cache_fn_ident #generics (#inputs) #output #body }; @@ -236,6 +272,7 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { ty = quote! { #visibility static #cache_ident: ::cached::once_cell::sync::Lazy> = ::cached::once_cell::sync::Lazy::new(|| std::sync::Mutex::new(#cache_create)); + #visibility static #cache_ident_key: ::cached::once_cell::sync::Lazy>>>> = ::cached::once_cell::sync::Lazy::new(|| std::sync::Mutex::new(std::collections::HashMap::new())); }; } @@ -247,7 +284,16 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { #set_cache_and_return }; - let do_set_return_block = if args.sync_writes { + let do_set_return_block = if args.sync_writes_by_key { + quote! { + #lock_key + if let Some(result) = cache.cache_get(&key) { + #return_cache_block + } + #function_call + #set_cache_and_return + } + } else if args.sync_writes { quote! { #lock if let Some(result) = cache.cache_get(&key) { diff --git a/cached_proc_macro/src/lib.rs b/cached_proc_macro/src/lib.rs index 83d6d39..5985017 100644 --- a/cached_proc_macro/src/lib.rs +++ b/cached_proc_macro/src/lib.rs @@ -14,6 +14,7 @@ use proc_macro::TokenStream; /// - `time`: (optional, u64) specify a cache TTL in seconds, implies the cache type is a `TimedCache` or `TimedSizedCache`. /// - `time_refresh`: (optional, bool) specify whether to refresh the TTL on cache hits. /// - `sync_writes`: (optional, bool) specify whether to synchronize the execution of writing of uncached values. +/// - `sync_writes_by_key`: (optional, bool) specify whether to synchronize the execution of writing of uncached values by key. /// - `ty`: (optional, string type) The cache store type to use. Defaults to `UnboundCache`. When `unbound` is /// specified, defaults to `UnboundCache`. When `size` is specified, defaults to `SizedCache`. /// When `time` is specified, defaults to `TimedCached`. diff --git a/tests/cached.rs b/tests/cached.rs index 8a161bd..a2a26bc 100644 --- a/tests/cached.rs +++ b/tests/cached.rs @@ -10,7 +10,7 @@ use cached::{ }; use serial_test::serial; use std::thread::{self, sleep}; -use std::time::Duration; +use std::time::{Duration, Instant}; cached! { UNBOUND_FIB; @@ -898,6 +898,49 @@ async fn test_cached_sync_writes_a() { assert_eq!(a, c.await.unwrap()); } +#[cached(time = 2, sync_writes_by_key = true, key = "u32", convert = "{ 1 }")] +fn cached_sync_writes_by_key(s: String) -> Vec { + sleep(Duration::new(1, 0)); + vec![s] +} + +#[test] +fn test_cached_sync_writes_by_key() { + let a = std::thread::spawn(|| cached_sync_writes_by_key("a".to_string())); + let b = std::thread::spawn(|| cached_sync_writes_by_key("b".to_string())); + let c = std::thread::spawn(|| cached_sync_writes_by_key("c".to_string())); + let start = Instant::now(); + let a = a.join().unwrap(); + let b = b.join().unwrap(); + let c = c.join().unwrap(); + assert!(start.elapsed() < Duration::from_secs(2)); +} + +#[cfg(feature = "async")] +#[cached( + time = 5, + sync_writes_by_key = true, + key = "String", + convert = r#"{ format!("{}", s) }"# +)] +async fn cached_sync_writes_by_key_a(s: String) -> Vec { + tokio::time::sleep(Duration::from_secs(1)).await; + vec![s] +} + +#[cfg(feature = "async")] +#[tokio::test] +async fn test_cached_sync_writes_by_key_a() { + let a = tokio::spawn(cached_sync_writes_by_key_a("a".to_string())); + let b = tokio::spawn(cached_sync_writes_by_key_a("b".to_string())); + let c = tokio::spawn(cached_sync_writes_by_key_a("c".to_string())); + let start = Instant::now(); + a.await.unwrap(); + b.await.unwrap(); + c.await.unwrap(); + assert!(start.elapsed() < Duration::from_secs(2)); +} + #[cfg(feature = "async")] #[once(sync_writes = true)] async fn once_sync_writes_a(s: &tokio::sync::Mutex) -> String {