Skip to content

Commit

Permalink
runtime: mitigate ABA with 32-bit queue indices when possible (tokio-…
Browse files Browse the repository at this point in the history
…rs#5042)

When 64-bit atomics are supported, use 32-bit queue indices. This
greatly improves resilience to ABA and has no impact on performance on
64-bit platforms.

Fixes: tokio-rs#5041
  • Loading branch information
sbarral authored and dbischof90 committed Oct 1, 2022
1 parent 8293ee1 commit 5bef21b
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 31 deletions.
10 changes: 10 additions & 0 deletions tokio/src/loom/std/atomic_u32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ impl AtomicU32 {
let inner = UnsafeCell::new(std::sync::atomic::AtomicU32::new(val));
AtomicU32 { inner }
}

/// Performs an unsynchronized load.
///
/// # Safety
///
/// All mutations must have happened before the unsynchronized load.
/// Additionally, there must be no concurrent mutations.
pub(crate) unsafe fn unsync_load(&self) -> u32 {
*(*self.inner.get()).get_mut()
}
}

impl Deref for AtomicU32 {
Expand Down
83 changes: 52 additions & 31 deletions tokio/src/runtime/scheduler/multi_thread/queue.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
//! Run-queue structures to support a work-stealing scheduler
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::{AtomicU16, AtomicU32};
use crate::loom::sync::Arc;
use crate::runtime::task::{self, Inject};
use crate::runtime::MetricsBatch;

use std::mem::MaybeUninit;
use std::mem::{self, MaybeUninit};
use std::ptr;
use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};

// Use wider integers when possible to increase ABA resilience.
//
// See issue #5041: <https://github.com/tokio-rs/tokio/issues/5041>.
cfg_has_atomic_u64! {
type UnsignedShort = u32;
type UnsignedLong = u64;
type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU32;
type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU64;
}
cfg_not_has_atomic_u64! {
type UnsignedShort = u16;
type UnsignedLong = u32;
type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU16;
type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU32;
}

/// Producer handle. May only be used from a single thread.
pub(crate) struct Local<T: 'static> {
inner: Arc<Inner<T>>,
Expand All @@ -21,19 +36,21 @@ pub(crate) struct Steal<T: 'static>(Arc<Inner<T>>);
pub(crate) struct Inner<T: 'static> {
/// Concurrently updated by many threads.
///
/// Contains two `u16` values. The LSB byte is the "real" head of the queue.
/// The `u16` in the MSB is set by a stealer in process of stealing values.
/// It represents the first value being stolen in the batch. `u16` is used
/// in order to distinguish between `head == tail` and `head == tail -
/// capacity`.
/// Contains two `UnsignedShort` values. The LSB byte is the "real" head of
/// the queue. The `UnsignedShort` in the MSB is set by a stealer in process
/// of stealing values. It represents the first value being stolen in the
/// batch. The `UnsignedShort` indices are intentionally wider than strictly
/// required for buffer indexing in order to provide ABA mitigation and make
/// it possible to distinguish between full and empty buffers.
///
/// When both `u16` values are the same, there is no active stealer.
/// When both `UnsignedShort` values are the same, there is no active
/// stealer.
///
/// Tracking an in-progress stealer prevents a wrapping scenario.
head: AtomicU32,
head: AtomicUnsignedLong,

/// Only updated by producer thread but read by many threads.
tail: AtomicU16,
tail: AtomicUnsignedShort,

/// Elements
buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>,
Expand Down Expand Up @@ -73,8 +90,8 @@ pub(crate) fn local<T: 'static>() -> (Steal<T>, Local<T>) {
}

let inner = Arc::new(Inner {
head: AtomicU32::new(0),
tail: AtomicU16::new(0),
head: AtomicUnsignedLong::new(0),
tail: AtomicUnsignedShort::new(0),
buffer: make_fixed_size(buffer.into_boxed_slice()),
});

