Skip to content

Commit

Permalink
Merge pull request #2 from zhongzc/pr-summary
Browse files Browse the repository at this point in the history
opt thread registration
  • Loading branch information
mornyx authored Oct 29, 2021
2 parents 24742f0 + de0316e commit af14b32
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 86 deletions.
29 changes: 28 additions & 1 deletion components/resource_metering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,38 @@ impl SharedTagPtr {
}

#[cfg(test)]
mod tests {
pub(crate) mod tests {
use super::*;

use std::sync::{Mutex, MutexGuard};

use crate::threadlocal::take_thread_registrations;
use lazy_static::lazy_static;

/// Tests that access [crate::threadlocal::THREAD_REGISTER_BUFFER] or [crate::config::GLOBAL_ENABLE]
/// need to be run sequentially. A helper function to
pub fn sequential_test() -> TestGuard {
TestGuard {
_guard: SEQ_LOCK.lock().unwrap(),
}
}

lazy_static! {
static ref SEQ_LOCK: Mutex<()> = Mutex::new(());
}
pub struct TestGuard {
_guard: MutexGuard<'static, ()>,
}
impl Drop for TestGuard {
fn drop(&mut self) {
take_thread_registrations(|_| {});
}
}

#[test]
fn test_attach() {
let _g = crate::tests::sequential_test();

// Use a thread created by ourself. If we use unit test thread directly,
// the test results may be affected by parallel testing.
std::thread::spawn(|| {
Expand Down
43 changes: 14 additions & 29 deletions components/resource_metering/src/recorder/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright 2021 TiKV Project Authors. Licensed under Apache-2.0.

use crate::collector::{Collector, CollectorReg, COLLECTOR_REG_CHAN};
use crate::threadlocal::{register_thread_local_chan_tx, ThreadLocalMsg, ThreadLocalRef};
use crate::threadlocal::{take_thread_registrations, ThreadLocalRef};
use crate::{RawRecords, SharedTagPtr};

use std::io;
Expand All @@ -13,7 +13,6 @@ use std::thread::JoinHandle;
use std::time::Duration;

use collections::HashMap;
use crossbeam::channel::{unbounded, Receiver};
use tikv_util::time::Instant;

mod cpu;
Expand Down Expand Up @@ -95,9 +94,7 @@ pub struct Recorder {
records: RawRecords,
recorders: Vec<Box<dyn SubRecorder + Send>>,
collectors: HashMap<u64, Box<dyn Collector>>,
thread_rx: Receiver<ThreadLocalMsg>,
thread_stores: HashMap<usize, ThreadLocalRef>,
destroyed_threads: Vec<usize>,
last_collect: Instant,
last_cleanup: Instant,
}
Expand Down Expand Up @@ -139,10 +136,9 @@ impl Recorder {

fn cleanup(&mut self) {
if self.last_cleanup.saturating_elapsed().as_secs() > CLEANUP_INTERVAL_SECS {
for id in &self.destroyed_threads {
self.thread_stores.remove(id);
}
self.destroyed_threads.clear();
self.thread_stores
.drain_filter(|_, t| t.is_thread_down())
.count();
if self.records.records.capacity() > RECORD_LEN_THRESHOLD
&& self.records.records.len() < (RECORD_LEN_THRESHOLD / 2)
{
Expand Down Expand Up @@ -190,21 +186,16 @@ impl Recorder {
}

fn handle_thread_registration(&mut self) {
while let Ok(msg) = self.thread_rx.try_recv() {
match msg {
ThreadLocalMsg::Created(tlr) => {
let id = tlr.id;
let tag = tlr.shared_ptr.clone();
self.thread_stores.insert(id, tlr);
for r in &mut self.recorders {
r.thread_created(id, tag.clone());
}
}
ThreadLocalMsg::Destroyed(id) => {
self.destroyed_threads.push(id);
take_thread_registrations(|tlrs| {
for tlr in tlrs {
let id = tlr.id;
let tag = tlr.shared_ptr.clone();
self.thread_stores.insert(id, tlr);
for r in &mut self.recorders {
r.thread_created(id, tag.clone());
}
}
}
});
}
}

Expand Down Expand Up @@ -251,18 +242,14 @@ impl RecorderBuilder {
pub fn spawn(self) -> io::Result<RecorderHandle> {
let pause = Arc::new(AtomicBool::new(!self.enable));
let precision_ms = self.precision_ms.clone();
let (tx, rx) = unbounded();
register_thread_local_chan_tx(tx);
let now = Instant::now();
let mut recorder = Recorder {
pause: pause.clone(),
precision_ms: precision_ms.clone(),
records: RawRecords::default(),
recorders: self.recorders,
collectors: HashMap::default(),
thread_rx: rx,
thread_stores: HashMap::default(),
destroyed_threads: Vec::new(),
last_collect: now,
last_cleanup: now,
};
Expand Down Expand Up @@ -400,8 +387,8 @@ mod tests {

#[test]
fn test_recorder() {
let (tx, rx) = unbounded();
register_thread_local_chan_tx(tx);
let _g = crate::tests::sequential_test();

std::thread::spawn(|| {
LOCAL_DATA.with(|_| {});
})
Expand All @@ -414,9 +401,7 @@ mod tests {
records: RawRecords::default(),
recorders: vec![Box::new(MockSubRecorder)],
collectors: HashMap::default(),
thread_rx: rx,
thread_stores: HashMap::default(),
destroyed_threads: Vec::new(),
last_collect: now,
last_cleanup: now,
};
Expand Down
17 changes: 9 additions & 8 deletions components/resource_metering/src/recorder/summary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,18 @@ impl SubRecorder for SummaryRecorder {
#[cfg(test)]
mod tests {
use super::*;
use crate::threadlocal::{register_thread_local_chan_tx, ThreadLocalMsg};

use crate::threadlocal::take_thread_registrations;
use crate::{ResourceMeteringTag, TagInfos, GLOBAL_ENABLE};
use crossbeam::channel::unbounded;

use std::sync::atomic::Ordering::SeqCst;
use std::sync::Arc;

#[test]
fn test_collect() {
GLOBAL_ENABLE.store(true, SeqCst);
let (tx, rx) = unbounded();
register_thread_local_chan_tx(tx);
let _g = crate::tests::sequential_test();

GLOBAL_ENABLE.store(true, SeqCst);
std::thread::spawn(|| {
let tag = ResourceMeteringTag {
infos: Arc::new(TagInfos {
Expand Down Expand Up @@ -140,11 +140,12 @@ mod tests {

let mut records = RawRecords::default();
let mut thread_stores = HashMap::default();
while let Ok(msg) = rx.try_recv() {
if let ThreadLocalMsg::Created(tlr) = msg {
take_thread_registrations(|tlrs| {
for tlr in tlrs {
thread_stores.insert(tlr.id, tlr);
}
}
});

let mut recorder = SummaryRecorder::default();
recorder.collect(&mut records, &mut thread_stores);
assert!(!records.records.is_empty());
Expand Down
108 changes: 60 additions & 48 deletions components/resource_metering/src/threadlocal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,49 @@ use crate::model::SummaryRecord;
use crate::{utils, ResourceMeteringTag, SharedTagPtr};

use std::cell::Cell;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::vec::Drain;

use collections::HashMap;
use crossbeam::channel::Sender;
use lazy_static::lazy_static;

lazy_static! {
/// `THREAD_LOCAL_CHANS` is used to transfer the necessary thread registration events.
static ref THREAD_LOCAL_CHANS: Mutex<Vec<Sender<ThreadLocalMsg>>> = Mutex::new(Vec::new());
/// `THREAD_REGISTER_BUFFER` is used to store the new thread registrations.
static ref THREAD_REGISTER_BUFFER: Mutex<Vec<ThreadLocalRef>> = Mutex::new(Vec::new());
}

pub fn take_thread_registrations<F, T>(mut consume: F) -> T
where
F: FnMut(Drain<ThreadLocalRef>) -> T,
{
consume(THREAD_REGISTER_BUFFER.lock().unwrap().drain(..))
}

thread_local! {
/// `LOCAL_DATA` is a thread-localized instance of [ThreadLocalData].
///
/// When a new thread tries to read `LOCAL_DATA`, it will actively send a message
/// to [THREAD_LOCAL_CHANS] during the initialization phase of thread local storage.
/// The message([ThreadLocalRef]) contains the thread id and some references to
/// When a new thread tries to read `LOCAL_DATA`, it will actively store its [ThreadLocalRef]
/// to [THREAD_REGISTER_BUFFER] during the initialization phase of thread local storage.
/// The [ThreadLocalRef] contains the thread id and some references to
/// the thread local fields.
pub static LOCAL_DATA: ThreadLocalData = {
let local_data = ThreadLocalData {
is_set: Cell::new(false),
shared_ptr: SharedTagPtr::default(),
summary_cur_record: Arc::new(SummaryRecord::default()),
summary_records: Arc::new(Mutex::new(HashMap::default())),
is_thread_down: Arc::new(AtomicBool::new(false)),
};
THREAD_LOCAL_CHANS.lock().unwrap().iter().for_each(|tx| {
tx.send(ThreadLocalMsg::Created(ThreadLocalRef{
THREAD_REGISTER_BUFFER.lock().unwrap().push(
ThreadLocalRef {
id: utils::thread_id(),
shared_ptr: local_data.shared_ptr.clone(),
summary_cur_record: local_data.summary_cur_record.clone(),
summary_records: local_data.summary_records.clone(),
})).ok();
});
is_down: local_data.is_thread_down.clone(),
}
);
local_data
};
}
Expand All @@ -51,13 +61,12 @@ pub struct ThreadLocalData {
pub shared_ptr: SharedTagPtr,
pub summary_cur_record: Arc<SummaryRecord>,
pub summary_records: Arc<Mutex<HashMap<ResourceMeteringTag, SummaryRecord>>>,
pub is_thread_down: Arc<AtomicBool>,
}

impl Drop for ThreadLocalData {
fn drop(&mut self) {
THREAD_LOCAL_CHANS.lock().unwrap().iter().for_each(|tx| {
tx.send(ThreadLocalMsg::Destroyed(utils::thread_id())).ok();
});
self.is_thread_down.store(true, Ordering::SeqCst);
}
}

Expand All @@ -67,50 +76,53 @@ pub struct ThreadLocalRef {
pub shared_ptr: SharedTagPtr,
pub summary_cur_record: Arc<SummaryRecord>,
pub summary_records: Arc<Mutex<HashMap<ResourceMeteringTag, SummaryRecord>>>,
pub is_down: Arc<AtomicBool>,
}

/// This enum is transmitted as a event in [THREAD_LOCAL_CHANS].
///
/// See [LOCAL_DATA] for more information.
#[derive(Debug)]
pub enum ThreadLocalMsg {
Created(ThreadLocalRef),
Destroyed(usize),
}

/// Register a channel to notify thread creation & destruction events.
pub fn register_thread_local_chan_tx(tx: Sender<ThreadLocalMsg>) {
THREAD_LOCAL_CHANS.lock().unwrap().push(tx);
impl ThreadLocalRef {
pub fn is_thread_down(&self) -> bool {
self.is_down.load(Ordering::SeqCst)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crossbeam::channel::unbounded;
use crossbeam::sync::WaitGroup;

#[test]
fn test_thread_local_chan() {
let (tx, rx) = unbounded();
register_thread_local_chan_tx(tx);
LOCAL_DATA.with(|_| {}); // Just to trigger registration.
std::thread::spawn(move || {
fn test_thread_local_registration() {
let _g = crate::tests::sequential_test();

let (next_step, stop_t0) = (WaitGroup::new(), WaitGroup::new());
let (n, s) = (next_step.clone(), stop_t0.clone());
let t0 = std::thread::spawn(move || {
LOCAL_DATA.with(|_| {});
})
.join()
.unwrap();
let mut count = 0;
while let Ok(msg) = rx.try_recv() {
match msg {
ThreadLocalMsg::Created(r) => {
assert_ne!(r.id, 0);
}
ThreadLocalMsg::Destroyed(id) => {
assert_ne!(id, 0);
}
}
count += 1;
}
// This value may be greater than 2 if other test threads access `LOCAL_DATA` in parallel.
assert!(count >= 2);
drop(n);
s.wait();
});
next_step.wait();

let (next_step, stop_t1) = (WaitGroup::new(), WaitGroup::new());
let (n, s) = (next_step.clone(), stop_t1.clone());
let t1 = std::thread::spawn(move || {
LOCAL_DATA.with(|_| {});
drop(n);
s.wait();
});
next_step.wait();

let registrations = take_thread_registrations(|t| t.collect::<Vec<_>>());
assert_eq!(registrations.len(), 2);
assert!(!registrations[0].is_thread_down());
assert!(!registrations[1].is_thread_down());

drop(stop_t0);
t0.join().unwrap();
assert!(registrations[0].is_thread_down());

drop(stop_t1);
t1.join().unwrap();
assert!(registrations[1].is_thread_down());
}
}

0 comments on commit af14b32

Please sign in to comment.