Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

opt thread registration #2

Merged
merged 2 commits into from
Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion components/resource_metering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,38 @@ impl SharedTagPtr {
}

#[cfg(test)]
mod tests {
pub(crate) mod tests {
use super::*;

use std::sync::{Mutex, MutexGuard};

use crate::threadlocal::take_thread_registrations;
use lazy_static::lazy_static;

/// Tests that access [crate::threadlocal::THREAD_REGISTER_BUFFER] or [crate::config::GLOBAL_ENABLE]
/// need to be run sequentially. A helper function to
pub fn sequential_test() -> TestGuard {
TestGuard {
_guard: SEQ_LOCK.lock().unwrap(),
}
}

lazy_static! {
static ref SEQ_LOCK: Mutex<()> = Mutex::new(());
}
pub struct TestGuard {
_guard: MutexGuard<'static, ()>,
}
impl Drop for TestGuard {
fn drop(&mut self) {
take_thread_registrations(|_| {});
}
}

#[test]
fn test_attach() {
let _g = crate::tests::sequential_test();

// Use a thread created by ourself. If we use unit test thread directly,
// the test results may be affected by parallel testing.
std::thread::spawn(|| {
Expand Down
43 changes: 14 additions & 29 deletions components/resource_metering/src/recorder/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright 2021 TiKV Project Authors. Licensed under Apache-2.0.

use crate::collector::{Collector, CollectorReg, COLLECTOR_REG_CHAN};
use crate::threadlocal::{register_thread_local_chan_tx, ThreadLocalMsg, ThreadLocalRef};
use crate::threadlocal::{take_thread_registrations, ThreadLocalRef};
use crate::{RawRecords, SharedTagPtr};

use std::io;
Expand All @@ -13,7 +13,6 @@ use std::thread::JoinHandle;
use std::time::Duration;

use collections::HashMap;
use crossbeam::channel::{unbounded, Receiver};
use tikv_util::time::Instant;

mod cpu;
Expand Down Expand Up @@ -95,9 +94,7 @@ pub struct Recorder {
records: RawRecords,
recorders: Vec<Box<dyn SubRecorder + Send>>,
collectors: HashMap<u64, Box<dyn Collector>>,
thread_rx: Receiver<ThreadLocalMsg>,
thread_stores: HashMap<usize, ThreadLocalRef>,
destroyed_threads: Vec<usize>,
last_collect: Instant,
last_cleanup: Instant,
}
Expand Down Expand Up @@ -139,10 +136,9 @@ impl Recorder {

fn cleanup(&mut self) {
if self.last_cleanup.saturating_elapsed().as_secs() > CLEANUP_INTERVAL_SECS {
for id in &self.destroyed_threads {
self.thread_stores.remove(id);
}
self.destroyed_threads.clear();
self.thread_stores
.drain_filter(|_, t| t.is_thread_down())
.count();
if self.records.records.capacity() > RECORD_LEN_THRESHOLD
&& self.records.records.len() < (RECORD_LEN_THRESHOLD / 2)
{
Expand Down Expand Up @@ -190,21 +186,16 @@ impl Recorder {
}

fn handle_thread_registration(&mut self) {
while let Ok(msg) = self.thread_rx.try_recv() {
match msg {
ThreadLocalMsg::Created(tlr) => {
let id = tlr.id;
let tag = tlr.shared_ptr.clone();
self.thread_stores.insert(id, tlr);
for r in &mut self.recorders {
r.thread_created(id, tag.clone());
}
}
ThreadLocalMsg::Destroyed(id) => {
self.destroyed_threads.push(id);
take_thread_registrations(|tlrs| {
for tlr in tlrs {
let id = tlr.id;
let tag = tlr.shared_ptr.clone();
self.thread_stores.insert(id, tlr);
for r in &mut self.recorders {
r.thread_created(id, tag.clone());
}
}
}
});
}
}

Expand Down Expand Up @@ -251,18 +242,14 @@ impl RecorderBuilder {
pub fn spawn(self) -> io::Result<RecorderHandle> {
let pause = Arc::new(AtomicBool::new(!self.enable));
let precision_ms = self.precision_ms.clone();
let (tx, rx) = unbounded();
register_thread_local_chan_tx(tx);
let now = Instant::now();
let mut recorder = Recorder {
pause: pause.clone(),
precision_ms: precision_ms.clone(),
records: RawRecords::default(),
recorders: self.recorders,
collectors: HashMap::default(),
thread_rx: rx,
thread_stores: HashMap::default(),
destroyed_threads: Vec::new(),
last_collect: now,
last_cleanup: now,
};
Expand Down Expand Up @@ -400,8 +387,8 @@ mod tests {

#[test]
fn test_recorder() {
let (tx, rx) = unbounded();
register_thread_local_chan_tx(tx);
let _g = crate::tests::sequential_test();

std::thread::spawn(|| {
LOCAL_DATA.with(|_| {});
})
Expand All @@ -414,9 +401,7 @@ mod tests {
records: RawRecords::default(),
recorders: vec![Box::new(MockSubRecorder)],
collectors: HashMap::default(),
thread_rx: rx,
thread_stores: HashMap::default(),
destroyed_threads: Vec::new(),
last_collect: now,
last_cleanup: now,
};
Expand Down
17 changes: 9 additions & 8 deletions components/resource_metering/src/recorder/summary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,18 @@ impl SubRecorder for SummaryRecorder {
#[cfg(test)]
mod tests {
use super::*;
use crate::threadlocal::{register_thread_local_chan_tx, ThreadLocalMsg};

use crate::threadlocal::take_thread_registrations;
use crate::{ResourceMeteringTag, TagInfos, GLOBAL_ENABLE};
use crossbeam::channel::unbounded;

use std::sync::atomic::Ordering::SeqCst;
use std::sync::Arc;

#[test]
fn test_collect() {
GLOBAL_ENABLE.store(true, SeqCst);
let (tx, rx) = unbounded();
register_thread_local_chan_tx(tx);
let _g = crate::tests::sequential_test();

GLOBAL_ENABLE.store(true, SeqCst);
std::thread::spawn(|| {
let tag = ResourceMeteringTag {
infos: Arc::new(TagInfos {
Expand Down Expand Up @@ -140,11 +140,12 @@ mod tests {

let mut records = RawRecords::default();
let mut thread_stores = HashMap::default();
while let Ok(msg) = rx.try_recv() {
if let ThreadLocalMsg::Created(tlr) = msg {
take_thread_registrations(|tlrs| {
for tlr in tlrs {
thread_stores.insert(tlr.id, tlr);
}
}
});

let mut recorder = SummaryRecorder::default();
recorder.collect(&mut records, &mut thread_stores);
assert!(!records.records.is_empty());
Expand Down
108 changes: 60 additions & 48 deletions components/resource_metering/src/threadlocal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,49 @@ use crate::model::SummaryRecord;
use crate::{utils, ResourceMeteringTag, SharedTagPtr};

use std::cell::Cell;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::vec::Drain;

use collections::HashMap;
use crossbeam::channel::Sender;
use lazy_static::lazy_static;

lazy_static! {
/// `THREAD_LOCAL_CHANS` is used to transfer the necessary thread registration events.
static ref THREAD_LOCAL_CHANS: Mutex<Vec<Sender<ThreadLocalMsg>>> = Mutex::new(Vec::new());
/// `THREAD_REGISTER_BUFFER` is used to store the new thread registrations.
static ref THREAD_REGISTER_BUFFER: Mutex<Vec<ThreadLocalRef>> = Mutex::new(Vec::new());
}

pub fn take_thread_registrations<F, T>(mut consume: F) -> T
where
F: FnMut(Drain<ThreadLocalRef>) -> T,
{
consume(THREAD_REGISTER_BUFFER.lock().unwrap().drain(..))
}

thread_local! {
/// `LOCAL_DATA` is a thread-localized instance of [ThreadLocalData].
///
/// When a new thread tries to read `LOCAL_DATA`, it will actively send a message
/// to [THREAD_LOCAL_CHANS] during the initialization phase of thread local storage.
/// The message([ThreadLocalRef]) contains the thread id and some references to
/// When a new thread tries to read `LOCAL_DATA`, it will actively store its [ThreadLocalRef]
/// to [THREAD_REGISTER_BUFFER] during the initialization phase of thread local storage.
/// The [ThreadLocalRef] contains the thread id and some references to
/// the thread local fields.
pub static LOCAL_DATA: ThreadLocalData = {
let local_data = ThreadLocalData {
is_set: Cell::new(false),
shared_ptr: SharedTagPtr::default(),
summary_cur_record: Arc::new(SummaryRecord::default()),
summary_records: Arc::new(Mutex::new(HashMap::default())),
is_thread_down: Arc::new(AtomicBool::new(false)),
};
THREAD_LOCAL_CHANS.lock().unwrap().iter().for_each(|tx| {
tx.send(ThreadLocalMsg::Created(ThreadLocalRef{
THREAD_REGISTER_BUFFER.lock().unwrap().push(
ThreadLocalRef {
id: utils::thread_id(),
shared_ptr: local_data.shared_ptr.clone(),
summary_cur_record: local_data.summary_cur_record.clone(),
summary_records: local_data.summary_records.clone(),
})).ok();
});
is_down: local_data.is_thread_down.clone(),
}
);
local_data
};
}
Expand All @@ -51,13 +61,12 @@ pub struct ThreadLocalData {
pub shared_ptr: SharedTagPtr,
pub summary_cur_record: Arc<SummaryRecord>,
pub summary_records: Arc<Mutex<HashMap<ResourceMeteringTag, SummaryRecord>>>,
pub is_thread_down: Arc<AtomicBool>,
}

impl Drop for ThreadLocalData {
fn drop(&mut self) {
THREAD_LOCAL_CHANS.lock().unwrap().iter().for_each(|tx| {
tx.send(ThreadLocalMsg::Destroyed(utils::thread_id())).ok();
});
self.is_thread_down.store(true, Ordering::SeqCst);
}
}

Expand All @@ -67,50 +76,53 @@ pub struct ThreadLocalRef {
pub shared_ptr: SharedTagPtr,
pub summary_cur_record: Arc<SummaryRecord>,
pub summary_records: Arc<Mutex<HashMap<ResourceMeteringTag, SummaryRecord>>>,
pub is_down: Arc<AtomicBool>,
}

/// This enum is transmitted as a event in [THREAD_LOCAL_CHANS].
///
/// See [LOCAL_DATA] for more information.
#[derive(Debug)]
pub enum ThreadLocalMsg {
Created(ThreadLocalRef),
Destroyed(usize),
}

/// Register a channel to notify thread creation & destruction events.
pub fn register_thread_local_chan_tx(tx: Sender<ThreadLocalMsg>) {
THREAD_LOCAL_CHANS.lock().unwrap().push(tx);
impl ThreadLocalRef {
pub fn is_thread_down(&self) -> bool {
self.is_down.load(Ordering::SeqCst)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crossbeam::channel::unbounded;
use crossbeam::sync::WaitGroup;

#[test]
fn test_thread_local_chan() {
let (tx, rx) = unbounded();
register_thread_local_chan_tx(tx);
LOCAL_DATA.with(|_| {}); // Just to trigger registration.
std::thread::spawn(move || {
fn test_thread_local_registration() {
let _g = crate::tests::sequential_test();

let (next_step, stop_t0) = (WaitGroup::new(), WaitGroup::new());
let (n, s) = (next_step.clone(), stop_t0.clone());
let t0 = std::thread::spawn(move || {
LOCAL_DATA.with(|_| {});
})
.join()
.unwrap();
let mut count = 0;
while let Ok(msg) = rx.try_recv() {
match msg {
ThreadLocalMsg::Created(r) => {
assert_ne!(r.id, 0);
}
ThreadLocalMsg::Destroyed(id) => {
assert_ne!(id, 0);
}
}
count += 1;
}
// This value may be greater than 2 if other test threads access `LOCAL_DATA` in parallel.
assert!(count >= 2);
drop(n);
s.wait();
});
next_step.wait();

let (next_step, stop_t1) = (WaitGroup::new(), WaitGroup::new());
let (n, s) = (next_step.clone(), stop_t1.clone());
let t1 = std::thread::spawn(move || {
LOCAL_DATA.with(|_| {});
drop(n);
s.wait();
});
next_step.wait();

let registrations = take_thread_registrations(|t| t.collect::<Vec<_>>());
assert_eq!(registrations.len(), 2);
assert!(!registrations[0].is_thread_down());
assert!(!registrations[1].is_thread_down());

drop(stop_t0);
t0.join().unwrap();
assert!(registrations[0].is_thread_down());

drop(stop_t1);
t1.join().unwrap();
assert!(registrations[1].is_thread_down());
}
}