Skip to content

Commit

Permalink
Do not early return on null bucket_ptr
Browse files Browse the repository at this point in the history
Buckets are allocated on demand based on `Thread::bucket`.
This means that when only threads with a high `id` (and thus high `bucket`) are writing entries into the `ThreadLocal`,
only higher `buckets` will be allocated, and lower buckets will be `null`.

Thus we must not early-return when encounting a `null` bucket.
  • Loading branch information
Swatinem authored and Amanieu committed Feb 20, 2024
1 parent b285630 commit b197719
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 26 deletions.
72 changes: 47 additions & 25 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl<T: Send> Drop for ThreadLocal<T> {
let this_bucket_size = 1 << i;

if bucket_ptr.is_null() {
break;
continue;
}

unsafe { deallocate_bucket(bucket_ptr, this_bucket_size) };
Expand Down Expand Up @@ -205,7 +205,7 @@ impl<T: Send> ThreadLocal<T> {
return Ok(val);
}

Ok(self.insert(create()?))
Ok(self.insert(thread, create()?))
}

fn get_inner(&self, thread: Thread) -> Option<&T> {
Expand All @@ -226,8 +226,7 @@ impl<T: Send> ThreadLocal<T> {
}

#[cold]
fn insert(&self, data: T) -> &T {
let thread = thread_id::get();
fn insert(&self, thread: Thread, data: T) -> &T {
let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire);

Expand Down Expand Up @@ -372,16 +371,14 @@ impl RawIter {
let bucket = unsafe { thread_local.buckets.get_unchecked(self.bucket) };
let bucket = bucket.load(Ordering::Acquire);

if bucket.is_null() {
return None;
}

while self.index < self.bucket_size {
let entry = unsafe { &*bucket.add(self.index) };
self.index += 1;
if entry.present.load(Ordering::Acquire) {
self.yielded += 1;
return Some(unsafe { &*(&*entry.value.get()).as_ptr() });
if !bucket.is_null() {
while self.index < self.bucket_size {
let entry = unsafe { &*bucket.add(self.index) };
self.index += 1;
if entry.present.load(Ordering::Acquire) {
self.yielded += 1;
return Some(unsafe { &*(&*entry.value.get()).as_ptr() });
}
}
}

Expand All @@ -401,16 +398,14 @@ impl RawIter {
let bucket = unsafe { thread_local.buckets.get_unchecked_mut(self.bucket) };
let bucket = *bucket.get_mut();

if bucket.is_null() {
return None;
}

while self.index < self.bucket_size {
let entry = unsafe { &mut *bucket.add(self.index) };
self.index += 1;
if *entry.present.get_mut() {
self.yielded += 1;
return Some(entry);
if !bucket.is_null() {
while self.index < self.bucket_size {
let entry = unsafe { &mut *bucket.add(self.index) };
self.index += 1;
if *entry.present.get_mut() {
self.yielded += 1;
return Some(entry);
}
}
}

Expand Down Expand Up @@ -525,7 +520,8 @@ unsafe fn deallocate_bucket<T>(bucket: *mut Entry<T>, size: usize) {

#[cfg(test)]
mod tests {
use super::ThreadLocal;
use super::*;

use std::cell::RefCell;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::Relaxed;
Expand Down Expand Up @@ -627,6 +623,32 @@ mod tests {
assert_eq!(dropped.load(Relaxed), 1);
}

#[test]
fn test_earlyreturn_buckets() {
struct Dropped(Arc<AtomicUsize>);
impl Drop for Dropped {
fn drop(&mut self) {
self.0.fetch_add(1, Relaxed);
}
}
let dropped = Arc::new(AtomicUsize::new(0));

// We use a high `id` here to guarantee that a lazily allocated bucket somewhere in the middle is used.
// Neither iteration nor `Drop` must early-return on `null` buckets that are used for lower `buckets`.
let thread = Thread::new(1234);
assert!(thread.bucket > 1);

let mut local = ThreadLocal::new();
local.insert(thread, Dropped(dropped.clone()));

let item = local.iter().next().unwrap();
assert_eq!(item.0.load(Relaxed), 0);
let item = local.iter_mut().next().unwrap();
assert_eq!(item.0.load(Relaxed), 0);
drop(local);
assert_eq!(dropped.load(Relaxed), 1);
}

#[test]
fn is_sync() {
fn foo<T: Sync>() {}
Expand Down
2 changes: 1 addition & 1 deletion src/thread_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub(crate) struct Thread {
pub(crate) index: usize,
}
impl Thread {
fn new(id: usize) -> Self {
pub(crate) fn new(id: usize) -> Self {
let bucket = usize::from(POINTER_WIDTH) - ((id + 1).leading_zeros() as usize) - 1;
let bucket_size = 1 << bucket;
let index = id - (bucket_size - 1);
Expand Down

0 comments on commit b197719

Please sign in to comment.