Skip to content

Commit 7e21850

Browse files
committed
Update ReentrantLock implementation, add CURRENT_ID thread local.
This changes `ReentrantLock` to use `ThreadId` for the thread ownership check instead of the address of a thread local. Unlike TLS blocks, `ThreadId` is guaranteed to be unique across the lifetime of the process, so if any thread ever terminates while holding a `ReentrantLockGuard`, no other thread may ever acquire that lock again. On platforms with 64-bit atomics, this is a very simple change. On other platforms, the approach used is slightly more involved, as explained in the module comment. This also adds a `CURRENT_ID` thread local in addition to the already existing `CURRENT`. This allows us to access the current `ThreadId` without the relatively heavy machinery used by `thread::current().id()`.
1 parent 567096d commit 7e21850

File tree

2 files changed

+144
-26
lines changed

2 files changed

+144
-26
lines changed

std/src/sync/reentrant_lock.rs

+115-23
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#[cfg(all(test, not(target_os = "emscripten")))]
22
mod tests;
33

4+
use cfg_if::cfg_if;
5+
46
use crate::cell::UnsafeCell;
57
use crate::fmt;
68
use crate::ops::Deref;
79
use crate::panic::{RefUnwindSafe, UnwindSafe};
8-
use crate::sync::atomic::{AtomicUsize, Ordering::Relaxed};
910
use crate::sys::sync as sys;
11+
use crate::thread::{current_id, ThreadId};
1012

1113
/// A re-entrant mutual exclusion lock
1214
///
@@ -53,8 +55,8 @@ use crate::sys::sync as sys;
5355
//
5456
// The 'owner' field tracks which thread has locked the mutex.
5557
//
56-
// We use current_thread_unique_ptr() as the thread identifier,
57-
// which is just the address of a thread local variable.
58+
// We use thread::current_id() as the thread identifier, which is just the
59+
// current thread's ThreadId, so it's unique across the process lifetime.
5860
//
5961
// If `owner` is set to the identifier of the current thread,
6062
// we assume the mutex is already locked and instead of locking it again,
@@ -72,14 +74,109 @@ use crate::sys::sync as sys;
7274
// since we're not dealing with multiple threads. If it's not equal,
7375
// synchronization is left to the mutex, making relaxed memory ordering for
7476
// the `owner` field fine in all cases.
77+
//
78+
// On systems without 64 bit atomics we also store the address of a TLS variable
79+
// along the 64-bit TID. We then first check that address against the address
80+
// of that variable on the current thread, and only if they compare equal do we
81+
// compare the actual TIDs. Because we only ever read the TID on the same thread
82+
// that it was written on (or a thread sharing the TLS block with that writer thread),
83+
// we don't need to further synchronize the TID accesses, so they can be regular 64-bit
84+
// non-atomic accesses.
7585
#[unstable(feature = "reentrant_lock", issue = "121440")]
7686
pub struct ReentrantLock<T: ?Sized> {
7787
mutex: sys::Mutex,
78-
owner: AtomicUsize,
88+
owner: Tid,
7989
lock_count: UnsafeCell<u32>,
8090
data: T,
8191
}
8292

