Skip to content

Commit

Permalink
Fixed soundness hole in hold account
Browse files Browse the repository at this point in the history
  • Loading branch information
agerasev committed Aug 18, 2023
1 parent 146b79d commit 44e6806
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 46 deletions.
4 changes: 2 additions & 2 deletions async/src/rb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ impl<S: Storage> Consumer for AsyncRb<S> {
}
impl<S: Storage> RingBuffer for AsyncRb<S> {
#[inline]
fn hold_read(&self, flag: bool) {
unsafe fn hold_read(&self, flag: bool) {
self.base.hold_read(flag);
self.read.wake()
}
#[inline]
fn hold_write(&self, flag: bool) {
unsafe fn hold_write(&self, flag: bool) {
self.base.hold_write(flag);
self.write.wake()
}
Expand Down
8 changes: 4 additions & 4 deletions async/src/wrap/cons.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ use std::io;
impl<R: AsyncRbRef> Consumer for AsyncCons<R> {
#[inline]
unsafe fn set_read_index(&self, value: usize) {
self.base.set_read_index(value)
self.base().set_read_index(value)
}
#[inline]
fn try_pop(&mut self) -> Option<Self::Item> {
self.base.try_pop()
self.base_mut().try_pop()
}
#[inline]
fn pop_slice(&mut self, elems: &mut [Self::Item]) -> usize
where
Self::Item: Copy,
{
self.base.pop_slice(elems)
self.base_mut().pop_slice(elems)
}
}

Expand All @@ -38,7 +38,7 @@ impl<R: AsyncRbRef> AsyncConsumer for AsyncCons<R> {

#[inline]
fn close(&mut self) {
self.base.close();
drop(self.base.take());
}
}

Expand Down
31 changes: 20 additions & 11 deletions async/src/wrap/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,38 @@ use core::{mem::MaybeUninit, num::NonZeroUsize};
use ringbuf::{rb::traits::ToRbRef, traits::Observer, wrap::caching::Caching, Obs};

pub struct AsyncWrap<R: AsyncRbRef, const P: bool, const C: bool> {
base: Caching<R, P, C>,
base: Option<Caching<R, P, C>>,
}

pub type AsyncProd<R> = AsyncWrap<R, true, false>;
pub type AsyncCons<R> = AsyncWrap<R, false, true>;

impl<R: AsyncRbRef, const P: bool, const C: bool> AsyncWrap<R, P, C> {
pub unsafe fn new(rb: R) -> Self {
Self { base: Caching::new(rb) }
Self {
base: Some(Caching::new(rb)),
}
}

fn base(&self) -> &Caching<R, P, C> {
self.base.as_ref().unwrap()
}
fn base_mut(&mut self) -> &mut Caching<R, P, C> {
self.base.as_mut().unwrap()
}

pub fn observe(&self) -> Obs<R> {
self.base.observe()
self.base().observe()
}
}

impl<R: AsyncRbRef, const P: bool, const C: bool> ToRbRef for AsyncWrap<R, P, C> {
type RbRef = R;
fn rb_ref(&self) -> &R {
self.base.rb_ref()
self.base().rb_ref()
}
fn into_rb_ref(self) -> R {
self.base.into_rb_ref()
self.base.unwrap().into_rb_ref()
}
}

Expand All @@ -39,27 +48,27 @@ impl<R: AsyncRbRef, const P: bool, const C: bool> Observer for AsyncWrap<R, P, C

#[inline]
fn capacity(&self) -> NonZeroUsize {
self.base.capacity()
self.base().capacity()
}
#[inline]
fn read_index(&self) -> usize {
self.base.read_index()
self.base().read_index()
}
#[inline]
fn write_index(&self) -> usize {
self.base.write_index()
self.base().write_index()
}
#[inline]
unsafe fn unsafe_slices(&self, start: usize, end: usize) -> (&mut [MaybeUninit<Self::Item>], &mut [MaybeUninit<Self::Item>]) {
self.base.unsafe_slices(start, end)
self.base().unsafe_slices(start, end)
}

#[inline]
fn read_is_held(&self) -> bool {
self.base.read_is_held()
self.base().read_is_held()
}
#[inline]
fn write_is_held(&self) -> bool {
self.base.write_is_held()
self.base().write_is_held()
}
}
10 changes: 5 additions & 5 deletions async/src/wrap/prod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@ use std::io;
impl<R: AsyncRbRef> Producer for AsyncProd<R> {
#[inline]
unsafe fn set_write_index(&self, value: usize) {
self.base.set_write_index(value)
self.base().set_write_index(value)
}

#[inline]
fn try_push(&mut self, elem: Self::Item) -> Result<(), Self::Item> {
self.base.try_push(elem)
self.base_mut().try_push(elem)
}
#[inline]
fn push_iter<I: Iterator<Item = Self::Item>>(&mut self, iter: I) -> usize {
self.base.push_iter(iter)
self.base_mut().push_iter(iter)
}
#[inline]
fn push_slice(&mut self, elems: &[Self::Item]) -> usize
where
Self::Item: Copy,
{
self.base.push_slice(elems)
self.base_mut().push_slice(elems)
}
}

Expand All @@ -43,7 +43,7 @@ impl<R: AsyncRbRef> AsyncProducer for AsyncProd<R> {

#[inline]
fn close(&mut self) {
self.base.close();
drop(self.base.take());
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/rb/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ impl<S: Storage> Consumer for LocalRb<S> {

impl<S: Storage> RingBuffer for LocalRb<S> {
#[inline]
fn hold_read(&self, flag: bool) {
unsafe fn hold_read(&self, flag: bool) {
self.read.held.set(flag)
}
#[inline]
fn hold_write(&self, flag: bool) {
unsafe fn hold_write(&self, flag: bool) {
self.write.held.set(flag)
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/rb/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,11 @@ impl<S: Storage> Consumer for SharedRb<S> {

impl<S: Storage> RingBuffer for SharedRb<S> {
#[inline]
fn hold_read(&self, flag: bool) {
unsafe fn hold_read(&self, flag: bool) {
self.read_held.store(flag, Ordering::Relaxed)
}
#[inline]
fn hold_write(&self, flag: bool) {
unsafe fn hold_write(&self, flag: bool) {
self.write_held.store(flag, Ordering::Relaxed)
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/traits/ring_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ pub trait RingBuffer: Observer + Consumer + Producer {
});
}

fn hold_read(&self, flag: bool);
fn hold_write(&self, flag: bool);
unsafe fn hold_read(&self, flag: bool);
unsafe fn hold_write(&self, flag: bool);
}
4 changes: 0 additions & 4 deletions src/wrap/caching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ impl<R: RbRef, const P: bool, const C: bool> Caching<R, P, C> {
pub fn freeze(self) -> Frozen<R, P, C> {
self.frozen
}

pub fn close(&mut self) {
self.frozen.close();
}
}

impl<R: RbRef, const P: bool, const C: bool> ToRbRef for Caching<R, P, C> {
Expand Down
19 changes: 12 additions & 7 deletions src/wrap/direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ impl<R: RbRef, const P: bool, const C: bool> Direct<R, P, C> {
pub fn new(rb: R) -> Self {
if P {
assert!(!rb.deref().write_is_held());
rb.deref().hold_write(true);
unsafe { rb.deref().hold_write(true) };
}
if C {
assert!(!rb.deref().read_is_held());
rb.deref().hold_read(true);
unsafe { rb.deref().hold_read(true) };
}
Self { rb }
}
Expand All @@ -55,7 +55,10 @@ impl<R: RbRef, const P: bool, const C: bool> Direct<R, P, C> {
unsafe { Frozen::new_unchecked(ptr::read(&this.rb)) }
}

pub fn close(&mut self) {
/// # Safety
///
/// Must not be used after this call.
unsafe fn close(&mut self) {
if P {
self.rb().hold_write(false);
}
Expand All @@ -71,9 +74,11 @@ impl<R: RbRef, const P: bool, const C: bool> ToRbRef for Direct<R, P, C> {
&self.rb
}
fn into_rb_ref(mut self) -> R {
self.close();
let this = ManuallyDrop::new(self);
unsafe { ptr::read(&this.rb) }
unsafe {
self.close();
let this = ManuallyDrop::new(self);
ptr::read(&this.rb)
}
}
}

Expand Down Expand Up @@ -133,7 +138,7 @@ impl<R: RbRef> Consumer for Cons<R> {

impl<R: RbRef, const P: bool, const C: bool> Drop for Direct<R, P, C> {
fn drop(&mut self) {
self.close();
unsafe { self.close() };
}
}

Expand Down
16 changes: 9 additions & 7 deletions src/wrap/frozen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ impl<R: RbRef, const P: bool, const C: bool> Frozen<R, P, C> {
pub fn new(rb: R) -> Self {
if P {
assert!(!rb.deref().write_is_held());
rb.deref().hold_write(true);
unsafe { rb.deref().hold_write(true) };
}
if C {
assert!(!rb.deref().read_is_held());
rb.deref().hold_read(true);
unsafe { rb.deref().hold_read(true) };
}
unsafe { Self::new_unchecked(rb) }
}
Expand All @@ -62,7 +62,7 @@ impl<R: RbRef, const P: bool, const C: bool> Frozen<R, P, C> {
Obs::new(self.rb.clone())
}

pub fn close(&mut self) {
unsafe fn close(&mut self) {
if P {
self.rb().hold_write(false);
}
Expand All @@ -80,9 +80,11 @@ impl<R: RbRef, const P: bool, const C: bool> ToRbRef for Frozen<R, P, C> {
}
fn into_rb_ref(mut self) -> R {
self.commit();
self.close();
let this = ManuallyDrop::new(self);
unsafe { ptr::read(&this.rb) }
unsafe {
self.close();
let this = ManuallyDrop::new(self);
ptr::read(&this.rb)
}
}
}

Expand Down Expand Up @@ -187,7 +189,7 @@ impl<R: RbRef> Consumer for FrozenCons<R> {
impl<R: RbRef, const P: bool, const C: bool> Drop for Frozen<R, P, C> {
fn drop(&mut self) {
self.commit();
self.close();
unsafe { self.close() };
}
}

Expand Down

0 comments on commit 44e6806

Please sign in to comment.