Skip to content

Commit

Permalink
Merge pull request #195 from moka-rs/prevent-eval-init-more-than-once
Browse files Browse the repository at this point in the history
Prevent race condition in `get_with` method to avoid evaluating `init` closure/future multiple times
  • Loading branch information
tatsuya6502 authored Nov 5, 2022
2 parents 0976f84 + c233aff commit 52febeb
Show file tree
Hide file tree
Showing 10 changed files with 1,033 additions and 277 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ anyhow = "1.0.19"
async-std = { version = "1.11", features = ["attributes"] }
env_logger = "0.9"
getrandom = "0.2"
paste = "1.0.9"
reqwest = "0.11.11"
skeptic = "0.13"
tokio = { version = "1.19", features = ["fs", "macros", "rt-multi-thread", "sync", "time" ] }
Expand Down
87 changes: 48 additions & 39 deletions src/future/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use crate::common::concurrent::debug_counters::CacheDebugStats;

use crossbeam_channel::{Sender, TrySendError};
use std::{
any::TypeId,
borrow::Borrow,
collections::hash_map::RandomState,
fmt,
Expand All @@ -31,8 +30,6 @@ use std::{
time::Duration,
};

use super::OptionallyNone;

/// A thread-safe, futures-aware concurrent in-memory cache.
///
/// `Cache` supports full concurrency of retrievals and a high expected concurrency
Expand Down Expand Up @@ -935,7 +932,7 @@ where
/// // This async function tries to get HTML from the given URI.
/// async fn get_html(task_id: u8, uri: &str) -> Result<String, reqwest::Error> {
/// println!("get_html() called by task {}.", task_id);
/// Ok(reqwest::get(uri).await?.text().await?)
/// reqwest::get(uri).await?.text().await
/// }
///
/// #[tokio::main]
Expand Down Expand Up @@ -1376,17 +1373,15 @@ where
init: impl Future<Output = V>,
mut replace_if: Option<impl FnMut(&V) -> bool>,
) -> V {
match (self.base.get_with_hash(&key, hash), &mut replace_if) {
(Some(v), None) => return v,
(Some(v), Some(cond)) => {
if !cond(&v) {
return v;
};
}
_ => (),
let maybe_v = self
.base
.get_with_hash_but_ignore_if(&key, hash, replace_if.as_mut());
if let Some(v) = maybe_v {
v
} else {
self.insert_with_hash_and_fun(key, hash, init, replace_if)
.await
}

self.insert_with_hash_and_fun(key, hash, init).await
}

async fn get_or_insert_with_hash_by_ref_and_fun<Q>(
Expand All @@ -1400,35 +1395,39 @@ where
K: Borrow<Q>,
Q: ToOwned<Owned = K> + Hash + Eq + ?Sized,
{
match (self.base.get_with_hash(key, hash), &mut replace_if) {
(Some(v), None) => return v,
(Some(v), Some(cond)) => {
if !cond(&v) {
return v;
};
}
_ => (),
let maybe_v = self
.base
.get_with_hash_but_ignore_if(key, hash, replace_if.as_mut());
if let Some(v) = maybe_v {
v
} else {
let key = Arc::new(key.to_owned());
self.insert_with_hash_and_fun(key, hash, init, replace_if)
.await
}
let key = Arc::new(key.to_owned());
self.insert_with_hash_and_fun(key, hash, init).await
}

async fn insert_with_hash_and_fun(
&self,
key: Arc<K>,
hash: u64,
init: impl Future<Output = V>,
mut replace_if: Option<impl FnMut(&V) -> bool>,
) -> V {
use futures_util::FutureExt;

let get = || {
self.base
.get_with_hash_but_no_recording(&key, hash, replace_if.as_mut())
};
let insert = |v| self.insert_with_hash(key.clone(), hash, v).boxed();

match self
.value_initializer
.init_or_read(Arc::clone(&key), init)
.init_or_read(Arc::clone(&key), get, init, insert)
.await
{
InitResult::Initialized(v) => {
self.insert_with_hash(Arc::clone(&key), hash, v.clone())
.await;
self.value_initializer
.remove_waiter(&key, TypeId::of::<()>());
crossbeam_epoch::pin().flush();
v
}
Expand Down Expand Up @@ -1483,16 +1482,21 @@ where
F: Future<Output = Result<V, E>>,
E: Send + Sync + 'static,
{
use futures_util::FutureExt;

let get = || {
let ignore_if = None as Option<&mut fn(&V) -> bool>;
self.base
.get_with_hash_but_no_recording(&key, hash, ignore_if)
};
let insert = |v| self.insert_with_hash(key.clone(), hash, v).boxed();

match self
.value_initializer
.try_init_or_read(Arc::clone(&key), init)
.try_init_or_read(Arc::clone(&key), get, init, insert)
.await
{
InitResult::Initialized(v) => {
self.insert_with_hash(Arc::clone(&key), hash, v.clone())
.await;
self.value_initializer
.remove_waiter(&key, TypeId::of::<E>());
crossbeam_epoch::pin().flush();
Ok(v)
}
Expand Down Expand Up @@ -1553,16 +1557,21 @@ where
where
F: Future<Output = Option<V>>,
{
use futures_util::FutureExt;

let get = || {
let ignore_if = None as Option<&mut fn(&V) -> bool>;
self.base
.get_with_hash_but_no_recording(&key, hash, ignore_if)
};
let insert = |v| self.insert_with_hash(key.clone(), hash, v).boxed();

match self
.value_initializer
.optionally_init_or_read(Arc::clone(&key), init)
.optionally_init_or_read(Arc::clone(&key), get, init, insert)
.await
{
InitResult::Initialized(v) => {
self.insert_with_hash(Arc::clone(&key), hash, v.clone())
.await;
self.value_initializer
.remove_waiter(&key, TypeId::of::<OptionallyNone>());
crossbeam_epoch::pin().flush();
Some(v)
}
Expand Down
Loading

0 comments on commit 52febeb

Please sign in to comment.