93+
cfg_if!(
94+
if #[cfg(target_has_atomic = "64")] {
95+
use crate::sync::atomic::{AtomicU64, Ordering::Relaxed};
96+
97+
struct Tid(AtomicU64);
98+
99+
impl Tid {
100+
const fn new() -> Self {
101+
Self(AtomicU64::new(0))
102+
}
103+
104+
#[inline]
105+
fn contains(&self, owner: ThreadId) -> bool {
106+
owner.as_u64().get() == self.0.load(Relaxed)
107+
}
108+
109+
#[inline]
110+
// This is just unsafe to match the API of the Tid type below.
111+
unsafe fn set(&self, tid: Option<ThreadId>) {
112+
let value = tid.map_or(0, |tid| tid.as_u64().get());
113+
self.0.store(value, Relaxed);
114+
}
115+
}
116+
} else {
117+
/// Returns the address of a TLS variable. This is guaranteed to
118+
/// be unique across all currently alive threads.
119+
fn tls_addr() -> usize {
120+
thread_local! { static X: u8 = const { 0u8 } };
121+
122+
X.with(|p| <*const u8>::addr(p))
123+
}
124+
125+
use crate::sync::atomic::{
126+
AtomicUsize,
127+
Ordering,
128+
};
129+
130+
struct Tid {
131+
// When a thread calls `set()`, this value gets updated to
132+
// the address of a thread local on that thread. This is
133+
// used as a first check in `contains()`; if the `tls_addr`
134+
// doesn't match the TLS address of the current thread, then
135+
// the ThreadId also can't match. Only if the TLS addresses do
136+
// match do we read out the actual TID.
137+
// Note also that we can use relaxed atomic operations here, because
138+
// we only ever read from the tid if `tls_addr` matches the current
139+
// TLS address. In that case, either the the tid has been set by
140+
// the current thread, or by a thread that has terminated before
141+
// the current thread was created. In either case, no further
142+
// synchronization is needed (as per <https://github.com/rust-lang/miri/issues/3450>)
143+
tls_addr: AtomicUsize,
144+
tid: UnsafeCell<u64>,
145+
}
146+
147+
unsafe impl Send for Tid {}
148+
unsafe impl Sync for Tid {}
149+
150+
impl Tid {
151+
const fn new() -> Self {
152+
Self { tls_addr: AtomicUsize::new(0), tid: UnsafeCell::new(0) }
153+
}
154+
155+
#[inline]
156+
// NOTE: This assumes that `owner` is the ID of the current
157+
// thread, and may spuriously return `false` if that's not the case.
158+
fn contains(&self, owner: ThreadId) -> bool {
159+
// SAFETY: See the comments in the struct definition.
160+
self.tls_addr.load(Ordering::Relaxed) == tls_addr()
161+
&& unsafe { *self.tid.get() } == owner.as_u64().get()
162+
}
163+
164+
#[inline]
165+
// This may only be called by one thread at a time, and can lead to
166+
// race conditions otherwise.
167+
unsafe fn set(&self, tid: Option<ThreadId>) {
168+
// It's important that we set `self.tls_addr` to 0 if the tid is
169+
// cleared. Otherwise, there might be race conditions between
170+
// `set()` and `get()`.
171+
let tls_addr = if tid.is_some() { tls_addr() } else { 0 };
172+
let value = tid.map_or(0, |tid| tid.as_u64().get());
173+
self.tls_addr.store(tls_addr, Ordering::Relaxed);
174+
unsafe { *self.tid.get() = value };
175+
}
176+
}
177+
}
178+
);
179+
83180
#[unstable(feature = "reentrant_lock", issue = "121440")]
84181
unsafe impl<T: Send + ?Sized> Send for ReentrantLock<T> {}
85182
#[unstable(feature = "reentrant_lock", issue = "121440")]
@@ -131,7 +228,7 @@ impl<T> ReentrantLock<T> {
131228
pub const fn new(t: T) -> ReentrantLock<T> {
132229
ReentrantLock {
133230
mutex: sys::Mutex::new(),
134-
owner: AtomicUsize::new(0),
231+
owner: Tid::new(),
135232
lock_count: UnsafeCell::new(0),
136233
data: t,
137234
}
@@ -181,14 +278,16 @@ impl<T: ?Sized> ReentrantLock<T> {
181278
/// assert_eq!(lock.lock().get(), 10);
182279
/// ```
183280
pub fn lock(&self) -> ReentrantLockGuard<'_, T> {
184-
let this_thread = current_thread_unique_ptr();
185-
// Safety: We only touch lock_count when we own the lock.
281+
let this_thread = current_id();
282+
// Safety: We only touch lock_count when we own the inner mutex.
283+
// Additionally, we only call `self.owner.set()` while holding
284+
// the inner mutex, so no two threads can call it concurrently.
186285
unsafe {
187-
if self.owner.load(Relaxed) == this_thread {
286+
if self.owner.contains(this_thread) {
188287
self.increment_lock_count().expect("lock count overflow in reentrant mutex");
189288
} else {
190289
self.mutex.lock();
191-
self.owner.store(this_thread, Relaxed);
290+
self.owner.set(Some(this_thread));
192291
debug_assert_eq!(*self.lock_count.get(), 0);
193292
*self.lock_count.get() = 1;
194293
}
@@ -223,14 +322,16 @@ impl<T: ?Sized> ReentrantLock<T> {
223322
///
224323
/// This function does not block.
225324
pub(crate) fn try_lock(&self) -> Option<ReentrantLockGuard<'_, T>> {
226-
let this_thread = current_thread_unique_ptr();
227-
// Safety: We only touch lock_count when we own the lock.
325+
let this_thread = current_id();
326+
// Safety: We only touch lock_count when we own the inner mutex.
327+
// Additionally, we only call `self.owner.set()` while holding
328+
// the inner mutex, so no two threads can call it concurrently.
228329
unsafe {
229-
if self.owner.load(Relaxed) == this_thread {
330+
if self.owner.contains(this_thread) {
230331
self.increment_lock_count()?;
231332
Some(ReentrantLockGuard { lock: self })
232333
} else if self.mutex.try_lock() {
233-
self.owner.store(this_thread, Relaxed);
334+
self.owner.set(Some(this_thread));
234335
debug_assert_eq!(*self.lock_count.get(), 0);
235336
*self.lock_count.get() = 1;
236337
Some(ReentrantLockGuard { lock: self })
@@ -303,18 +404,9 @@ impl<T: ?Sized> Drop for ReentrantLockGuard<'_, T> {
303404
unsafe {
304405
*self.lock.lock_count.get() -= 1;
305406
if *self.lock.lock_count.get() == 0 {
306-
self.lock.owner.store(0, Relaxed);
407+
self.lock.owner.set(None);
307408
self.lock.mutex.unlock();
308409
}
309410
}
310411
}
311412
}
312-
313-
/// Get an address that is unique per running thread.
314-
///
315-
/// This can be used as a non-null usize-sized ID.
316-
pub(crate) fn current_thread_unique_ptr() -> usize {
317-
// Use a non-drop type to make sure it's still available during thread destruction.
318-
thread_local! { static X: u8 = const { 0 } }
319-
X.with(|x| <*const _>::addr(x))
320-
}

std/src/thread/mod.rs

+29-3
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@
159159
mod tests;
160160

161161
use crate::any::Any;
162-
use crate::cell::{OnceCell, UnsafeCell};
162+
use crate::cell::{Cell, OnceCell, UnsafeCell};
163163
use crate::env;
164164
use crate::ffi::{CStr, CString};
165165
use crate::fmt;
@@ -698,17 +698,22 @@ where
698698
}
699699

700700
thread_local! {
701+
// Invariant: `CURRENT` and `CURRENT_ID` will always be initialized together.
702+
// If `CURRENT` is initialized, then `CURRENT_ID` will hold the same value
703+
// as `CURRENT.id()`.
701704
static CURRENT: OnceCell<Thread> = const { OnceCell::new() };
705+
static CURRENT_ID: Cell<Option<ThreadId>> = const { Cell::new(None) };
702706
}
703707

704708
/// Sets the thread handle for the current thread.
705709
///
706710
/// Aborts if the handle has been set already to reduce code size.
707711
pub(crate) fn set_current(thread: Thread) {
712+
let tid = thread.id();
708713
// Using `unwrap` here can add ~3kB to the binary size. We have complete
709714
// control over where this is called, so just abort if there is a bug.
710715
CURRENT.with(|current| match current.set(thread) {
711-
Ok(()) => {}
716+
Ok(()) => CURRENT_ID.set(Some(tid)),
712717
Err(_) => rtabort!("thread::set_current should only be called once per thread"),
713718
});
714719
}
@@ -718,7 +723,28 @@ pub(crate) fn set_current(thread: Thread) {
718723
/// In contrast to the public `current` function, this will not panic if called
719724
/// from inside a TLS destructor.
720725
pub(crate) fn try_current() -> Option<Thread> {
721-
CURRENT.try_with(|current| current.get_or_init(|| Thread::new_unnamed()).clone()).ok()
726+
CURRENT
727+
.try_with(|current| {
728+
current
729+
.get_or_init(|| {
730+
let thread = Thread::new_unnamed();
731+
CURRENT_ID.set(Some(thread.id()));
732+
thread
733+
})
734+
.clone()
735+
})
736+
.ok()
737+
}
738+
739+
/// Gets the id of the thread that invokes it.
740+
#[inline]
741+
pub(crate) fn current_id() -> ThreadId {
742+
CURRENT_ID.get().unwrap_or_else(|| {
743+
// If `CURRENT_ID` isn't initialized yet, then `CURRENT` must also not be initialized.
744+
// `current()` will initialize both `CURRENT` and `CURRENT_ID` so subsequent calls to
745+
// `current_id()` will succeed immediately.
746+
current().id()
747+
})
722748
}
723749

724750
/// Gets a handle to the thread that invokes it.

0 commit comments

Comments
 (0)