diff --git a/components/resource_metering/src/lib.rs b/components/resource_metering/src/lib.rs index 2a78a23a4a1..56ce7ed900d 100644 --- a/components/resource_metering/src/lib.rs +++ b/components/resource_metering/src/lib.rs @@ -259,11 +259,38 @@ impl SharedTagPtr { } #[cfg(test)] -mod tests { +pub(crate) mod tests { use super::*; + use std::sync::{Mutex, MutexGuard}; + + use crate::threadlocal::take_thread_registrations; + use lazy_static::lazy_static; + + /// Tests that access [crate::threadlocal::THREAD_REGISTER_BUFFER] or [crate::config::GLOBAL_ENABLE] + /// need to be run sequentially. A helper function to + pub fn sequential_test() -> TestGuard { + TestGuard { + _guard: SEQ_LOCK.lock().unwrap(), + } + } + + lazy_static! { + static ref SEQ_LOCK: Mutex<()> = Mutex::new(()); + } + pub struct TestGuard { + _guard: MutexGuard<'static, ()>, + } + impl Drop for TestGuard { + fn drop(&mut self) { + take_thread_registrations(|_| {}); + } + } + #[test] fn test_attach() { + let _g = crate::tests::sequential_test(); + // Use a thread created by ourself. If we use unit test thread directly, // the test results may be affected by parallel testing. std::thread::spawn(|| { diff --git a/components/resource_metering/src/recorder/mod.rs b/components/resource_metering/src/recorder/mod.rs index 98c8b2d4dd0..98ea6595b40 100644 --- a/components/resource_metering/src/recorder/mod.rs +++ b/components/resource_metering/src/recorder/mod.rs @@ -1,7 +1,7 @@ // Copyright 2021 TiKV Project Authors. Licensed under Apache-2.0. use crate::collector::{Collector, CollectorReg, COLLECTOR_REG_CHAN}; -use crate::threadlocal::{register_thread_local_chan_tx, ThreadLocalMsg, ThreadLocalRef}; +use crate::threadlocal::{take_thread_registrations, ThreadLocalRef}; use crate::{RawRecords, SharedTagPtr}; use std::io; @@ -13,7 +13,6 @@ use std::thread::JoinHandle; use std::time::Duration; use collections::HashMap; -use crossbeam::channel::{unbounded, Receiver}; use tikv_util::time::Instant; mod cpu; @@ -95,9 +94,7 @@ pub struct Recorder { records: RawRecords, recorders: Vec>, collectors: HashMap>, - thread_rx: Receiver, thread_stores: HashMap, - destroyed_threads: Vec, last_collect: Instant, last_cleanup: Instant, } @@ -139,10 +136,9 @@ impl Recorder { fn cleanup(&mut self) { if self.last_cleanup.saturating_elapsed().as_secs() > CLEANUP_INTERVAL_SECS { - for id in &self.destroyed_threads { - self.thread_stores.remove(id); - } - self.destroyed_threads.clear(); + self.thread_stores + .drain_filter(|_, t| t.is_thread_down()) + .count(); if self.records.records.capacity() > RECORD_LEN_THRESHOLD && self.records.records.len() < (RECORD_LEN_THRESHOLD / 2) { @@ -190,21 +186,16 @@ impl Recorder { } fn handle_thread_registration(&mut self) { - while let Ok(msg) = self.thread_rx.try_recv() { - match msg { - ThreadLocalMsg::Created(tlr) => { - let id = tlr.id; - let tag = tlr.shared_ptr.clone(); - self.thread_stores.insert(id, tlr); - for r in &mut self.recorders { - r.thread_created(id, tag.clone()); - } - } - ThreadLocalMsg::Destroyed(id) => { - self.destroyed_threads.push(id); + take_thread_registrations(|tlrs| { + for tlr in tlrs { + let id = tlr.id; + let tag = tlr.shared_ptr.clone(); + self.thread_stores.insert(id, tlr); + for r in &mut self.recorders { + r.thread_created(id, tag.clone()); } } - } + }); } } @@ -251,8 +242,6 @@ impl RecorderBuilder { pub fn spawn(self) -> io::Result { let pause = Arc::new(AtomicBool::new(!self.enable)); let precision_ms = self.precision_ms.clone(); - let (tx, rx) = unbounded(); - register_thread_local_chan_tx(tx); let now = Instant::now(); let mut recorder = Recorder { pause: pause.clone(), @@ -260,9 +249,7 @@ impl RecorderBuilder { records: RawRecords::default(), recorders: self.recorders, collectors: HashMap::default(), - thread_rx: rx, thread_stores: HashMap::default(), - destroyed_threads: Vec::new(), last_collect: now, last_cleanup: now, }; @@ -400,8 +387,8 @@ mod tests { #[test] fn test_recorder() { - let (tx, rx) = unbounded(); - register_thread_local_chan_tx(tx); + let _g = crate::tests::sequential_test(); + std::thread::spawn(|| { LOCAL_DATA.with(|_| {}); }) @@ -414,9 +401,7 @@ mod tests { records: RawRecords::default(), recorders: vec![Box::new(MockSubRecorder)], collectors: HashMap::default(), - thread_rx: rx, thread_stores: HashMap::default(), - destroyed_threads: Vec::new(), last_collect: now, last_cleanup: now, }; diff --git a/components/resource_metering/src/recorder/summary.rs b/components/resource_metering/src/recorder/summary.rs index 07c3d76401d..fdc528ba586 100644 --- a/components/resource_metering/src/recorder/summary.rs +++ b/components/resource_metering/src/recorder/summary.rs @@ -70,18 +70,18 @@ impl SubRecorder for SummaryRecorder { #[cfg(test)] mod tests { use super::*; - use crate::threadlocal::{register_thread_local_chan_tx, ThreadLocalMsg}; + + use crate::threadlocal::take_thread_registrations; use crate::{ResourceMeteringTag, TagInfos, GLOBAL_ENABLE}; - use crossbeam::channel::unbounded; + use std::sync::atomic::Ordering::SeqCst; use std::sync::Arc; #[test] fn test_collect() { - GLOBAL_ENABLE.store(true, SeqCst); - let (tx, rx) = unbounded(); - register_thread_local_chan_tx(tx); + let _g = crate::tests::sequential_test(); + GLOBAL_ENABLE.store(true, SeqCst); std::thread::spawn(|| { let tag = ResourceMeteringTag { infos: Arc::new(TagInfos { @@ -140,11 +140,12 @@ mod tests { let mut records = RawRecords::default(); let mut thread_stores = HashMap::default(); - while let Ok(msg) = rx.try_recv() { - if let ThreadLocalMsg::Created(tlr) = msg { + take_thread_registrations(|tlrs| { + for tlr in tlrs { thread_stores.insert(tlr.id, tlr); } - } + }); + let mut recorder = SummaryRecorder::default(); recorder.collect(&mut records, &mut thread_stores); assert!(!records.records.is_empty()); diff --git a/components/resource_metering/src/threadlocal.rs b/components/resource_metering/src/threadlocal.rs index 6fd96f0a9f6..2c2968ca2f2 100644 --- a/components/resource_metering/src/threadlocal.rs +++ b/components/resource_metering/src/threadlocal.rs @@ -4,23 +4,31 @@ use crate::model::SummaryRecord; use crate::{utils, ResourceMeteringTag, SharedTagPtr}; use std::cell::Cell; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; +use std::vec::Drain; use collections::HashMap; -use crossbeam::channel::Sender; use lazy_static::lazy_static; lazy_static! { - /// `THREAD_LOCAL_CHANS` is used to transfer the necessary thread registration events. - static ref THREAD_LOCAL_CHANS: Mutex>> = Mutex::new(Vec::new()); + /// `THREAD_REGISTER_BUFFER` is used to store the new thread registrations. + static ref THREAD_REGISTER_BUFFER: Mutex> = Mutex::new(Vec::new()); +} + +pub fn take_thread_registrations(mut consume: F) -> T +where + F: FnMut(Drain) -> T, +{ + consume(THREAD_REGISTER_BUFFER.lock().unwrap().drain(..)) } thread_local! { /// `LOCAL_DATA` is a thread-localized instance of [ThreadLocalData]. /// - /// When a new thread tries to read `LOCAL_DATA`, it will actively send a message - /// to [THREAD_LOCAL_CHANS] during the initialization phase of thread local storage. - /// The message([ThreadLocalRef]) contains the thread id and some references to + /// When a new thread tries to read `LOCAL_DATA`, it will actively store its [ThreadLocalRef] + /// to [THREAD_REGISTER_BUFFER] during the initialization phase of thread local storage. + /// The [ThreadLocalRef] contains the thread id and some references to /// the thread local fields. pub static LOCAL_DATA: ThreadLocalData = { let local_data = ThreadLocalData { @@ -28,15 +36,17 @@ thread_local! { shared_ptr: SharedTagPtr::default(), summary_cur_record: Arc::new(SummaryRecord::default()), summary_records: Arc::new(Mutex::new(HashMap::default())), + is_thread_down: Arc::new(AtomicBool::new(false)), }; - THREAD_LOCAL_CHANS.lock().unwrap().iter().for_each(|tx| { - tx.send(ThreadLocalMsg::Created(ThreadLocalRef{ + THREAD_REGISTER_BUFFER.lock().unwrap().push( + ThreadLocalRef { id: utils::thread_id(), shared_ptr: local_data.shared_ptr.clone(), summary_cur_record: local_data.summary_cur_record.clone(), summary_records: local_data.summary_records.clone(), - })).ok(); - }); + is_down: local_data.is_thread_down.clone(), + } + ); local_data }; } @@ -51,13 +61,12 @@ pub struct ThreadLocalData { pub shared_ptr: SharedTagPtr, pub summary_cur_record: Arc, pub summary_records: Arc>>, + pub is_thread_down: Arc, } impl Drop for ThreadLocalData { fn drop(&mut self) { - THREAD_LOCAL_CHANS.lock().unwrap().iter().for_each(|tx| { - tx.send(ThreadLocalMsg::Destroyed(utils::thread_id())).ok(); - }); + self.is_thread_down.store(true, Ordering::SeqCst); } } @@ -67,50 +76,53 @@ pub struct ThreadLocalRef { pub shared_ptr: SharedTagPtr, pub summary_cur_record: Arc, pub summary_records: Arc>>, + pub is_down: Arc, } -/// This enum is transmitted as a event in [THREAD_LOCAL_CHANS]. -/// -/// See [LOCAL_DATA] for more information. -#[derive(Debug)] -pub enum ThreadLocalMsg { - Created(ThreadLocalRef), - Destroyed(usize), -} - -/// Register a channel to notify thread creation & destruction events. -pub fn register_thread_local_chan_tx(tx: Sender) { - THREAD_LOCAL_CHANS.lock().unwrap().push(tx); +impl ThreadLocalRef { + pub fn is_thread_down(&self) -> bool { + self.is_down.load(Ordering::SeqCst) + } } #[cfg(test)] mod tests { use super::*; - use crossbeam::channel::unbounded; + use crossbeam::sync::WaitGroup; #[test] - fn test_thread_local_chan() { - let (tx, rx) = unbounded(); - register_thread_local_chan_tx(tx); - LOCAL_DATA.with(|_| {}); // Just to trigger registration. - std::thread::spawn(move || { + fn test_thread_local_registration() { + let _g = crate::tests::sequential_test(); + + let (next_step, stop_t0) = (WaitGroup::new(), WaitGroup::new()); + let (n, s) = (next_step.clone(), stop_t0.clone()); + let t0 = std::thread::spawn(move || { LOCAL_DATA.with(|_| {}); - }) - .join() - .unwrap(); - let mut count = 0; - while let Ok(msg) = rx.try_recv() { - match msg { - ThreadLocalMsg::Created(r) => { - assert_ne!(r.id, 0); - } - ThreadLocalMsg::Destroyed(id) => { - assert_ne!(id, 0); - } - } - count += 1; - } - // This value may be greater than 2 if other test threads access `LOCAL_DATA` in parallel. - assert!(count >= 2); + drop(n); + s.wait(); + }); + next_step.wait(); + + let (next_step, stop_t1) = (WaitGroup::new(), WaitGroup::new()); + let (n, s) = (next_step.clone(), stop_t1.clone()); + let t1 = std::thread::spawn(move || { + LOCAL_DATA.with(|_| {}); + drop(n); + s.wait(); + }); + next_step.wait(); + + let registrations = take_thread_registrations(|t| t.collect::>()); + assert_eq!(registrations.len(), 2); + assert!(!registrations[0].is_thread_down()); + assert!(!registrations[1].is_thread_down()); + + drop(stop_t0); + t0.join().unwrap(); + assert!(registrations[0].is_thread_down()); + + drop(stop_t1); + t1.join().unwrap(); + assert!(registrations[1].is_thread_down()); } }