Skip to content

Commit

Permalink
Rework packet waker atomic ops to be less error prone to implement
Browse files Browse the repository at this point in the history
Add test-asan workflow
  • Loading branch information
h33p committed Oct 30, 2023
1 parent 13bd8fb commit 4faa599
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 66 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,38 @@ jobs:
- name: Run all tests
run: cargo test --workspace --all-features --verbose

test-asan:
runs-on: ${{ matrix.os }}
env:
RUSTFLAGS: -Zsanitizer=address -C debuginfo=2 ${{ matrix.rustflags }}
RUSTDOCFLAGS: -Zsanitizer=address -C debuginfo=2 ${{ matrix.rustflags }}
CARGO_BUILD_RUSTFLAGS: -C debuginfo=2
ASAN_OPTIONS: symbolize=1 detect_leaks=0
timeout-minutes: 20
strategy:
fail-fast: false
matrix:
# TODO: enable windows, macos
os: [ubuntu-latest]
toolchain: ["nightly-2023-09-01"]
rustflags: ["--cfg mfio_assume_linear_types --cfg tokio_unstable", "--cfg tokio_unstable"]
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
toolchain: ${{ matrix.toolchain }}
override: true

- name: Get rustc target
run: |
echo "RUSTC_TARGET=$(rustc -vV | sed -n 's|host: ||p')" >> $GITHUB_OUTPUT
id: target
- name: Install llvm
run: sudo apt update && sudo apt install llvm-13
- run: rustup component add rust-src
- name: Run all tests
run: cargo -Zbuild-std test --verbose --target ${{ steps.target.outputs.RUSTC_TARGET }}

