Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace the linked list with a safer and less allocation-heavy alternative #38

Merged
merged 5 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ __test = []
[dependencies]
crossbeam-utils = { version = "0.8.12", default-features = false }
parking = { version = "2.0.0", optional = true }
slab = { version = "0.4.7", default-features = false }

[dev-dependencies]
criterion = "0.3.4"
Expand Down
27 changes: 7 additions & 20 deletions src/inner.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! The inner mechanism powering the `Event` type.

use crate::list::{Entry, List};
use crate::list::List;
use crate::node::Node;
use crate::queue::Queue;
use crate::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
Expand All @@ -11,7 +11,6 @@ use alloc::vec;
use alloc::vec::Vec;

use core::ops;
use core::ptr::NonNull;

/// Inner state of [`Event`].
pub(crate) struct Inner {
Expand All @@ -25,14 +24,6 @@ pub(crate) struct Inner {

/// Queue of nodes waiting to be processed.
queue: Queue,

/// A single cached list entry to avoid allocations on the fast path of the insertion.
///
/// This field can only be written to when the `cache_used` field in the `list` structure
/// is false, or the user has a pointer to the `Entry` identical to this one and that user
/// has exclusive access to that `Entry`. An immutable pointer to this field is kept in
/// the `list` structure when it is in use.
cache: UnsafeCell<Entry>,
}

impl Inner {
Expand All @@ -42,7 +33,6 @@ impl Inner {
notified: AtomicUsize::new(core::usize::MAX),
list: Mutex::new(List::new()),
queue: Queue::new(),
cache: UnsafeCell::new(Entry::new()),
}
}

Expand All @@ -62,12 +52,6 @@ impl Inner {
// Acquire and drop the lock to make sure that the queue is flushed.
let _guard = self.lock();
}

/// Returns the pointer to the single cached list entry.
#[inline(always)]
pub(crate) fn cache_ptr(&self) -> NonNull<Entry> {
unsafe { NonNull::new_unchecked(self.cache.get()) }
}
}

/// The guard returned by [`Inner::lock`].
Expand All @@ -88,11 +72,11 @@ impl ListGuard<'_> {
guard: &mut MutexGuard<'_, List>,
) {
// Process the start node.
tasks.extend(start_node.apply(guard, self.inner));
tasks.extend(start_node.apply(guard));

// Process all remaining nodes.
while let Some(node) = self.inner.queue.pop() {
tasks.extend(node.apply(guard, self.inner));
tasks.extend(node.apply(guard));
}
}
}
Expand Down Expand Up @@ -125,7 +109,7 @@ impl Drop for ListGuard<'_> {
}

