diff --git a/library/std/src/sys/sgx/thread_local_key.rs b/library/std/src/sys/sgx/thread_local_key.rs index b21784475f0d2..c7a57d3a3d47e 100644 --- a/library/std/src/sys/sgx/thread_local_key.rs +++ b/library/std/src/sys/sgx/thread_local_key.rs @@ -21,8 +21,3 @@ pub unsafe fn get(key: Key) -> *mut u8 { pub unsafe fn destroy(key: Key) { Tls::destroy(AbiKey::from_usize(key)) } - -#[inline] -pub fn requires_synchronized_create() -> bool { - false -} diff --git a/library/std/src/sys/solid/thread_local_key.rs b/library/std/src/sys/solid/thread_local_key.rs index b17521f701daf..b37bf99969887 100644 --- a/library/std/src/sys/solid/thread_local_key.rs +++ b/library/std/src/sys/solid/thread_local_key.rs @@ -19,8 +19,3 @@ pub unsafe fn get(_key: Key) -> *mut u8 { pub unsafe fn destroy(_key: Key) { panic!("should not be used on the solid target"); } - -#[inline] -pub fn requires_synchronized_create() -> bool { - panic!("should not be used on the solid target"); -} diff --git a/library/std/src/sys/unix/thread_local_key.rs b/library/std/src/sys/unix/thread_local_key.rs index 2c5b94b1e61e5..2b2d079ee4d01 100644 --- a/library/std/src/sys/unix/thread_local_key.rs +++ b/library/std/src/sys/unix/thread_local_key.rs @@ -27,8 +27,3 @@ pub unsafe fn destroy(key: Key) { let r = libc::pthread_key_delete(key); debug_assert_eq!(r, 0); } - -#[inline] -pub fn requires_synchronized_create() -> bool { - false -} diff --git a/library/std/src/sys/unsupported/thread_local_key.rs b/library/std/src/sys/unsupported/thread_local_key.rs index c31b61cbf56d3..b6e5e4cd2e197 100644 --- a/library/std/src/sys/unsupported/thread_local_key.rs +++ b/library/std/src/sys/unsupported/thread_local_key.rs @@ -19,8 +19,3 @@ pub unsafe fn get(_key: Key) -> *mut u8 { pub unsafe fn destroy(_key: Key) { panic!("should not be used on this target"); } - -#[inline] -pub fn requires_synchronized_create() -> bool { - panic!("should not be used on this target"); -} diff --git a/library/std/src/sys/windows/c.rs b/library/std/src/sys/windows/c.rs index c61a7e7d1e4ab..732e227d7e277 100644 --- a/library/std/src/sys/windows/c.rs +++ b/library/std/src/sys/windows/c.rs @@ -71,6 +71,7 @@ pub type BCRYPT_ALG_HANDLE = LPVOID; pub type PCONDITION_VARIABLE = *mut CONDITION_VARIABLE; pub type PLARGE_INTEGER = *mut c_longlong; pub type PSRWLOCK = *mut SRWLOCK; +pub type LPINIT_ONCE = *mut INIT_ONCE; pub type SOCKET = crate::os::windows::raw::SOCKET; pub type socklen_t = c_int; @@ -194,6 +195,9 @@ pub const DUPLICATE_SAME_ACCESS: DWORD = 0x00000002; pub const CONDITION_VARIABLE_INIT: CONDITION_VARIABLE = CONDITION_VARIABLE { ptr: ptr::null_mut() }; pub const SRWLOCK_INIT: SRWLOCK = SRWLOCK { ptr: ptr::null_mut() }; +pub const INIT_ONCE_STATIC_INIT: INIT_ONCE = INIT_ONCE { ptr: ptr::null_mut() }; + +pub const INIT_ONCE_INIT_FAILED: DWORD = 0x00000004; pub const DETACHED_PROCESS: DWORD = 0x00000008; pub const CREATE_NEW_PROCESS_GROUP: DWORD = 0x00000200; @@ -565,6 +569,10 @@ pub struct CONDITION_VARIABLE { pub struct SRWLOCK { pub ptr: LPVOID, } +#[repr(C)] +pub struct INIT_ONCE { + pub ptr: LPVOID, +} #[repr(C)] pub struct REPARSE_MOUNTPOINT_DATA_BUFFER { @@ -955,6 +963,7 @@ extern "system" { pub fn TlsAlloc() -> DWORD; pub fn TlsGetValue(dwTlsIndex: DWORD) -> LPVOID; pub fn TlsSetValue(dwTlsIndex: DWORD, lpTlsvalue: LPVOID) -> BOOL; + pub fn TlsFree(dwTlsIndex: DWORD) -> BOOL; pub fn GetLastError() -> DWORD; pub fn QueryPerformanceFrequency(lpFrequency: *mut LARGE_INTEGER) -> BOOL; pub fn QueryPerformanceCounter(lpPerformanceCount: *mut LARGE_INTEGER) -> BOOL; @@ -1114,6 +1123,14 @@ extern "system" { pub fn TryAcquireSRWLockExclusive(SRWLock: PSRWLOCK) -> BOOLEAN; pub fn TryAcquireSRWLockShared(SRWLock: PSRWLOCK) -> BOOLEAN; + pub fn InitOnceBeginInitialize( + lpInitOnce: LPINIT_ONCE, + dwFlags: DWORD, + fPending: LPBOOL, + lpContext: *mut LPVOID, + ) -> BOOL; + pub fn InitOnceComplete(lpInitOnce: LPINIT_ONCE, dwFlags: DWORD, lpContext: LPVOID) -> BOOL; + pub fn CompareStringOrdinal( lpString1: LPCWSTR, cchCount1: c_int, diff --git a/library/std/src/sys/windows/thread_local_key.rs b/library/std/src/sys/windows/thread_local_key.rs index ec670238e6f0e..17628b7579b8d 100644 --- a/library/std/src/sys/windows/thread_local_key.rs +++ b/library/std/src/sys/windows/thread_local_key.rs @@ -1,11 +1,16 @@ -use crate::mem::ManuallyDrop; +use crate::cell::UnsafeCell; use crate::ptr; -use crate::sync::atomic::AtomicPtr; -use crate::sync::atomic::Ordering::SeqCst; +use crate::sync::atomic::{ + AtomicPtr, AtomicU32, + Ordering::{AcqRel, Acquire, Relaxed, Release}, +}; use crate::sys::c; -pub type Key = c::DWORD; -pub type Dtor = unsafe extern "C" fn(*mut u8); +#[cfg(test)] +mod tests; + +type Key = c::DWORD; +type Dtor = unsafe extern "C" fn(*mut u8); // Turns out, like pretty much everything, Windows is pretty close the // functionality that Unix provides, but slightly different! In the case of @@ -22,60 +27,109 @@ pub type Dtor = unsafe extern "C" fn(*mut u8); // To accomplish this feat, we perform a number of threads, all contained // within this module: // -// * All TLS destructors are tracked by *us*, not the windows runtime. This +// * All TLS destructors are tracked by *us*, not the Windows runtime. This // means that we have a global list of destructors for each TLS key that // we know about. // * When a thread exits, we run over the entire list and run dtors for all // non-null keys. This attempts to match Unix semantics in this regard. // -// This ends up having the overhead of using a global list, having some -// locks here and there, and in general just adding some more code bloat. We -// attempt to optimize runtime by forgetting keys that don't have -// destructors, but this only gets us so far. -// // For more details and nitty-gritty, see the code sections below! // // [1]: https://www.codeproject.com/Articles/8113/Thread-Local-Storage-The-C-Way -// [2]: https://github.com/ChromiumWebApps/chromium/blob/master/base -// /threading/thread_local_storage_win.cc#L42 +// [2]: https://github.com/ChromiumWebApps/chromium/blob/master/base/threading/thread_local_storage_win.cc#L42 -// ------------------------------------------------------------------------- -// Native bindings -// -// This section is just raw bindings to the native functions that Windows -// provides, There's a few extra calls to deal with destructors. +pub struct StaticKey { + /// The key value shifted up by one. Since TLS_OUT_OF_INDEXES == DWORD::MAX + /// is not a valid key value, this allows us to use zero as sentinel value + /// without risking overflow. + key: AtomicU32, + dtor: Option, + next: AtomicPtr, + /// Currently, destructors cannot be unregistered, so we cannot use racy + /// initialization for keys. Instead, we need synchronize initialization. + /// Use the Windows-provided `Once` since it does not require TLS. + once: UnsafeCell, +} -#[inline] -pub unsafe fn create(dtor: Option) -> Key { - let key = c::TlsAlloc(); - assert!(key != c::TLS_OUT_OF_INDEXES); - if let Some(f) = dtor { - register_dtor(key, f); +impl StaticKey { + #[inline] + pub const fn new(dtor: Option) -> StaticKey { + StaticKey { + key: AtomicU32::new(0), + dtor, + next: AtomicPtr::new(ptr::null_mut()), + once: UnsafeCell::new(c::INIT_ONCE_STATIC_INIT), + } } - key -} -#[inline] -pub unsafe fn set(key: Key, value: *mut u8) { - let r = c::TlsSetValue(key, value as c::LPVOID); - debug_assert!(r != 0); -} + #[inline] + pub unsafe fn set(&'static self, val: *mut u8) { + let r = c::TlsSetValue(self.key(), val.cast()); + debug_assert_eq!(r, c::TRUE); + } -#[inline] -pub unsafe fn get(key: Key) -> *mut u8 { - c::TlsGetValue(key) as *mut u8 -} + #[inline] + pub unsafe fn get(&'static self) -> *mut u8 { + c::TlsGetValue(self.key()).cast() + } -#[inline] -pub unsafe fn destroy(_key: Key) { - rtabort!("can't destroy tls keys on windows") -} + #[inline] + unsafe fn key(&'static self) -> Key { + match self.key.load(Acquire) { + 0 => self.init(), + key => key - 1, + } + } + + #[cold] + unsafe fn init(&'static self) -> Key { + if self.dtor.is_some() { + let mut pending = c::FALSE; + let r = c::InitOnceBeginInitialize(self.once.get(), 0, &mut pending, ptr::null_mut()); + assert_eq!(r, c::TRUE); -#[inline] -pub fn requires_synchronized_create() -> bool { - true + if pending == c::FALSE { + // Some other thread initialized the key, load it. + self.key.load(Relaxed) - 1 + } else { + let key = c::TlsAlloc(); + if key == c::TLS_OUT_OF_INDEXES { + // Wakeup the waiting threads before panicking to avoid deadlock. + c::InitOnceComplete(self.once.get(), c::INIT_ONCE_INIT_FAILED, ptr::null_mut()); + panic!("out of TLS indexes"); + } + + self.key.store(key + 1, Release); + register_dtor(self); + + let r = c::InitOnceComplete(self.once.get(), 0, ptr::null_mut()); + debug_assert_eq!(r, c::TRUE); + + key + } + } else { + // If there is no destructor to clean up, we can use racy initialization. + + let key = c::TlsAlloc(); + assert_ne!(key, c::TLS_OUT_OF_INDEXES, "out of TLS indexes"); + + match self.key.compare_exchange(0, key + 1, AcqRel, Acquire) { + Ok(_) => key, + Err(new) => { + // Some other thread completed initialization first, so destroy + // our key and use theirs. + let r = c::TlsFree(key); + debug_assert_eq!(r, c::TRUE); + new - 1 + } + } + } + } } +unsafe impl Send for StaticKey {} +unsafe impl Sync for StaticKey {} + // ------------------------------------------------------------------------- // Dtor registration // @@ -96,29 +150,21 @@ pub fn requires_synchronized_create() -> bool { // Typically processes have a statically known set of TLS keys which is pretty // small, and we'd want to keep this memory alive for the whole process anyway // really. -// -// Perhaps one day we can fold the `Box` here into a static allocation, -// expanding the `StaticKey` structure to contain not only a slot for the TLS -// key but also a slot for the destructor queue on windows. An optimization for -// another day! - -static DTORS: AtomicPtr = AtomicPtr::new(ptr::null_mut()); - -struct Node { - dtor: Dtor, - key: Key, - next: *mut Node, -} -unsafe fn register_dtor(key: Key, dtor: Dtor) { - let mut node = ManuallyDrop::new(Box::new(Node { key, dtor, next: ptr::null_mut() })); +static DTORS: AtomicPtr = AtomicPtr::new(ptr::null_mut()); - let mut head = DTORS.load(SeqCst); +/// Should only be called once per key, otherwise loops or breaks may occur in +/// the linked list. +unsafe fn register_dtor(key: &'static StaticKey) { + let this = <*const StaticKey>::cast_mut(key); + // Use acquire ordering to pass along the changes done by the previously + // registered keys when we store the new head with release ordering. + let mut head = DTORS.load(Acquire); loop { - node.next = head; - match DTORS.compare_exchange(head, &mut **node, SeqCst, SeqCst) { - Ok(_) => return, // nothing to drop, we successfully added the node to the list - Err(cur) => head = cur, + key.next.store(head, Relaxed); + match DTORS.compare_exchange_weak(head, this, Release, Acquire) { + Ok(_) => break, + Err(new) => head = new, } } } @@ -214,25 +260,29 @@ unsafe extern "system" fn on_tls_callback(h: c::LPVOID, dwReason: c::DWORD, pv: unsafe fn reference_tls_used() {} } -#[allow(dead_code)] // actually called above +#[allow(dead_code)] // actually called below unsafe fn run_dtors() { - let mut any_run = true; for _ in 0..5 { - if !any_run { - break; - } - any_run = false; - let mut cur = DTORS.load(SeqCst); + let mut any_run = false; + + // Use acquire ordering to observe key initialization. + let mut cur = DTORS.load(Acquire); while !cur.is_null() { - let ptr = c::TlsGetValue((*cur).key); + let key = (*cur).key.load(Relaxed) - 1; + let dtor = (*cur).dtor.unwrap(); + let ptr = c::TlsGetValue(key); if !ptr.is_null() { - c::TlsSetValue((*cur).key, ptr::null_mut()); - ((*cur).dtor)(ptr as *mut _); + c::TlsSetValue(key, ptr::null_mut()); + dtor(ptr as *mut _); any_run = true; } - cur = (*cur).next; + cur = (*cur).next.load(Relaxed); + } + + if !any_run { + break; } } } diff --git a/library/std/src/sys/windows/thread_local_key/tests.rs b/library/std/src/sys/windows/thread_local_key/tests.rs new file mode 100644 index 0000000000000..c95f383fb90e3 --- /dev/null +++ b/library/std/src/sys/windows/thread_local_key/tests.rs @@ -0,0 +1,53 @@ +use super::StaticKey; +use crate::ptr; + +#[test] +fn smoke() { + static K1: StaticKey = StaticKey::new(None); + static K2: StaticKey = StaticKey::new(None); + + unsafe { + assert!(K1.get().is_null()); + assert!(K2.get().is_null()); + K1.set(ptr::invalid_mut(1)); + K2.set(ptr::invalid_mut(2)); + assert_eq!(K1.get() as usize, 1); + assert_eq!(K2.get() as usize, 2); + } +} + +#[test] +fn destructors() { + use crate::mem::ManuallyDrop; + use crate::sync::Arc; + use crate::thread; + + unsafe extern "C" fn destruct(ptr: *mut u8) { + drop(Arc::from_raw(ptr as *const ())); + } + + static KEY: StaticKey = StaticKey::new(Some(destruct)); + + let shared1 = Arc::new(()); + let shared2 = Arc::clone(&shared1); + + unsafe { + assert!(KEY.get().is_null()); + KEY.set(Arc::into_raw(shared1) as *mut u8); + } + + thread::spawn(move || unsafe { + assert!(KEY.get().is_null()); + KEY.set(Arc::into_raw(shared2) as *mut u8); + }) + .join() + .unwrap(); + + // Leak the Arc, let the TLS destructor clean it up. + let shared1 = unsafe { ManuallyDrop::new(Arc::from_raw(KEY.get() as *const ())) }; + assert_eq!( + Arc::strong_count(&shared1), + 1, + "destructor should have dropped the other reference on thread exit" + ); +} diff --git a/library/std/src/sys_common/mod.rs b/library/std/src/sys_common/mod.rs index e4dd0253668b8..8c19f9332dc56 100644 --- a/library/std/src/sys_common/mod.rs +++ b/library/std/src/sys_common/mod.rs @@ -34,10 +34,17 @@ pub mod rwlock; pub mod thread; pub mod thread_info; pub mod thread_local_dtor; -pub mod thread_local_key; pub mod thread_parker; pub mod wtf8; +cfg_if::cfg_if! { + if #[cfg(target_os = "windows")] { + pub use crate::sys::thread_local_key; + } else { + pub mod thread_local_key; + } +} + cfg_if::cfg_if! { if #[cfg(any(target_os = "l4re", target_os = "hermit", diff --git a/library/std/src/sys_common/thread_local_key.rs b/library/std/src/sys_common/thread_local_key.rs index 032bf604d7388..747579f178127 100644 --- a/library/std/src/sys_common/thread_local_key.rs +++ b/library/std/src/sys_common/thread_local_key.rs @@ -53,7 +53,6 @@ mod tests; use crate::sync::atomic::{self, AtomicUsize, Ordering}; use crate::sys::thread_local_key as imp; -use crate::sys_common::mutex::StaticMutex; /// A type for TLS keys that are statically allocated. /// @@ -151,25 +150,6 @@ impl StaticKey { } unsafe fn lazy_init(&self) -> usize { - // Currently the Windows implementation of TLS is pretty hairy, and - // it greatly simplifies creation if we just synchronize everything. - // - // Additionally a 0-index of a tls key hasn't been seen on windows, so - // we just simplify the whole branch. - if imp::requires_synchronized_create() { - // We never call `INIT_LOCK.init()`, so it is UB to attempt to - // acquire this mutex reentrantly! - static INIT_LOCK: StaticMutex = StaticMutex::new(); - let _guard = INIT_LOCK.lock(); - let mut key = self.key.load(Ordering::SeqCst); - if key == 0 { - key = imp::create(self.dtor) as usize; - self.key.store(key, Ordering::SeqCst); - } - rtassert!(key != 0); - return key; - } - // POSIX allows the key created here to be 0, but the compare_exchange // below relies on using 0 as a sentinel value to check who won the // race to set the shared TLS key. As far as I know, there is no @@ -232,8 +212,6 @@ impl Key { impl Drop for Key { fn drop(&mut self) { - // Right now Windows doesn't support TLS key destruction, but this also - // isn't used anywhere other than tests, so just leak the TLS key. - // unsafe { imp::destroy(self.key) } + unsafe { imp::destroy(self.key) } } }