lint:
runs-on: ${{ matrix.os }}
env:
Expand Down
16 changes: 8 additions & 8 deletions mfio/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pub trait PacketIoExt<Perms: PacketPerms, Param>: PacketIo<Perms, Param> {
//IoFut::NewId(self, param, packet.stack())
IoFut {
pkt: UnsafeCell::new(Some(packet.stack())),
initial_state: Some((self, param)),
initial_state: UnsafeCell::new(Some((self, param))),
_phantom: PhantomData,
}
}
Expand All @@ -116,7 +116,7 @@ pub trait PacketIoExt<Perms: PacketPerms, Param>: PacketIo<Perms, Param> {
//IoFut::NewId(self, param, packet.stack())
IoToFut {
pkt_out: UnsafeCell::new(Some((packet.stack(), output.stack()))),
initial_state: Some((self, param)),
initial_state: UnsafeCell::new(Some((self, param))),
_phantom: PhantomData,
}
}
Expand Down Expand Up @@ -177,7 +177,7 @@ impl NoPos {

pub struct IoFut<'a, T, Perms: PacketPerms, Param, Packet: PacketStore<'a, Perms>> {
pkt: UnsafeCell<Option<Packet::StackReq<'a>>>,
initial_state: Option<(&'a T, Param)>,
initial_state: UnsafeCell<Option<(&'a T, Param)>>,
_phantom: PhantomData<Perms>,
}

Expand All @@ -187,10 +187,10 @@ impl<'a, T: PacketIo<Perms, Param>, Perms: PacketPerms, Param, Pkt: PacketStore<
type Output = Pkt::StackReq<'a>;

fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let state = unsafe { self.get_unchecked_mut() };
let state: &Self = unsafe { core::mem::transmute(self) };

loop {
match state.initial_state.take() {
match unsafe { (*state.initial_state.get()).take() } {
Some((io, param)) => {
// SAFETY: this packet's existence is tied to 'a lifetime, meaning it will be valid
// throughout 'a.
Expand Down Expand Up @@ -230,7 +230,7 @@ pub struct IoToFut<
Output: OutputStore<'a, Perms>,
> {
pkt_out: UnsafeCell<Option<(Packet::StackReq<'a>, Output::StackReq<'a>)>>,
initial_state: Option<(&'a T, Param)>,
initial_state: UnsafeCell<Option<(&'a T, Param)>>,
_phantom: PhantomData<Perms>,
}

Expand All @@ -244,9 +244,9 @@ impl<
> IoToFut<'a, T, Perms, Param, Pkt, Out>
{
pub fn submit(self: Pin<&mut Self>) -> &Out::StackReq<'a> {
let state = unsafe { self.get_unchecked_mut() };
let state: &Self = unsafe { core::mem::transmute(self) };

if let Some((io, param)) = state.initial_state.take() {
if let Some((io, param)) = unsafe { (*state.initial_state.get()).take() } {
// SAFETY: this packet's existence is tied to 'a lifetime, meaning it will be valid
// throughout 'a.
let (pkt, out): &'a mut (Pkt::StackReq<'a>, Out::StackReq<'a>) =
Expand Down
175 changes: 119 additions & 56 deletions mfio/src/io/packet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,102 @@ pub use output::*;
mod view;
pub use view::*;

const LOCK_BIT: u64 = 1 << 63;
const HAS_WAKER_BIT: u64 = 1 << 62;
const FINALIZED_BIT: u64 = 1 << 61;
const ALL_BITS: u64 = LOCK_BIT | HAS_WAKER_BIT | FINALIZED_BIT;

struct RcAndWaker {
rc_and_flags: AtomicU64,
waker: UnsafeCell<MaybeUninit<CWaker>>,
}

impl Default for RcAndWaker {
fn default() -> Self {
Self {
rc_and_flags: 0.into(),
waker: UnsafeCell::new(MaybeUninit::uninit()),
}
}
}

impl core::fmt::Debug for RcAndWaker {
fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(
fmt,
"{}",
(self.rc_and_flags.load(Ordering::Relaxed) & HAS_WAKER_BIT) != 0
)
}
}

impl RcAndWaker {
fn acquire(&self) -> bool {
(loop {
let flags = self.rc_and_flags.fetch_or(LOCK_BIT, Ordering::AcqRel);
if (flags & LOCK_BIT) == 0 {
break flags;
}
while self.rc_and_flags.load(Ordering::Relaxed) & LOCK_BIT != 0 {
core::hint::spin_loop();
}
} & HAS_WAKER_BIT)
!= 0
}

pub fn take(&self) -> Option<CWaker> {
let ret = if self.acquire() {
Some(unsafe { (*self.waker.get()).assume_init_read() })
} else {
None
};
self.rc_and_flags
.fetch_and(!(LOCK_BIT | HAS_WAKER_BIT), Ordering::Release);
ret
}

pub fn write(&self, waker: CWaker) -> u64 {
if self.acquire() {
unsafe { core::ptr::drop_in_place((*self.waker.get()).as_mut_ptr()) }
}

unsafe { *self.waker.get() = MaybeUninit::new(waker) };

self.rc_and_flags.fetch_or(HAS_WAKER_BIT, Ordering::Relaxed);
self.rc_and_flags.fetch_and(!LOCK_BIT, Ordering::AcqRel) & !ALL_BITS
}

pub fn acquire_rc(&self) -> u64 {
self.rc_and_flags.load(Ordering::Acquire) & !ALL_BITS
}

pub fn dec_rc(&self) -> (u64, bool) {
let ret = self.rc_and_flags.fetch_sub(1, Ordering::AcqRel);
(ret & !ALL_BITS, (ret & HAS_WAKER_BIT) != 0)
}

pub fn inc_rc(&self) -> u64 {
self.rc_and_flags.fetch_add(1, Ordering::AcqRel) & !ALL_BITS
}

pub fn finalize(&self) {
self.rc_and_flags.fetch_or(FINALIZED_BIT, Ordering::Release);
}

pub fn wait_finalize(&self) {
// FIXME: in theory, wait_finalize should only wait for the FINALIZED_BIT, but not deal
// with the locking and the waker. However, something is making us have to take the waker,
// to make these atomic ops sound (however, even then I doubt this is fully sound, but is
// merely moving probability of desync lower).
// Either way, we should be able to have this waker mechanism be way more optimized,
// without atomic locks.
self.take();
while (self.rc_and_flags.load(Ordering::Acquire) & FINALIZED_BIT) == 0 {
core::hint::spin_loop();
}
}
}

/// Describes a full packet.
///
/// This packet is considered simple.
Expand Down Expand Up @@ -548,9 +644,7 @@ pub struct Packet<Perms: PacketPerms> {
///
/// return true
/// ```
rc_and_flags: AtomicUsize,
/// Waker to be triggered, upon `rc` dropping down to 0.
waker: UnsafeCell<MaybeUninit<CWaker>>,
rc_and_waker: RcAndWaker,
/// What was the smallest position that resulted in an error.
///
/// This value is initialized to !0, and upon each errored packet segment, is minned
Expand All @@ -573,17 +667,8 @@ unsafe impl<Perms: PacketPerms> Sync for Packet<Perms> {}

impl<Perms: PacketPerms> Drop for Packet<Perms> {
fn drop(&mut self) {
let loaded = self.rc_and_flags.load(Ordering::Acquire);
assert_eq!(
loaded & !(0b11 << 62),
0,
"The packet has in-flight segments."
);
if loaded >> 62 == 0b11 {
unsafe {
core::ptr::drop_in_place(self.waker.get_mut().as_mut_ptr());
}
}
let loaded = self.rc_and_waker.acquire_rc();
assert_eq!(loaded, 0, "The packet has in-flight segments.");
}
}

Expand All @@ -593,46 +678,26 @@ impl<'a, Perms: PacketPerms> Future for &'a Packet<Perms> {
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = Pin::into_inner(self);

// Clear the flag bits, because we want the end writing bit be properly set
let flags = this.rc_and_flags.fetch_and(!(0b11 << 62), Ordering::AcqRel);
let rc = this.rc_and_waker.write(cx.waker().clone().into());

// Drop the old waker
if (flags >> 62) == 0b11 {
unsafe {
core::ptr::drop_in_place((*this.waker.get()).as_mut_ptr());
}
}
if rc == 0 {
// Synchronize the thread that last decremented the refcount.
// If we don't, we risk a race condition where we drop the packet, while the packet
// reference is still being used to take the waker.
this.rc_and_waker.wait_finalize();

// Load in the start writing bit
let loaded = this.rc_and_flags.fetch_or(0b1 << 63, Ordering::AcqRel);

if loaded & !(0b11 << 62) == 0 {
// no more packets left, we don't need to write anything
return Poll::Ready(());
}

unsafe {
*this.waker.get() = MaybeUninit::new(cx.waker().clone().into());
}

// Load in the end writing bit.
let loaded = this.rc_and_flags.fetch_or(0b1 << 62, Ordering::AcqRel);

if loaded & !(0b11 << 62) == 0 {
// no more packets left, we wrote uselessly
// The waker will be freed in packet drop...
return Poll::Ready(());
}

// true indicates the waker was installed and we can go to sleep.
Poll::Pending
}
}

impl<Perms: PacketPerms> Packet<Perms> {
/// Current reference count of the packet.
pub fn rc(&self) -> usize {
self.rc_and_flags.load(Ordering::Relaxed) & !(0b11 << 62)
(self.rc_and_waker.acquire_rc()) as usize
}

unsafe fn on_output(&self, error: Option<(u64, NonZeroI32)>) -> Option<CWaker> {
Expand All @@ -642,29 +707,28 @@ impl<Perms: PacketPerms> Packet<Perms> {
}
}

let loaded = self.rc_and_flags.fetch_sub(1, Ordering::AcqRel);
let (prev, has_waker) = self.rc_and_waker.dec_rc();

// Do nothing, because we are either:
//
// - Not the last packet (any of the first 62 bits set).
// - The waker was not fully written yet (the last 2 bits are not 0b11). This case will be
// handled by the polling thread appropriately.
if loaded != (0b11 << 62) + 1 {
// Do nothing, because we are not the last packet (any of the first 62 bits set).
if prev != 1 {
return None;
}

if self.rc_and_flags.fetch_and(!(0b11 << 62), Ordering::AcqRel) >> 62 == 0b11 {
// FIXME: dial this atomic codepath in, because we've seen uninitialized reads.
Some(core::ptr::read(self.waker.get()).assume_init())
let ret = if has_waker {
self.rc_and_waker.take()
} else {
None
}
};

self.rc_and_waker.finalize();

ret
}

unsafe fn on_add_to_view(&self) {
let rc = self.rc_and_flags.fetch_add(1, Ordering::AcqRel) & !(0b11 << 62);
let rc = self.rc_and_waker.inc_rc();
if rc != 0 {
self.rc_and_flags.fetch_sub(1, Ordering::AcqRel);
self.rc_and_waker.dec_rc();
assert_eq!(rc, 0);
}
}
Expand All @@ -688,8 +752,7 @@ impl<Perms: PacketPerms> Packet<Perms> {
pub unsafe fn new_hdr(vtbl: PacketVtblRef<Perms>) -> Self {
Packet {
vtbl,
rc_and_flags: AtomicUsize::new(0),
waker: UnsafeCell::new(MaybeUninit::uninit()),
rc_and_waker: Default::default(),
error_clamp: (!0u64).into(),
min_error: 0.into(),
}
Expand Down
4 changes: 2 additions & 2 deletions mfio/src/io/packet/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ impl<'a, Perms: PacketPerms> PacketView<'a, Perms> {
assert!(pos < self.len());

// TODO: maybe relaxed is enough here?
self.pkt().rc_and_flags.fetch_add(1, Ordering::Release);
self.pkt().rc_and_waker.inc_rc();

let Self {
pkt,
Expand Down Expand Up @@ -425,7 +425,7 @@ impl<'a, Perms: PacketPerms> PacketView<'a, Perms> {
///
/// Please see [`BoundPacketView::extract_packet`] documentation for details.
pub unsafe fn extract_packet(&self, offset: u64, len: u64) -> Self {
self.pkt().rc_and_flags.fetch_add(1, Ordering::AcqRel);
self.pkt().rc_and_waker.inc_rc();

let Self {
pkt, tag, start, ..
Expand Down

0 comments on commit 4faa599

Please sign in to comment.