// Update the atomic `notified` counter.
let notified = if list.notified < list.len {
let notified = if list.notified < list.len() {
list.notified
} else {
core::usize::MAX
Expand Down Expand Up @@ -224,3 +208,6 @@ impl<'a, T> ops::DerefMut for MutexGuard<'a, T> {
unsafe { &mut *self.mutex.value.get() }
}
}

unsafe impl<T: Send> Send for Mutex<T> {}
unsafe impl<T: Send> Sync for Mutex<T> {}
158 changes: 92 additions & 66 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ use alloc::sync::Arc;

use core::fmt;
use core::future::Future;
use core::mem::ManuallyDrop;
use core::mem::{self, ManuallyDrop};
use core::num::NonZeroUsize;
use core::pin::Pin;
use core::ptr::{self, NonNull};
use core::ptr;
use core::sync::atomic::{self, AtomicPtr, AtomicUsize, Ordering};
use core::task::{Context, Poll, Waker};
use core::usize;
Expand All @@ -92,7 +93,7 @@ use std::time::{Duration, Instant};

use inner::Inner;
use list::{Entry, State};
use node::Node;
use node::{Node, TaskWaiting};

#[cfg(feature = "std")]
use parking::Unparker;
Expand Down Expand Up @@ -168,9 +169,6 @@ pub struct Event {
inner: AtomicPtr<Inner>,
}

unsafe impl Send for Event {}
unsafe impl Sync for Event {}

#[cfg(feature = "std")]
impl UnwindSafe for Event {}
#[cfg(feature = "std")]
Expand Down Expand Up @@ -210,31 +208,31 @@ impl Event {
let inner = self.inner();

// Try to acquire a lock in the inner list.
let entry = unsafe {
if let Some(mut lock) = (*inner).lock() {
let entry = lock.alloc((*inner).cache_ptr());
lock.insert(entry);
let state = {
let inner = unsafe { &*inner };
if let Some(mut lock) = inner.lock() {
let entry = lock.insert(Entry::new());

entry
ListenerState::HasNode(entry)
} else {
// Push entries into the queue indicating that we want to push a listener.
let (node, entry) = Node::listener();
(*inner).push(node);
inner.push(node);

// Indicate that there are nodes waiting to be notified.
(*inner)
inner
.notified
.compare_exchange(usize::MAX, 0, Ordering::AcqRel, Ordering::Relaxed)
.ok();

entry
ListenerState::Queued(entry)
}
};

// Register the listener.
let listener = EventListener {
inner: unsafe { Arc::clone(&ManuallyDrop::new(Arc::from_raw(inner))) },
entry: Some(entry),
state,
};

// Make sure the listener is registered before whatever happens next.
Expand Down Expand Up @@ -529,12 +527,20 @@ pub struct EventListener {
/// A reference to [`Event`]'s inner state.
inner: Arc<Inner>,

/// A pointer to this listener's entry in the linked list.
entry: Option<NonNull<Entry>>,
/// The current state of the listener.
state: ListenerState,
}

unsafe impl Send for EventListener {}
unsafe impl Sync for EventListener {}
enum ListenerState {
/// The listener has a node inside of the linked list.
HasNode(NonZeroUsize),

/// The listener has already been notified and has discarded its entry.
Discarded,

/// The listener has an entry in the queue that may or may not have a task waiting.
Queued(Arc<TaskWaiting>),
}

#[cfg(feature = "std")]
impl UnwindSafe for EventListener {}
Expand Down Expand Up @@ -605,11 +611,26 @@ impl EventListener {

fn wait_internal(mut self, deadline: Option<Instant>) -> bool {
// Take out the entry pointer and set it to `None`.
let entry = match self.entry.take() {
None => unreachable!("cannot wait twice on an `EventListener`"),
Some(entry) => entry,
};
let (parker, unparker) = parking::pair();
let entry = match self.state.take() {
ListenerState::HasNode(entry) => entry,
ListenerState::Queued(task_waiting) => {
// This listener is stuck in the backup queue.
// Wait for the task to be notified.
loop {
match task_waiting.status() {
Some(entry_id) => break entry_id,
None => {
// Register a task and park until it is notified.
task_waiting.register(Task::Thread(unparker.clone()));

parker.park();
}
}
}
}
ListenerState::Discarded => panic!("Cannot wait on a discarded listener"),
};

// Wait for the lock to be available.
let lock = || {
Expand All @@ -628,22 +649,15 @@ impl EventListener {

// Set this listener's state to `Waiting`.
{
let e = unsafe { entry.as_ref() };

if e.is_queued() {
// Write a task to be woken once the lock is acquired.
e.write_task(Task::Thread(unparker));
} else {
let mut list = lock();
let mut list = lock();

// If the listener was notified, we're done.
match e.state().replace(State::Notified(false)) {
State::Notified(_) => {
list.remove(entry, self.inner.cache_ptr());
return true;
}
_ => e.state().set(State::Task(Task::Thread(unparker))),
// If the listener was notified, we're done.
match list.state(entry).replace(State::Notified(false)) {
State::Notified(_) => {
list.remove(entry);
return true;
}
_ => list.state(entry).set(State::Task(Task::Thread(unparker))),
}
}

Expand All @@ -658,7 +672,7 @@ impl EventListener {
if now >= deadline {
// Remove the entry and check if notified.
let mut list = lock();
let state = list.remove(entry, self.inner.cache_ptr());
let state = list.remove(entry);
return state.is_notified();
}

Expand All @@ -668,17 +682,16 @@ impl EventListener {
}

let mut list = lock();
let e = unsafe { entry.as_ref() };

// Do a dummy replace operation in order to take out the state.
match e.state().replace(State::Notified(false)) {
match list.state(entry).replace(State::Notified(false)) {
State::Notified(_) => {
// If this listener has been notified, remove it from the list and return.
list.remove(entry, self.inner.cache_ptr());
list.remove(entry);
return true;
}
// Otherwise, set the state back to `Waiting`.
state => e.state().set(state),
state => list.state(entry).set(state),
}
}
}
Expand Down Expand Up @@ -706,10 +719,10 @@ impl EventListener {
/// ```
pub fn discard(mut self) -> bool {
// If this listener has never picked up a notification...
if let Some(entry) = self.entry.take() {
if let ListenerState::HasNode(entry) = self.state.take() {
// Remove the listener from the list and return `true` if it was notified.
if let Some(mut lock) = self.inner.lock() {
let state = lock.remove(entry, self.inner.cache_ptr());
let state = lock.remove(entry);

if let State::Notified(_) = state {
return true;
Expand Down Expand Up @@ -772,6 +785,30 @@ impl Future for EventListener {

#[allow(unreachable_patterns)]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let entry = match self.state {
ListenerState::Discarded => {
unreachable!("cannot poll a completed `EventListener` future")
}
ListenerState::HasNode(ref entry) => *entry,
ListenerState::Queued(ref task_waiting) => {
loop {
// See if the task waiting has been completed.
match task_waiting.status() {
Some(entry_id) => {
self.state = ListenerState::HasNode(entry_id);
break entry_id;
}
None => {
// If not, wait for it to complete.
task_waiting.register(Task::Waker(cx.waker().clone()));
if task_waiting.status().is_none() {
return Poll::Pending;
}
}
}
}
}
};
let mut list = match self.inner.lock() {
Some(list) => list,
None => {
Expand All @@ -787,20 +824,15 @@ impl Future for EventListener {
}
}
};

let entry = match self.entry {
None => unreachable!("cannot poll a completed `EventListener` future"),
Some(entry) => entry,
};
let state = unsafe { entry.as_ref().state() };
let state = list.state(entry);

// Do a dummy replace operation in order to take out the state.
match state.replace(State::Notified(false)) {
State::Notified(_) => {
// If this listener has been notified, remove it from the list and return.
list.remove(entry, self.inner.cache_ptr());
list.remove(entry);
drop(list);
self.entry = None;
self.state = ListenerState::Discarded;
return Poll::Ready(());
}
State::Created => {
Expand All @@ -827,12 +859,11 @@ impl Future for EventListener {
impl Drop for EventListener {
fn drop(&mut self) {
// If this listener has never picked up a notification...
if let Some(entry) = self.entry.take() {
if let ListenerState::HasNode(entry) = self.state.take() {
match self.inner.lock() {
Some(mut list) => {
// But if a notification was delivered to it...
if let State::Notified(additional) = list.remove(entry, self.inner.cache_ptr())
{
if let State::Notified(additional) = list.remove(entry) {
// Then pass it on to another active listener.
list.notify(1, additional);
}
Expand All @@ -849,6 +880,12 @@ impl Drop for EventListener {
}
}

impl ListenerState {
fn take(&mut self) -> Self {
mem::replace(self, ListenerState::Discarded)
}
}

/// Equivalent to `atomic::fence(Ordering::SeqCst)`, but in some cases faster.
#[inline]
fn full_fence() {
Expand Down Expand Up @@ -877,17 +914,6 @@ fn full_fence() {
}
}

/// Indicate that we're using spin-based contention and that we should yield the CPU.
#[inline]
fn yield_now() {
#[cfg(feature = "std")]
std::thread::yield_now();

#[cfg(not(feature = "std"))]
#[allow(deprecated)]
sync::atomic::spin_loop_hint();
}

#[cfg(any(feature = "__test", test))]
impl Event {
/// Locks the event.
Expand Down
Loading