Skip to content

Commit

Permalink
Fix #148 (UB in try_call_once) (#149)
Browse files Browse the repository at this point in the history
* Fix UB in `try_call_once` and add regression test.

* Fix MSRV

* Clean up `try_call_once` impl

* Remove unused import
  • Loading branch information
UnknownEclipse authored Apr 3, 2023
1 parent 907a550 commit 2a018b6
Showing 1 changed file with 123 additions and 97 deletions.
220 changes: 123 additions & 97 deletions src/once.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@ mod status {
}
use self::status::{AtomicStatus, Status};

use core::hint::unreachable_unchecked as unreachable;

impl<T, R: RelaxStrategy> Once<T, R> {
/// Performs an initialization routine once and only once. The given closure
/// will be executed if this is the first time `call_once` has been called,
Expand Down Expand Up @@ -208,111 +206,92 @@ impl<T, R: RelaxStrategy> Once<T, R> {
/// }
/// ```
pub fn try_call_once<F: FnOnce() -> Result<T, E>, E>(&self, f: F) -> Result<&T, E> {
// SAFETY: We perform an Acquire load because if this were to return COMPLETE, then we need
// the preceding stores done while initializing, to become visible after this load.
let mut status = self.status.load(Ordering::Acquire);
if let Some(value) = self.get() {
Ok(value)
} else {
self.try_call_once_slow(f)
}
}

if status == Status::Incomplete {
match self.status.compare_exchange(
#[cold]
fn try_call_once_slow<F: FnOnce() -> Result<T, E>, E>(&self, f: F) -> Result<&T, E> {
loop {
let xchg = self.status.compare_exchange(
Status::Incomplete,
Status::Running,
// SAFETY: Success ordering: We do not have to synchronize any data at all, as the
// value is at this point uninitialized, so Relaxed is technically sufficient. We
// will however have to do a Release store later. However, the success ordering
// must always be at least as strong as the failure ordering, so we choose Acquire
// here anyway.
Ordering::Acquire,
// SAFETY: Failure ordering: While we have already loaded the status initially, we
// know that if some other thread would have fully initialized this in between,
// then there will be new not-yet-synchronized accesses done during that
// initialization that would not have been synchronized by the earlier load. Thus
// we use Acquire to ensure when we later call force_get() in the last match
// statement, if the status was changed to COMPLETE, that those accesses will become
// visible to us.
Ordering::Acquire,
) {
Ok(_must_be_state_incomplete) => {
// The compare-exchange succeeded, so we shall initialize it.

// We use a guard (Finish) to catch panics caused by builder
let finish = Finish {
status: &self.status,
};
let val = match f() {
Ok(val) => val,
Err(err) => {
// If an error occurs, clean up everything and leave.
core::mem::forget(finish);
self.status.store(Status::Incomplete, Ordering::Release);
return Err(err);
}
};
unsafe {
// SAFETY:
// `UnsafeCell`/deref: currently the only accessor, mutably
// and immutably by cas exclusion.
// `write`: pointer comes from `MaybeUninit`.
(*self.data.get()).as_mut_ptr().write(val);
};
// If there were to be a panic with unwind enabled, the code would
// short-circuit and never reach the point where it writes the inner data.
// The destructor for Finish will run, and poison the Once to ensure that other
// threads accessing it do not exhibit unwanted behavior, if there were to be
// any inconsistency in data structures caused by the panicking thread.
//
// However, f() is expected in the general case not to panic. In that case, we
// simply forget the guard, bypassing its destructor. We could theoretically
// clear a flag instead, but this eliminates the call to the destructor at
// compile time, and unconditionally poisons during an eventual panic, if
// unwinding is enabled.
core::mem::forget(finish);

// SAFETY: Release is required here, so that all memory accesses done in the
// closure when initializing, become visible to other threads that perform Acquire
// loads.
//
// And, we also know that the changes this thread has done will not magically
// disappear from our cache, so it does not need to be AcqRel.
self.status.store(Status::Complete, Ordering::Release);
);

// This next line is mainly an optimization.
return unsafe { Ok(self.force_get()) };
match xchg {
Ok(_must_be_state_incomplete) => {
// Impl is defined after the match for readability
}
Err(Status::Panicked) => panic!("Once panicked"),
Err(Status::Running) => match self.poll() {
Some(v) => return Ok(v),
None => continue,
},
Err(Status::Complete) => {
return Ok(unsafe {
// SAFETY: The status is Complete
self.force_get()
});
}
Err(Status::Incomplete) => {
// The compare_exchange failed, so this shouldn't ever be reached,
// however if we decide to switch to compare_exchange_weak it will
// be safer to leave this here than hit an unreachable
continue;
}
// The compare-exchange failed, so we know for a fact that the status cannot be
// INCOMPLETE, or it would have succeeded.
Err(other_status) => status = other_status,
}
}

Ok(match status {
// SAFETY: We have either checked with an Acquire load, that the status is COMPLETE, or
// initialized it ourselves, in which case no additional synchronization is needed.
Status::Complete => unsafe { self.force_get() },
Status::Panicked => panic!("Once panicked"),
Status::Running => self.poll().unwrap_or_else(|| {
if cfg!(debug_assertions) {
unreachable!("Encountered INCOMPLETE when polling Once")
} else {
// SAFETY: This poll is guaranteed never to fail because the API of poll
// promises spinning if initialization is in progress. We've already
// checked that initialisation is in progress, and initialisation is
// monotonic: once done, it cannot be undone. We also fetched the status
// with Acquire semantics, thereby guaranteeing that the later-executed
// poll will also agree with us that initialization is in progress. Ergo,
// this poll cannot fail.
unsafe {
unreachable();
}
}
}),
// The compare-exchange succeeded, so we shall initialize it.

// SAFETY: The only invariant possible in addition to the aforementioned ones at the
// moment, is INCOMPLETE. However, the only way for this match statement to be
// reached, is if we lost the CAS (otherwise we would have returned early), in
// which case we know for a fact that the state cannot be changed back to INCOMPLETE as
// `Once`s are monotonic.
Status::Incomplete => unsafe { unreachable() },
})
// We use a guard (Finish) to catch panics caused by builder
let finish = Finish {
status: &self.status,
};
let val = match f() {
Ok(val) => val,
Err(err) => {
// If an error occurs, clean up everything and leave.
core::mem::forget(finish);
self.status.store(Status::Incomplete, Ordering::Release);
return Err(err);
}
};
unsafe {
// SAFETY:
// `UnsafeCell`/deref: currently the only accessor, mutably
// and immutably by cas exclusion.
// `write`: pointer comes from `MaybeUninit`.
(*self.data.get()).as_mut_ptr().write(val);
};
// If there were to be a panic with unwind enabled, the code would
// short-circuit and never reach the point where it writes the inner data.
// The destructor for Finish will run, and poison the Once to ensure that other
// threads accessing it do not exhibit unwanted behavior, if there were to be
// any inconsistency in data structures caused by the panicking thread.
//
// However, f() is expected in the general case not to panic. In that case, we
// simply forget the guard, bypassing its destructor. We could theoretically
// clear a flag instead, but this eliminates the call to the destructor at
// compile time, and unconditionally poisons during an eventual panic, if
// unwinding is enabled.
core::mem::forget(finish);

// SAFETY: Release is required here, so that all memory accesses done in the
// closure when initializing, become visible to other threads that perform Acquire
// loads.
//
// And, we also know that the changes this thread has done will not magically
// disappear from our cache, so it does not need to be AcqRel.
self.status.store(Status::Complete, Ordering::Release);

// This next line is mainly an optimization.
return unsafe { Ok(self.force_get()) };
}
}

/// Spins until the [`Once`] contains a value.
Expand Down Expand Up @@ -547,7 +526,9 @@ impl<'a> Drop for Finish<'a> {
mod tests {
use std::prelude::v1::*;

use std::sync::atomic::AtomicU32;
use std::sync::mpsc::channel;
use std::sync::Arc;
use std::thread;

use super::*;
Expand Down Expand Up @@ -706,6 +687,51 @@ mod tests {
}
}

#[test]
fn try_call_once_err() {
let once = Once::<_, Spin>::new();
let shared = Arc::new((once, AtomicU32::new(0)));

let (tx, rx) = channel();

let t0 = {
let shared = shared.clone();
thread::spawn(move || {
let (once, called) = &*shared;

once.try_call_once(|| {
called.fetch_add(1, Ordering::AcqRel);
tx.send(()).unwrap();
thread::sleep(std::time::Duration::from_millis(50));
Err(())
})
.ok();
})
};

let t1 = {
let shared = shared.clone();
thread::spawn(move || {
rx.recv().unwrap();
let (once, called) = &*shared;
assert_eq!(
called.load(Ordering::Acquire),
1,
"leader thread did not run first"
);

once.call_once(|| {
called.fetch_add(1, Ordering::AcqRel);
});
})
};

t0.join().unwrap();
t1.join().unwrap();

assert_eq!(shared.1.load(Ordering::Acquire), 2);
}

// This is sort of two test cases, but if we write them as separate test methods
// they can be executed concurrently and then fail some small fraction of the
// time.
Expand Down

0 comments on commit 2a018b6

Please sign in to comment.