Skip to content

Commit

Permalink
📝 add more pointer check
Browse files Browse the repository at this point in the history
  • Loading branch information
Xudong-Huang committed Nov 16, 2024
1 parent 4f5475a commit ab0b24b
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ const LEADING_BITS: usize = 0;
#[cfg(not(target_pointer_width = "64"))]
const ALIGN_BITS: usize = 2;

const LOWER_MASK: usize = (1 << ALIGN_BITS) - 1;
const HIGHER_MASK: usize = !((1 << (usize::MAX.leading_ones() as usize - LEADING_BITS)) - 1);
const REFCOUNT_MASK: usize = (1 << (LEADING_BITS + ALIGN_BITS)) - 1;

//---------------------------------------------------------------------------------------
// LinkWrapper
//---------------------------------------------------------------------------------------

#[repr(C)]
union Ptr<T> {
addr: usize,
ptr: *const T,
Expand All @@ -40,14 +43,19 @@ impl<T> LinkWrapper<T> {
#[inline]
const fn new(ptr: *const T) -> Self {
let addr: usize = unsafe { Ptr { ptr }.addr };
debug_assert!(addr & LOWER_MASK == 0);
debug_assert!(addr & HIGHER_MASK == 0);
LinkWrapper {
ptr: AtomicUsize::new(addr << LEADING_BITS),
phantom: PhantomData,
}
}

fn update(&self, ptr: *const T) -> Option<Arc<T>> {
let new = unsafe { Ptr { ptr }.addr } << LEADING_BITS;
let addr = unsafe { Ptr { ptr }.addr };
debug_assert!(addr & LOWER_MASK == 0);
debug_assert!(addr & HIGHER_MASK == 0);
let new = addr << LEADING_BITS;
let mut old = self.ptr.load(Ordering::Relaxed) & !REFCOUNT_MASK;

while let Err(addr) =
Expand All @@ -58,7 +66,10 @@ impl<T> LinkWrapper<T> {
core::hint::spin_loop();
}

let ptr = (old >> LEADING_BITS) as *const T;
debug_assert!(old & LOWER_MASK == 0);
debug_assert!(old & HIGHER_MASK == 0);
let addr = old >> LEADING_BITS;
let ptr = unsafe { Ptr { addr }.ptr };
Self::ptr_to_arc(ptr)
}

Expand All @@ -74,11 +85,11 @@ impl<T> LinkWrapper<T> {

#[inline]
fn inc_ref(&self) -> *const T {
let addr = self.ptr.fetch_add(1, Ordering::Relaxed);
let addr = self.ptr.fetch_add(1, Ordering::Release);
let refs = addr & REFCOUNT_MASK;
assert!(refs < REFCOUNT_MASK);
let ptr = (addr & !REFCOUNT_MASK) >> LEADING_BITS;
ptr as *const T
assert!(refs < REFCOUNT_MASK, "Too many references");
let addr = (addr & !REFCOUNT_MASK) >> LEADING_BITS;
unsafe { Ptr { addr }.ptr }
}

#[inline]
Expand Down Expand Up @@ -212,6 +223,9 @@ mod test {

#[test]
fn simple_drop() {
let ptr = Arc::into_raw(Arc::new(10));
let _a = unsafe { Arc::from_raw(ptr) };

static REF: AtomicUsize = AtomicUsize::new(0);
struct Foo(usize);
impl Foo {
Expand Down

0 comments on commit ab0b24b

Please sign in to comment.