Skip to content

Commit

Permalink
Merge #718
Browse files Browse the repository at this point in the history
718: added get_or_insert_with function r=taiki-e a=snow01

get_or_insert_with function allows lazy creation of default value. Default values may be heavy objects, lazy creation doesn't create un-necessary objects when key is already there in the data structure.

Co-authored-by: Shailendra Sharma <shailendra.sharma@gmail.com>
  • Loading branch information
bors[bot] and snow01 authored Jul 22, 2021
2 parents 311124c + 403e899 commit 800ca61
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 4 deletions.
29 changes: 25 additions & 4 deletions crossbeam-skiplist/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,21 @@ where

/// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist.
pub fn get_or_insert(&self, key: K, value: V, guard: &Guard) -> RefEntry<'_, K, V> {
self.insert_internal(key, || value, false, guard)
}

/// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist,
/// where value is calculated with a function.
///
///
/// <b>Note:</b> Another thread may write key value first, leading to the result of this closure
/// discarded. If closure is modifying some other state (such as shared counters or shared
/// objects), it may lead to <u>undesired behaviour</u> such as counters being changed without
/// result of closure inserted
pub fn get_or_insert_with<F>(&self, key: K, value: F, guard: &Guard) -> RefEntry<'_, K, V>
where
F: FnOnce() -> V,
{
self.insert_internal(key, value, false, guard)
}

Expand Down Expand Up @@ -831,13 +846,16 @@ where
/// Inserts an entry with the specified `key` and `value`.
///
/// If `replace` is `true`, then any existing entry with this key will first be removed.
fn insert_internal(
fn insert_internal<F>(
&self,
key: K,
value: V,
value: F,
replace: bool,
guard: &Guard,
) -> RefEntry<'_, K, V> {
) -> RefEntry<'_, K, V>
where
F: FnOnce() -> V,
{
self.check_guard(guard);

unsafe {
Expand Down Expand Up @@ -876,6 +894,9 @@ where
}
}

// create value before creating node, so extra allocation doesn't happen if value() function panics
let value = value();

// Create a new node.
let height = self.random_height();
let (node, n) = {
Expand Down Expand Up @@ -1061,7 +1082,7 @@ where
/// If there is an existing entry with this key, it will be removed before inserting the new
/// one.
pub fn insert(&self, key: K, value: V, guard: &Guard) -> RefEntry<'_, K, V> {
self.insert_internal(key, value, true, guard)
self.insert_internal(key, || value, true, guard)
}

/// Removes an entry with the specified `key` from the map and returns it.
Expand Down
33 changes: 33 additions & 0 deletions crossbeam-skiplist/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,39 @@ where
Entry::new(self.inner.get_or_insert(key, value, guard))
}

/// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist,
/// where value is calculated with a function.
///
///
/// <b>Note:</b> Another thread may write key value first, leading to the result of this closure
/// discarded. If closure is modifying some other state (such as shared counters or shared
/// objects), it may lead to <u>undesired behaviour</u> such as counters being changed without
/// result of closure inserted
////
/// This function returns an [`Entry`] which
/// can be used to access the key's associated value.
///
///
/// # Example
/// ```
/// use crossbeam_skiplist::SkipMap;
///
/// let ages = SkipMap::new();
/// let gates_age = ages.get_or_insert_with("Bill Gates", || 64);
/// assert_eq!(*gates_age.value(), 64);
///
/// ages.insert("Steve Jobs", 65);
/// let jobs_age = ages.get_or_insert_with("Steve Jobs", || -1);
/// assert_eq!(*jobs_age.value(), 65);
/// ```
pub fn get_or_insert_with<F>(&self, key: K, value_fn: F) -> Entry<'_, K, V>
where
F: FnOnce() -> V,
{
let guard = &epoch::pin();
Entry::new(self.inner.get_or_insert_with(key, value_fn, guard))
}

/// Returns an iterator over all entries in the map,
/// sorted by key.
///
Expand Down
70 changes: 70 additions & 0 deletions crossbeam-skiplist/tests/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,76 @@ fn get_or_insert() {
assert_eq!(*s.get_or_insert(6, 600, guard).value(), 600);
}