Expand Down Expand Up @@ -115,7 +132,7 @@ impl<T> Local<T> {
// safety: this is the **only** thread that updates this cell.
let tail = unsafe { self.inner.tail.unsync_load() };

if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as u16 {
if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as UnsignedShort {
// There is capacity for the task
break tail;
} else if steal != real {
Expand Down Expand Up @@ -165,16 +182,16 @@ impl<T> Local<T> {
fn push_overflow(
&mut self,
task: task::Notified<T>,
head: u16,
tail: u16,
head: UnsignedShort,
tail: UnsignedShort,
inject: &Inject<T>,
metrics: &mut MetricsBatch,
) -> Result<(), task::Notified<T>> {
/// How many elements are we taking from the local queue.
///
/// This is one less than the number of tasks pushed to the inject
/// queue as we are also inserting the `task` argument.
const NUM_TASKS_TAKEN: u16 = (LOCAL_QUEUE_CAPACITY / 2) as u16;
const NUM_TASKS_TAKEN: UnsignedShort = (LOCAL_QUEUE_CAPACITY / 2) as UnsignedShort;

assert_eq!(
tail.wrapping_sub(head) as usize,
Expand Down Expand Up @@ -219,15 +236,15 @@ impl<T> Local<T> {
/// An iterator that takes elements out of the run queue.
struct BatchTaskIter<'a, T: 'static> {
buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY],
head: u32,
i: u32,
head: UnsignedLong,
i: UnsignedLong,
}
impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> {
type Item = task::Notified<T>;

#[inline]
fn next(&mut self) -> Option<task::Notified<T>> {
if self.i == u32::from(NUM_TASKS_TAKEN) {
if self.i == UnsignedLong::from(NUM_TASKS_TAKEN) {
None
} else {
let i_idx = self.i.wrapping_add(self.head) as usize & MASK;
Expand All @@ -247,7 +264,7 @@ impl<T> Local<T> {
// values again, and we are the only producer.
let batch_iter = BatchTaskIter {
buffer: &*self.inner.buffer,
head: head as u32,
head: head as UnsignedLong,
i: 0,
};
inject.push_batch(batch_iter.chain(std::iter::once(task)));
Expand Down Expand Up @@ -320,7 +337,7 @@ impl<T> Steal<T> {
// from `dst` there may not be enough capacity to steal.
let (steal, _) = unpack(dst.inner.head.load(Acquire));

if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as u16 / 2 {
if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as UnsignedShort / 2 {
// we *could* try to steal less here, but for simplicity, we're just
// going to abort.
return None;
Expand All @@ -335,7 +352,7 @@ impl<T> Steal<T> {
return None;
}

dst_metrics.incr_steal_count(n);
dst_metrics.incr_steal_count(n as u16);

// We are returning a task here
n -= 1;
Expand All @@ -360,7 +377,7 @@ impl<T> Steal<T> {

// Steal tasks from `self`, placing them into `dst`. Returns the number of
// tasks that were stolen.
fn steal_into2(&self, dst: &mut Local<T>, dst_tail: u16) -> u16 {
fn steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort {
let mut prev_packed = self.0.head.load(Acquire);
let mut next_packed;

Expand Down Expand Up @@ -402,7 +419,11 @@ impl<T> Steal<T> {
}
};

assert!(n <= LOCAL_QUEUE_CAPACITY as u16 / 2, "actual = {}", n);
assert!(
n <= LOCAL_QUEUE_CAPACITY as UnsignedShort / 2,
"actual = {}",
n
);

let (first, _) = unpack(next_packed);

Expand Down Expand Up @@ -479,7 +500,7 @@ impl<T> Drop for Local<T> {
}

impl<T> Inner<T> {
fn len(&self) -> u16 {
fn len(&self) -> UnsignedShort {
let (_, head) = unpack(self.head.load(Acquire));
let tail = self.tail.load(Acquire);

Expand All @@ -493,16 +514,16 @@ impl<T> Inner<T> {

/// Split the head value into the real head and the index a stealer is working
/// on.
fn unpack(n: u32) -> (u16, u16) {
let real = n & u16::MAX as u32;
let steal = n >> 16;
fn unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort) {
let real = n & UnsignedShort::MAX as UnsignedLong;
let steal = n >> (mem::size_of::<UnsignedShort>() * 8);

(steal as u16, real as u16)
(steal as UnsignedShort, real as UnsignedShort)
}

/// Join the two head values
fn pack(steal: u16, real: u16) -> u32 {
(real as u32) | ((steal as u32) << 16)
fn pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong {
(real as UnsignedLong) | ((steal as UnsignedLong) << (mem::size_of::<UnsignedShort>() * 8))
}

#[test]
Expand Down

0 comments on commit 5bef21b

Please sign in to comment.