diff --git a/src/libstd/sync/mutex.rs b/src/libstd/sync/mutex.rs index 87c2318a9377c..81e9c5deafd00 100644 --- a/src/libstd/sync/mutex.rs +++ b/src/libstd/sync/mutex.rs @@ -4,7 +4,7 @@ use crate::mem; use crate::ops::{Deref, DerefMut}; use crate::ptr; use crate::sys_common::mutex as sys; -use crate::sys_common::poison::{self, TryLockError, TryLockResult, LockResult}; +use crate::sys_common::poison::{self, LockResult, TryLockError, TryLockResult}; /// A mutual exclusion primitive useful for protecting shared data /// @@ -358,6 +358,35 @@ impl Mutex { let data = unsafe { &mut *self.data.get() }; poison::map_result(self.poison.borrow(), |_| data ) } + + /// Invoke a function under a mutex. + /// + /// This takes the mutex and once held passes a mutable reference to the + /// contained value to a `FnOnce` closure. Once the closure returns, the + /// mutex is released. + /// + /// # Panics + /// + /// If another user of this mutex panicked while holding the mutex, then + /// this call will panic. + /// + /// # Example + /// ``` + /// # #![feature(mutex_with)] + /// use std::sync::Mutex; + /// + /// let mutex = Mutex::new(0); + /// + /// // Atomically fetch the old value and increment it. + /// let old = mutex.with(|v| { let old = *v; *v += 1; old }); + /// + /// assert_eq!(old, 0); + /// assert_eq!(*mutex.lock().unwrap(), 1); + /// ``` + #[unstable(feature = "mutex_with", issue = "61974")] + pub fn with U>(&self, func: F) -> U { + self.lock().map(|mut v| func(&mut *v)).expect("Lock poisoned") + } } #[stable(feature = "rust1", since = "1.0.0")] @@ -703,4 +732,14 @@ mod tests { let comp: &[i32] = &[4, 2, 5]; assert_eq!(&*mutex.lock().unwrap(), comp); } + + #[test] + fn test_mutex_with() { + let mx = Mutex::new(123); + + let () = mx.with(|v| *v += 1); + + let lk = mx.lock().unwrap(); + assert_eq!(*lk, 124); + } } diff --git a/src/libstd/sync/rwlock.rs b/src/libstd/sync/rwlock.rs index b1b56f321fc6b..5cb42ce206990 100644 --- a/src/libstd/sync/rwlock.rs +++ b/src/libstd/sync/rwlock.rs @@ -409,6 +409,60 @@ impl RwLock { let data = unsafe { &mut *self.data.get() }; poison::map_result(self.poison.borrow(), |_| data) } + + /// Call function while holding write lock + /// + /// This attempts to take the write lock, and if successful passes a mutable + /// reference to the value to the callback. + /// + /// # Panics + /// + /// This function will panic if the RwLock is poisoned. An RwLock + /// is poisoned whenever a writer panics while holding an exclusive lock. + /// + /// # Example + /// + /// ``` + /// # #![feature(mutex_with)] + /// use std::sync::RwLock; + /// + /// let rw = RwLock::new(String::new()); + /// + /// let prev = rw.with_write(|mut s| { + /// let prev = s.clone(); + /// *s += "foo"; + /// prev + /// }); + /// ``` + #[unstable(feature = "mutex_with", issue = "61974")] + pub fn with_write U>(&self, func: F) -> U { + self.write().map(|mut v| func(&mut *v)).expect("RwLock poisoned") + } + + /// Call function while holding read lock + /// + /// This attempts to take the read lock, and if successful passes a + /// reference to the value to the callback. + /// + /// # Panics + /// + /// This function will panic if the RwLock is poisoned. An RwLock + /// is poisoned whenever a writer panics while holding an exclusive lock. + /// + /// # Example + /// + /// ``` + /// # #![feature(mutex_with)] + /// use std::sync::RwLock; + /// + /// let rw = RwLock::new("hello world".to_string()); + /// + /// let val = rw.with_read(|s| s.clone()); + /// ``` + #[unstable(feature = "mutex_with", issue = "61974")] + pub fn with_read U>(&self, func: F) -> U { + self.read().map(|v| func(&*v)).expect("RwLock poisoned") + } } #[stable(feature = "rust1", since = "1.0.0")] @@ -797,4 +851,24 @@ mod tests { Ok(x) => panic!("get_mut of poisoned RwLock is Ok: {:?}", x), } } + + #[test] + fn test_with_read() { + let m = RwLock::new(10); + + let v = m.with_read(|v| *v); + + assert_eq!(v, 10); + } + + #[test] + fn test_with_write() { + let m = RwLock::new(10); + + let old = m.with_write(|v| {let old = *v; *v += 1; old}); + let now = m.with_read(|v| *v); + + assert_eq!(old, 10); + assert_eq!(now, 11); + } }