#[test]
fn get_or_insert_with() {
let guard = &epoch::pin();
let s = SkipList::new(epoch::default_collector().clone());
s.insert(3, 3, guard);
s.insert(5, 5, guard);
s.insert(1, 1, guard);
s.insert(4, 4, guard);
s.insert(2, 2, guard);

assert_eq!(*s.get(&4, guard).unwrap().value(), 4);
assert_eq!(*s.insert(4, 40, guard).value(), 40);
assert_eq!(*s.get(&4, guard).unwrap().value(), 40);

assert_eq!(*s.get_or_insert_with(4, || 400, guard).value(), 40);
assert_eq!(*s.get(&4, guard).unwrap().value(), 40);
assert_eq!(*s.get_or_insert_with(6, || 600, guard).value(), 600);
}

#[test]
fn get_or_insert_with_panic() {
use std::panic;

let s = SkipList::new(epoch::default_collector().clone());
let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
let guard = &epoch::pin();
s.get_or_insert_with(4, || panic!(), guard);
}));
assert!(res.is_err());
assert!(s.is_empty());
let guard = &epoch::pin();
assert_eq!(*s.get_or_insert_with(4, || 40, guard).value(), 40);
assert_eq!(s.len(), 1);
}

#[test]
fn get_or_insert_with_parallel_run() {
use std::sync::{Arc, Mutex};

let s = Arc::new(SkipList::new(epoch::default_collector().clone()));
let s2 = s.clone();
let called = Arc::new(Mutex::new(false));
let called2 = called.clone();
let handle = std::thread::spawn(move || {
let guard = &epoch::pin();
assert_eq!(
*s2.get_or_insert_with(
7,
|| {
*called2.lock().unwrap() = true;

// allow main thread to run before we return result
std::thread::sleep(std::time::Duration::from_secs(4));
70
},
guard,
)
.value(),
700
);
});
std::thread::sleep(std::time::Duration::from_secs(2));
let guard = &epoch::pin();

// main thread writes the value first
assert_eq!(*s.get_or_insert(7, 700, guard).value(), 700);
handle.join().unwrap();
assert!(*called.lock().unwrap());
}

#[test]
fn get_next_prev() {
let guard = &epoch::pin();
Expand Down
61 changes: 61 additions & 0 deletions crossbeam-skiplist/tests/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,67 @@ fn get_or_insert() {
assert_eq!(*s.get_or_insert(6, 600).value(), 600);
}

#[test]
fn get_or_insert_with() {
let s = SkipMap::new();
s.insert(3, 3);
s.insert(5, 5);
s.insert(1, 1);
s.insert(4, 4);
s.insert(2, 2);

assert_eq!(*s.get(&4).unwrap().value(), 4);
assert_eq!(*s.insert(4, 40).value(), 40);
assert_eq!(*s.get(&4).unwrap().value(), 40);

assert_eq!(*s.get_or_insert_with(4, || 400).value(), 40);
assert_eq!(*s.get(&4).unwrap().value(), 40);
assert_eq!(*s.get_or_insert_with(6, || 600).value(), 600);
}

#[test]
fn get_or_insert_with_panic() {
use std::panic;

let s = SkipMap::new();
let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
s.get_or_insert_with(4, || panic!());
}));
assert!(res.is_err());
assert!(s.is_empty());
assert_eq!(*s.get_or_insert_with(4, || 40).value(), 40);
assert_eq!(s.len(), 1);
}

#[test]
fn get_or_insert_with_parallel_run() {
use std::sync::{Arc, Mutex};

let s = Arc::new(SkipMap::new());
let s2 = s.clone();
let called = Arc::new(Mutex::new(false));
let called2 = called.clone();
let handle = std::thread::spawn(move || {
assert_eq!(
*s2.get_or_insert_with(7, || {
*called2.lock().unwrap() = true;

// allow main thread to run before we return result
std::thread::sleep(std::time::Duration::from_secs(4));
70
})
.value(),
700
);
});
std::thread::sleep(std::time::Duration::from_secs(2));

// main thread writes the value first
assert_eq!(*s.get_or_insert(7, 700).value(), 700);
handle.join().unwrap();
assert!(*called.lock().unwrap());
}

#[test]
fn get_next_prev() {
let s = SkipMap::new();
Expand Down

0 comments on commit 800ca61

Please sign in to comment.