Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #148 (UB in try_call_once) #149

Merged
merged 4 commits into from
Apr 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 123 additions & 97 deletions src/once.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@ mod status {
}
use self::status::{AtomicStatus, Status};

use core::hint::unreachable_unchecked as unreachable;

impl<T, R: RelaxStrategy> Once<T, R> {
/// Performs an initialization routine once and only once. The given closure
/// will be executed if this is the first time `call_once` has been called,
Expand Down Expand Up @@ -208,111 +206,92 @@ impl<T, R: RelaxStrategy> Once<T, R> {
/// }
/// ```
pub fn try_call_once<F: FnOnce() -> Result<T, E>, E>(&self, f: F) -> Result<&T, E> {
// SAFETY: We perform an Acquire load because if this were to return COMPLETE, then we need
// the preceding stores done while initializing, to become visible after this load.
let mut status = self.status.load(Ordering::Acquire);
if let Some(value) = self.get() {
Ok(value)
} else {
self.try_call_once_slow(f)
}
}

if status == Status::Incomplete {
match self.status.compare_exchange(
#[cold]
fn try_call_once_slow<F: FnOnce() -> Result<T, E>, E>(&self, f: F) -> Result<&T, E> {
loop {
let xchg = self.status.compare_exchange(
Status::Incomplete,
Status::Running,
// SAFETY: Success ordering: We do not have to synchronize any data at all, as the
// value is at this point uninitialized, so Relaxed is technically sufficient. We
// will however have to do a Release store later. However, the success ordering
// must always be at least as strong as the failure ordering, so we choose Acquire
// here anyway.
Ordering::Acquire,
// SAFETY: Failure ordering: While we have already loaded the status initially, we
// know that if some other thread would have fully initialized this in between,
// then there will be new not-yet-synchronized accesses done during that
// initialization that would not have been synchronized by the earlier load. Thus
// we use Acquire to ensure when we later call force_get() in the last match
// statement, if the status was changed to COMPLETE, that those accesses will become
// visible to us.
Ordering::Acquire,
) {
Ok(_must_be_state_incomplete) => {
// The compare-exchange succeeded, so we shall initialize it.

// We use a guard (Finish) to catch panics caused by builder
let finish = Finish {
status: &self.status,
};
let val = match f() {
Ok(val) => val,
Err(err) => {
// If an error occurs, clean up everything and leave.
core::mem::forget(finish);
self.status.store(Status::Incomplete, Ordering::Release);
return Err(err);
}
};
unsafe {
// SAFETY:
// `UnsafeCell`/deref: currently the only accessor, mutably
// and immutably by cas exclusion.
// `write`: pointer comes from `MaybeUninit`.
(*self.data.get()).as_mut_ptr().write(val);
};
// If there were to be a panic with unwind enabled, the code would
// short-circuit and never reach the point where it writes the inner data.
// The destructor for Finish will run, and poison the Once to ensure that other
// threads accessing it do not exhibit unwanted behavior, if there were to be
// any inconsistency in data structures caused by the panicking thread.
//
// However, f() is expected in the general case not to panic. In that case, we
// simply forget the guard, bypassing its destructor. We could theoretically
// clear a flag instead, but this eliminates the call to the destructor at
// compile time, and unconditionally poisons during an eventual panic, if
// unwinding is enabled.
core::mem::forget(finish);

// SAFETY: Release is required here, so that all memory accesses done in the
// closure when initializing, become visible to other threads that perform Acquire
// loads.
//
// And, we also know that the changes this thread has done will not magically
// disappear from our cache, so it does not need to be AcqRel.
self.status.store(Status::Complete, Ordering::Release);
);

// This next line is mainly an optimization.
return unsafe { Ok(self.force_get()) };
match xchg {
Ok(_must_be_state_incomplete) => {
// Impl is defined after the match for readability
}
Err(Status::Panicked) => panic!("Once panicked"),
Err(Status::Running) => match self.poll() {
Some(v) => return Ok(v),
None => continue,
},
Err(Status::Complete) => {
return Ok(unsafe {
// SAFETY: The status is Complete
self.force_get()
});
}
Err(Status::Incomplete) => {
// The compare_exchange failed, so this shouldn't ever be reached,
// however if we decide to switch to compare_exchange_weak it will
// be safer to leave this here than hit an unreachable
continue;
}
// The compare-exchange failed, so we know for a fact that the status cannot be
// INCOMPLETE, or it would have succeeded.
Err(other_status) => status = other_status,
}
}

Ok(match status {
// SAFETY: We have either checked with an Acquire load, that the status is COMPLETE, or
// initialized it ourselves, in which case no additional synchronization is needed.
Status::Complete => unsafe { self.force_get() },
Status::Panicked => panic!("Once panicked"),
Status::Running => self.poll().unwrap_or_else(|| {
if cfg!(debug_assertions) {
unreachable!("Encountered INCOMPLETE when polling Once")
} else {
// SAFETY: This poll is guaranteed never to fail because the API of poll
// promises spinning if initialization is in progress. We've already
// checked that initialisation is in progress, and initialisation is
// monotonic: once done, it cannot be undone. We also fetched the status
// with Acquire semantics, thereby guaranteeing that the later-executed
// poll will also agree with us that initialization is in progress. Ergo,
// this poll cannot fail.
unsafe {
unreachable();
}
}
}),
// The compare-exchange succeeded, so we shall initialize it.

// SAFETY: The only invariant possible in addition to the aforementioned ones at the
// moment, is INCOMPLETE. However, the only way for this match statement to be
// reached, is if we lost the CAS (otherwise we would have returned early), in
// which case we know for a fact that the state cannot be changed back to INCOMPLETE as
// `Once`s are monotonic.
Status::Incomplete => unsafe { unreachable() },
})
// We use a guard (Finish) to catch panics caused by builder
let finish = Finish {
status: &self.status,
};
let val = match f() {
Ok(val) => val,
Err(err) => {
// If an error occurs, clean up everything and leave.
core::mem::forget(finish);
self.status.store(Status::Incomplete, Ordering::Release);
return Err(err);
}
};
unsafe {
// SAFETY:
// `UnsafeCell`/deref: currently the only accessor, mutably
// and immutably by cas exclusion.
// `write`: pointer comes from `MaybeUninit`.
(*self.data.get()).as_mut_ptr().write(val);
};
// If there were to be a panic with unwind enabled, the code would
// short-circuit and never reach the point where it writes the inner data.
// The destructor for Finish will run, and poison the Once to ensure that other
// threads accessing it do not exhibit unwanted behavior, if there were to be
// any inconsistency in data structures caused by the panicking thread.
//
// However, f() is expected in the general case not to panic. In that case, we
// simply forget the guard, bypassing its destructor. We could theoretically
// clear a flag instead, but this eliminates the call to the destructor at
// compile time, and unconditionally poisons during an eventual panic, if
// unwinding is enabled.
core::mem::forget(finish);

// SAFETY: Release is required here, so that all memory accesses done in the
// closure when initializing, become visible to other threads that perform Acquire
// loads.
//
// And, we also know that the changes this thread has done will not magically
// disappear from our cache, so it does not need to be AcqRel.
self.status.store(Status::Complete, Ordering::Release);

// This next line is mainly an optimization.
return unsafe { Ok(self.force_get()) };
}
}

/// Spins until the [`Once`] contains a value.
Expand Down Expand Up @@ -547,7 +526,9 @@ impl<'a> Drop for Finish<'a> {
mod tests {
use std::prelude::v1::*;

use std::sync::atomic::AtomicU32;
use std::sync::mpsc::channel;
use std::sync::Arc;
use std::thread;

use super::*;
Expand Down Expand Up @@ -706,6 +687,51 @@ mod tests {
}
}

#[test]
fn try_call_once_err() {
let once = Once::<_, Spin>::new();
let shared = Arc::new((once, AtomicU32::new(0)));

let (tx, rx) = channel();

let t0 = {
let shared = shared.clone();
thread::spawn(move || {
let (once, called) = &*shared;

once.try_call_once(|| {
called.fetch_add(1, Ordering::AcqRel);
tx.send(()).unwrap();
thread::sleep(std::time::Duration::from_millis(50));
Err(())
})
.ok();
})
};

let t1 = {
let shared = shared.clone();
thread::spawn(move || {
rx.recv().unwrap();
let (once, called) = &*shared;
assert_eq!(
called.load(Ordering::Acquire),
1,
"leader thread did not run first"
);

once.call_once(|| {
called.fetch_add(1, Ordering::AcqRel);
});
})
};

t0.join().unwrap();
t1.join().unwrap();

assert_eq!(shared.1.load(Ordering::Acquire), 2);
}

// This is sort of two test cases, but if we write them as separate test methods
// they can be executed concurrently and then fail some small fraction of the
// time.
Expand Down