diff --git a/Cargo.lock b/Cargo.lock index b962d93610d4a..f43a7d0259a26 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -556,7 +556,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "334d75cf09b33bede6cbc20e52515853ae7bee3d4eadd9540e13ce92af983d34" dependencies = [ "event-listener 3.1.0", - "event-listener-strategy 0.1.0", + "event-listener-strategy", "futures-core", ] @@ -571,19 +571,6 @@ dependencies = [ "futures-core", ] -[[package]] -name = "async-channel" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ca33f4bc4ed1babef42cad36cc1f51fa88be00420404e5b1e80ab1b18f7678c" -dependencies = [ - "concurrent-queue", - "event-listener 4.0.0", - "event-listener-strategy 0.4.0", - "futures-core", - "pin-project-lite", -] - [[package]] name = "async-compression" version = "0.4.1" @@ -636,7 +623,7 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1b6f5d7df27bd294849f8eec66ecfc63d11814df7a4f5d74168a2394467b776" dependencies = [ - "async-channel 1.8.0", + "async-channel", "async-executor", "async-io", "async-lock", @@ -691,7 +678,7 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62565bb4402e926b29953c785397c6dc0391b7b446e45008b0049eb43cec6f5d" dependencies = [ - "async-channel 1.8.0", + "async-channel", "async-global-executor", "async-io", "async-lock", @@ -1124,7 +1111,7 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c67b173a56acffd6d2326fb7ab938ba0b00a71480e14902b2591c87bc5741e8" dependencies = [ - "async-channel 1.8.0", + "async-channel", "async-lock", "async-task", "atomic-waker", @@ -2469,7 +2456,6 @@ dependencies = [ "anyerror", "anyhow", "async-backtrace", - "async-channel 2.1.1", "async-trait-fn", "bytesize", "ctrlc", @@ -3218,7 +3204,7 @@ name = "databend-common-pipeline-sinks" version = "0.1.0" dependencies = [ "async-backtrace", - "async-channel 1.8.0", + "async-channel", "async-trait-fn", "databend-common-base", "databend-common-catalog", @@ -3233,7 +3219,7 @@ name = "databend-common-pipeline-sources" version = "0.1.0" dependencies = [ "async-backtrace", - "async-channel 1.8.0", + "async-channel", "async-trait-fn", "bstr 1.6.2", "csv-core", @@ -4193,7 +4179,7 @@ dependencies = [ "arrow-ipc", "arrow-schema", "async-backtrace", - "async-channel 1.8.0", + "async-channel", "async-stream", "async-trait-fn", "backoff", @@ -5123,17 +5109,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "event-listener" -version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "770d968249b5d99410d61f5bf89057f3199a077a04d087092f58e7d10692baae" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - [[package]] name = "event-listener-strategy" version = "0.1.0" @@ -5144,16 +5119,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "event-listener-strategy" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" -dependencies = [ - "event-listener 4.0.0", - "pin-project-lite", -] - [[package]] name = "fallible-streaming-iterator" version = "0.1.9" @@ -7483,7 +7448,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad" dependencies = [ "anyhow", - "async-channel 1.8.0", + "async-channel", "base64 0.13.1", "futures-lite", "http", diff --git a/src/binaries/query/ee_main.rs b/src/binaries/query/ee_main.rs index 1e28e09ca6934..408f5bd1aa936 100644 --- a/src/binaries/query/ee_main.rs +++ b/src/binaries/query/ee_main.rs @@ -19,6 +19,7 @@ mod entry; use databend_common_base::mem_allocator::GlobalAllocator; use databend_common_base::runtime::Runtime; +use databend_common_base::runtime::ThreadTracker; use databend_common_config::InnerConfig; use databend_common_exception::Result; use databend_enterprise_query::enterprise_services::EnterpriseServices; @@ -31,6 +32,8 @@ use crate::entry::start_services; pub static GLOBAL_ALLOCATOR: GlobalAllocator = GlobalAllocator; fn main() { + ThreadTracker::init(); + match Runtime::with_default_worker_threads() { Err(cause) => { eprintln!("Databend Query start failure, cause: {:?}", cause); diff --git a/src/binaries/query/entry.rs b/src/binaries/query/entry.rs index d616e2d34dc92..ffba334b918f3 100644 --- a/src/binaries/query/entry.rs +++ b/src/binaries/query/entry.rs @@ -16,8 +16,8 @@ use std::env; use std::time::Duration; use databend_common_base::mem_allocator::GlobalAllocator; +use databend_common_base::runtime::set_alloc_error_hook; use databend_common_base::runtime::GLOBAL_MEM_STAT; -use databend_common_base::set_alloc_error_hook; use databend_common_config::Commands; use databend_common_config::InnerConfig; use databend_common_config::DATABEND_COMMIT_VERSION; diff --git a/src/binaries/query/oss_main.rs b/src/binaries/query/oss_main.rs index 64f24f6d9690a..c3a63b4f6a845 100644 --- a/src/binaries/query/oss_main.rs +++ b/src/binaries/query/oss_main.rs @@ -19,6 +19,7 @@ mod entry; use databend_common_base::mem_allocator::GlobalAllocator; use databend_common_base::runtime::Runtime; +use databend_common_base::runtime::ThreadTracker; use databend_common_config::InnerConfig; use databend_common_exception::Result; use databend_common_license::license_manager::LicenseManager; @@ -32,6 +33,8 @@ use crate::entry::start_services; pub static GLOBAL_ALLOCATOR: GlobalAllocator = GlobalAllocator; fn main() { + ThreadTracker::init(); + match Runtime::with_default_worker_threads() { Err(cause) => { eprintln!("Databend Query start failure, cause: {:?}", cause); diff --git a/src/common/base/Cargo.toml b/src/common/base/Cargo.toml index c5bc34cf1bcbc..041733f143ff0 100644 --- a/src/common/base/Cargo.toml +++ b/src/common/base/Cargo.toml @@ -28,7 +28,6 @@ databend-common-exception = { path = "../exception" } # Crates.io dependencies async-backtrace = { workspace = true } -async-channel = "2" async-trait = { workspace = true } bytesize = "1.1.0" ctrlc = { version = "3.2.3", features = ["termination"] } diff --git a/src/common/base/src/runtime/memory/alloc_error_hook.rs b/src/common/base/src/runtime/memory/alloc_error_hook.rs new file mode 100644 index 0000000000000..29487d573495d --- /dev/null +++ b/src/common/base/src/runtime/memory/alloc_error_hook.rs @@ -0,0 +1,30 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::runtime::LimitMemGuard; +use crate::runtime::ThreadTracker; + +pub fn set_alloc_error_hook() { + std::alloc::set_alloc_error_hook(|layout| { + let _guard = LimitMemGuard::enter_unlimited(); + + let out_of_limit_desc = ThreadTracker::replace_error_message(None); + + panic!( + "{}", + out_of_limit_desc + .unwrap_or_else(|| format!("memory allocation of {} bytes failed", layout.size())) + ); + }) +} diff --git a/src/common/base/src/runtime/memory/mem_stat.rs b/src/common/base/src/runtime/memory/mem_stat.rs new file mode 100644 index 0000000000000..7c0f2190d09d1 --- /dev/null +++ b/src/common/base/src/runtime/memory/mem_stat.rs @@ -0,0 +1,416 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Debug; +use std::fmt::Formatter; +use std::sync::atomic::AtomicI64; +use std::sync::atomic::Ordering; +use std::sync::Arc; + +use bytesize::ByteSize; +use log::info; + +/// The program mem stat +/// +/// Every alloc/dealloc stat will be fed to this mem stat. +pub static GLOBAL_MEM_STAT: MemStat = MemStat::global(); + +const MINIMUM_MEMORY_LIMIT: i64 = 256 * 1024 * 1024; + +/// Memory allocation stat. +/// +/// - A MemStat have child MemStat. +/// - Every stat that is fed to a child is also fed to its parent. +/// - A MemStat has at most one parent. +pub struct MemStat { + #[allow(dead_code)] + global: bool, + + name: Option, + + pub(crate) used: AtomicI64, + + pub(crate) peak_used: AtomicI64, + + /// The limit of max used memory for this tracker. + /// + /// Set to 0 to disable the limit. + limit: AtomicI64, + + parent_memory_stat: Option>, +} + +impl MemStat { + pub const fn global() -> Self { + Self { + name: None, + global: true, + used: AtomicI64::new(0), + limit: AtomicI64::new(0), + peak_used: AtomicI64::new(0), + parent_memory_stat: None, + } + } + + pub fn create(name: String) -> Arc { + MemStat::create_child(name, None) + } + + pub fn create_child(name: String, parent_memory_stat: Option>) -> Arc { + Arc::new(MemStat { + global: false, + name: Some(name), + used: AtomicI64::new(0), + limit: AtomicI64::new(0), + peak_used: AtomicI64::new(0), + parent_memory_stat, + }) + } + + pub fn set_limit(&self, mut size: i64) { + // It may cause the process unable to run if memory limit is too low. + if size > 0 && size < MINIMUM_MEMORY_LIMIT { + size = MINIMUM_MEMORY_LIMIT; + } + + self.limit.store(size, Ordering::Relaxed); + } + + /// Feed memory usage stat to MemStat and return if it exceeds the limit. + /// + /// It feeds `state` to the this tracker and all of its ancestors, including GLOBAL_TRACKER. + #[inline] + pub fn record_memory( + &self, + batch_memory_used: i64, + current_memory_alloc: i64, + ) -> Result<(), OutOfLimit> { + let mut used = self.used.fetch_add(batch_memory_used, Ordering::Relaxed); + + used += batch_memory_used; + self.peak_used.fetch_max(used, Ordering::Relaxed); + + if let Some(parent_memory_stat) = self.parent_memory_stat.as_deref() { + if let Err(cause) = parent_memory_stat + .record_memory::(batch_memory_used, current_memory_alloc) + { + if NEED_ROLLBACK { + // We only roll back the memory that alloc failed + self.used.fetch_sub(current_memory_alloc, Ordering::Relaxed); + self.peak_used + .store(used - current_memory_alloc, Ordering::Relaxed); + } + + return Err(cause); + } + } + + if let Err(cause) = self.check_limit(used) { + if NEED_ROLLBACK { + // NOTE: we cannot rollback peak_used of parent mem stat in this case + // self.peak_used.store(peak_used, Ordering::Relaxed); + self.rollback(current_memory_alloc); + } + + return Err(cause); + } + + Ok(()) + } + + pub fn rollback(&self, memory_usage: i64) { + self.used.fetch_sub(memory_usage, Ordering::Relaxed); + + if let Some(parent_memory_stat) = self.parent_memory_stat.as_deref() { + parent_memory_stat.rollback(memory_usage) + } + } + + /// Check if used memory is out of the limit. + #[inline] + fn check_limit(&self, used: i64) -> Result<(), OutOfLimit> { + let limit = self.limit.load(Ordering::Relaxed); + + // No limit + if limit == 0 { + return Ok(()); + } + + if used <= limit { + return Ok(()); + } + + Err(OutOfLimit::new(used, limit)) + } + + #[inline] + pub fn get_memory_usage(&self) -> i64 { + self.used.load(Ordering::Relaxed) + } + + #[inline] + #[allow(unused)] + pub fn get_peak_memory_usage(&self) -> i64 { + self.peak_used.load(Ordering::Relaxed) + } + + #[allow(unused)] + pub fn log_memory_usage(&self) { + let name = self.name.clone().unwrap_or_else(|| String::from("global")); + let memory_usage = self.used.load(Ordering::Relaxed); + let memory_usage = std::cmp::max(0, memory_usage) as u64; + info!( + "Current memory usage({}): {}.", + name, + ByteSize::b(memory_usage) + ); + } + + #[allow(unused)] + pub fn log_peek_memory_usage(&self) { + let name = self.name.clone().unwrap_or_else(|| String::from("global")); + let peak_memory_usage = self.peak_used.load(Ordering::Relaxed); + let peak_memory_usage = std::cmp::max(0, peak_memory_usage) as u64; + info!( + "Peak memory usage({}): {}.", + name, + ByteSize::b(peak_memory_usage) + ); + } +} + +/// Error of exceeding limit. +#[derive(Clone)] +pub struct OutOfLimit { + pub value: V, + pub limit: V, +} + +impl OutOfLimit { + pub const fn new(value: V, limit: V) -> Self { + Self { value, limit } + } +} + +impl Debug for OutOfLimit { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "memory usage {}({}) exceeds limit {}({})", + ByteSize::b(self.value as u64), + self.value, + ByteSize::b(self.limit as u64), + self.limit, + ) + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::Ordering; + + use databend_common_exception::Result; + + use crate::runtime::memory::mem_stat::MINIMUM_MEMORY_LIMIT; + use crate::runtime::MemStat; + + #[test] + fn test_single_level_mem_stat() -> Result<()> { + let mem_stat = MemStat::create("TEST".to_string()); + + mem_stat.record_memory::(1, 1).unwrap(); + mem_stat.record_memory::(2, 2).unwrap(); + mem_stat.record_memory::(-1, -1).unwrap(); + + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 2); + assert_eq!(mem_stat.peak_used.load(Ordering::Relaxed), 3); + + Ok(()) + } + + #[test] + fn test_single_level_mem_stat_with_check_limit() -> Result<()> { + let mem_stat = MemStat::create("TEST".to_string()); + mem_stat.set_limit(MINIMUM_MEMORY_LIMIT); + + mem_stat.record_memory::(1, 1).unwrap(); + assert!( + mem_stat + .record_memory::(MINIMUM_MEMORY_LIMIT, MINIMUM_MEMORY_LIMIT) + .is_err() + ); + assert_eq!( + mem_stat.used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + ); + assert_eq!( + mem_stat.peak_used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + ); + + assert!(mem_stat.record_memory::(1, 1).is_err()); + assert_eq!( + mem_stat.used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + 1 + ); + assert_eq!( + mem_stat.peak_used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + 1 + ); + + assert!(mem_stat.record_memory::(1, 1).is_err()); + assert_eq!( + mem_stat.used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + 1 + ); + assert_eq!( + mem_stat.peak_used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + 1 + ); + + assert!(mem_stat.record_memory::(-1, -1).is_err()); + assert_eq!( + mem_stat.used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + 1 + ); + assert_eq!( + mem_stat.peak_used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + 1 + ); + + assert!(mem_stat.record_memory::(-1, -1).is_err()); + assert_eq!( + mem_stat.used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + ); + assert_eq!( + mem_stat.peak_used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + 1 + ); + + Ok(()) + } + + #[test] + fn test_multiple_level_mem_stat() -> Result<()> { + let mem_stat = MemStat::create("TEST".to_string()); + let child_mem_stat = + MemStat::create_child("TEST_CHILD".to_string(), Some(mem_stat.clone())); + + mem_stat.record_memory::(1, 1).unwrap(); + mem_stat.record_memory::(2, 2).unwrap(); + mem_stat.record_memory::(-1, -1).unwrap(); + + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 2); + assert_eq!(mem_stat.peak_used.load(Ordering::Relaxed), 3); + assert_eq!(child_mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(child_mem_stat.peak_used.load(Ordering::Relaxed), 0); + + child_mem_stat.record_memory::(1, 1).unwrap(); + child_mem_stat.record_memory::(2, 2).unwrap(); + child_mem_stat.record_memory::(-1, -1).unwrap(); + + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 4); + assert_eq!(mem_stat.peak_used.load(Ordering::Relaxed), 5); + assert_eq!(child_mem_stat.used.load(Ordering::Relaxed), 2); + assert_eq!(child_mem_stat.peak_used.load(Ordering::Relaxed), 3); + + Ok(()) + } + + #[test] + fn test_multiple_level_mem_stat_with_check_limit() -> Result<()> { + let mem_stat = MemStat::create("TEST".to_string()); + mem_stat.set_limit(MINIMUM_MEMORY_LIMIT * 2); + let child_mem_stat = + MemStat::create_child("TEST_CHILD".to_string(), Some(mem_stat.clone())); + child_mem_stat.set_limit(MINIMUM_MEMORY_LIMIT); + + mem_stat.record_memory::(1, 1).unwrap(); + assert!( + mem_stat + .record_memory::(MINIMUM_MEMORY_LIMIT, MINIMUM_MEMORY_LIMIT) + .is_ok() + ); + assert_eq!( + mem_stat.used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + ); + assert_eq!( + mem_stat.peak_used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + ); + assert_eq!(child_mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(child_mem_stat.peak_used.load(Ordering::Relaxed), 0); + + child_mem_stat.record_memory::(1, 1).unwrap(); + assert!( + child_mem_stat + .record_memory::(MINIMUM_MEMORY_LIMIT, MINIMUM_MEMORY_LIMIT) + .is_err() + ); + assert_eq!( + mem_stat.used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + 1 + MINIMUM_MEMORY_LIMIT + ); + assert_eq!( + mem_stat.peak_used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + 1 + MINIMUM_MEMORY_LIMIT + ); + assert_eq!( + child_mem_stat.used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + ); + assert_eq!( + child_mem_stat.peak_used.load(Ordering::Relaxed), + 1 + MINIMUM_MEMORY_LIMIT + ); + + // parent failure + let mem_stat = MemStat::create("TEST".to_string()); + mem_stat.set_limit(MINIMUM_MEMORY_LIMIT); + let child_mem_stat = + MemStat::create_child("TEST_CHILD".to_string(), Some(mem_stat.clone())); + child_mem_stat.set_limit(MINIMUM_MEMORY_LIMIT * 2); + + assert!( + child_mem_stat + .record_memory::(1 + MINIMUM_MEMORY_LIMIT, 1 + MINIMUM_MEMORY_LIMIT) + .is_err() + ); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(mem_stat.peak_used.load(Ordering::Relaxed), 0); + assert_eq!(child_mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(child_mem_stat.peak_used.load(Ordering::Relaxed), 0); + + // child failure + let mem_stat = MemStat::create("TEST".to_string()); + mem_stat.set_limit(MINIMUM_MEMORY_LIMIT * 2); + let child_mem_stat = + MemStat::create_child("TEST_CHILD".to_string(), Some(mem_stat.clone())); + child_mem_stat.set_limit(MINIMUM_MEMORY_LIMIT); + + assert!( + child_mem_stat + .record_memory::(1 + MINIMUM_MEMORY_LIMIT, 1 + MINIMUM_MEMORY_LIMIT) + .is_err() + ); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + // assert_eq!(mem_stat.peak_used.load(Ordering::Relaxed), 0); + assert_eq!(child_mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(child_mem_stat.peak_used.load(Ordering::Relaxed), 0); + + Ok(()) + } +} diff --git a/src/common/base/src/runtime/memory/mod.rs b/src/common/base/src/runtime/memory/mod.rs new file mode 100644 index 0000000000000..80be2e68348e3 --- /dev/null +++ b/src/common/base/src/runtime/memory/mod.rs @@ -0,0 +1,23 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod alloc_error_hook; +mod mem_stat; +mod stat_buffer; + +pub use alloc_error_hook::set_alloc_error_hook; +pub use mem_stat::MemStat; +pub use mem_stat::OutOfLimit; +pub use mem_stat::GLOBAL_MEM_STAT; +pub use stat_buffer::StatBuffer; diff --git a/src/common/base/src/runtime/memory/stat_buffer.rs b/src/common/base/src/runtime/memory/stat_buffer.rs new file mode 100644 index 0000000000000..d38749caf637a --- /dev/null +++ b/src/common/base/src/runtime/memory/stat_buffer.rs @@ -0,0 +1,220 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ptr::addr_of_mut; +use std::sync::atomic::Ordering; + +use databend_common_exception::Result; + +use crate::runtime::memory::mem_stat::OutOfLimit; +use crate::runtime::memory::MemStat; +use crate::runtime::LimitMemGuard; +use crate::runtime::ThreadTracker; +use crate::runtime::GLOBAL_MEM_STAT; + +#[thread_local] +static mut STAT_BUFFER: StatBuffer = StatBuffer::empty(&GLOBAL_MEM_STAT); + +static MEM_STAT_BUFFER_SIZE: i64 = 4 * 1024 * 1024; + +/// Buffering memory allocation stats. +/// +/// A StatBuffer buffers stats changes in local variables, and periodically flush them to other storage such as an `Arc` shared by several threads. +#[derive(Clone)] +pub struct StatBuffer { + memory_usage: i64, + // Whether to allow unlimited memory. Alloc memory will not panic if it is true. + unlimited_flag: bool, + global_mem_stat: &'static MemStat, + destroyed_thread_local_macro: bool, +} + +impl StatBuffer { + pub const fn empty(global_mem_stat: &'static MemStat) -> Self { + Self { + memory_usage: 0, + global_mem_stat, + unlimited_flag: false, + destroyed_thread_local_macro: false, + } + } + + pub fn current() -> &'static mut StatBuffer { + unsafe { &mut *addr_of_mut!(STAT_BUFFER) } + } + + pub fn is_unlimited(&self) -> bool { + self.unlimited_flag + } + + pub fn set_unlimited_flag(&mut self, flag: bool) -> bool { + let old = self.unlimited_flag; + self.unlimited_flag = flag; + old + } + + pub fn incr(&mut self, bs: i64) -> i64 { + self.memory_usage += bs; + self.memory_usage + } + + /// Flush buffered stat to MemStat it belongs to. + pub fn flush(&mut self, alloc: i64) -> Result<(), OutOfLimit> { + match std::mem::take(&mut self.memory_usage) { + 0 => Ok(()), + usage => { + if let Err(e) = self.global_mem_stat.record_memory::(usage, alloc) { + if !ROLLBACK { + let _ = ThreadTracker::record_memory::(usage, alloc); + } + + return Err(e); + } + + if let Err(e) = ThreadTracker::record_memory::(usage, alloc) { + if ROLLBACK { + self.global_mem_stat.rollback(alloc); + return Err(e); + } + } + + Ok(()) + } + } + } + + pub fn alloc(&mut self, memory_usage: i64) -> Result<(), OutOfLimit> { + // Rust will alloc or dealloc memory after the thread local is destroyed when we using thread_local macro. + // This is the boundary of thread exit. It may be dangerous to throw mistakes here. + if self.destroyed_thread_local_macro { + let used = self + .global_mem_stat + .used + .fetch_add(memory_usage, Ordering::Relaxed); + self.global_mem_stat + .peak_used + .fetch_max(used + memory_usage, Ordering::Relaxed); + return Ok(()); + } + + match self.incr(memory_usage) <= MEM_STAT_BUFFER_SIZE { + true => Ok(()), + false => self.flush::(memory_usage), + } + } + + pub fn dealloc(&mut self, memory_usage: i64) { + // Rust will alloc or dealloc memory after the thread local is destroyed when we using thread_local macro. + if self.destroyed_thread_local_macro { + self.global_mem_stat + .used + .fetch_add(-memory_usage, Ordering::Relaxed); + return; + } + + if self.incr(-memory_usage) < -MEM_STAT_BUFFER_SIZE { + let _ = self.flush::(memory_usage); + } + + // NOTE: De-allocation does not panic + // even when it's possible exceeding the limit + // due to other threads sharing the same MemStat may have allocated a lot of memory. + } + + pub fn mark_destroyed(&mut self) { + let _guard = LimitMemGuard::enter_unlimited(); + let memory_usage = std::mem::take(&mut self.memory_usage); + + // Memory operations during destruction will be recorded to global stat. + self.destroyed_thread_local_macro = true; + let _ = self.global_mem_stat.record_memory::(memory_usage, 0); + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::Ordering; + + use databend_common_exception::Result; + + use crate::runtime::memory::stat_buffer::MEM_STAT_BUFFER_SIZE; + use crate::runtime::memory::MemStat; + use crate::runtime::memory::StatBuffer; + + #[test] + fn test_alloc() -> Result<()> { + static TEST_MEM_STATE: MemStat = MemStat::global(); + let mut buffer = StatBuffer::empty(&TEST_MEM_STATE); + + buffer.alloc(1).unwrap(); + assert_eq!(buffer.memory_usage, 1); + assert_eq!(TEST_MEM_STATE.used.load(Ordering::Relaxed), 0); + + buffer.destroyed_thread_local_macro = true; + buffer.alloc(2).unwrap(); + assert_eq!(buffer.memory_usage, 1); + assert_eq!(TEST_MEM_STATE.used.load(Ordering::Relaxed), 2); + + buffer.destroyed_thread_local_macro = false; + buffer.alloc(MEM_STAT_BUFFER_SIZE).unwrap(); + assert_eq!(buffer.memory_usage, 0); + assert_eq!( + TEST_MEM_STATE.used.load(Ordering::Relaxed), + MEM_STAT_BUFFER_SIZE + 1 + 2 + ); + + Ok(()) + } + + #[test] + fn test_dealloc() -> Result<()> { + static TEST_MEM_STATE: MemStat = MemStat::global(); + let mut buffer = StatBuffer::empty(&TEST_MEM_STATE); + + buffer.dealloc(1); + assert_eq!(buffer.memory_usage, -1); + assert_eq!(TEST_MEM_STATE.used.load(Ordering::Relaxed), 0); + + buffer.destroyed_thread_local_macro = true; + buffer.dealloc(2); + assert_eq!(buffer.memory_usage, -1); + assert_eq!(TEST_MEM_STATE.used.load(Ordering::Relaxed), -2); + + buffer.destroyed_thread_local_macro = false; + buffer.dealloc(MEM_STAT_BUFFER_SIZE); + assert_eq!(buffer.memory_usage, 0); + assert_eq!( + TEST_MEM_STATE.used.load(Ordering::Relaxed), + -(MEM_STAT_BUFFER_SIZE + 1 + 2) + ); + + Ok(()) + } + + #[test] + fn test_mark_destroyed() -> Result<()> { + static TEST_MEM_STATE: MemStat = MemStat::global(); + + let mut buffer = StatBuffer::empty(&TEST_MEM_STATE); + + assert!(!buffer.destroyed_thread_local_macro); + buffer.alloc(1).unwrap(); + assert_eq!(TEST_MEM_STATE.used.load(Ordering::Relaxed), 0); + buffer.mark_destroyed(); + assert!(buffer.destroyed_thread_local_macro); + assert_eq!(TEST_MEM_STATE.used.load(Ordering::Relaxed), 1); + + Ok(()) + } +} diff --git a/src/common/base/src/runtime/mod.rs b/src/common/base/src/runtime/mod.rs index 946a0fee792d9..30d5ad00bafd6 100644 --- a/src/common/base/src/runtime/mod.rs +++ b/src/common/base/src/runtime/mod.rs @@ -15,6 +15,7 @@ mod backtrace; mod catch_unwind; mod global_runtime; +mod memory; #[allow(clippy::module_inception)] mod runtime; mod runtime_tracker; @@ -27,6 +28,9 @@ pub use catch_unwind::catch_unwind; pub use catch_unwind::CatchUnwindFuture; pub use global_runtime::GlobalIORuntime; pub use global_runtime::GlobalQueryRuntime; +pub use memory::set_alloc_error_hook; +pub use memory::MemStat; +pub use memory::GLOBAL_MEM_STAT; pub use runtime::block_on; pub use runtime::execute_futures_in_parallel; pub use runtime::match_join_handle; @@ -39,12 +43,8 @@ pub use runtime::Dropper; pub use runtime::Runtime; pub use runtime::TrySpawn; pub use runtime::GLOBAL_TASK; -pub use runtime_tracker::set_alloc_error_hook; pub use runtime_tracker::LimitMemGuard; -pub use runtime_tracker::MemStat; pub use runtime_tracker::ThreadTracker; -pub use runtime_tracker::TrackedFuture; pub use runtime_tracker::UnlimitedFuture; -pub use runtime_tracker::GLOBAL_MEM_STAT; pub use thread::Thread; pub use thread::ThreadJoinHandle; diff --git a/src/common/base/src/runtime/runtime.rs b/src/common/base/src/runtime/runtime.rs index cec2364b6d19d..e9b0d99d59acb 100644 --- a/src/common/base/src/runtime/runtime.rs +++ b/src/common/base/src/runtime/runtime.rs @@ -31,9 +31,10 @@ use tokio::sync::Semaphore; use tokio::task::JoinHandle; use crate::runtime::catch_unwind::CatchUnwindFuture; -use crate::runtime::MemStat; +use crate::runtime::memory::MemStat; use crate::runtime::Thread; use crate::runtime::ThreadJoinHandle; +use crate::runtime::ThreadTracker; /// Methods to spawn tasks. pub trait TrySpawn { @@ -130,15 +131,6 @@ impl Runtime { }) } - fn tracker_builder(mem_stat: Arc) -> tokio::runtime::Builder { - let mut builder = tokio::runtime::Builder::new_multi_thread(); - builder - .enable_all() - .on_thread_start(mem_stat.on_start_thread()); - - builder - } - pub fn get_tracker(&self) -> Arc { self.tracker.clone() } @@ -148,7 +140,7 @@ impl Runtime { /// its executor. pub fn with_default_worker_threads() -> Result { let mem_stat = MemStat::create(String::from("UnnamedRuntime")); - let mut runtime_builder = Self::tracker_builder(mem_stat.clone()); + let mut runtime_builder = tokio::runtime::Builder::new_multi_thread(); #[cfg(debug_assertions)] { @@ -162,7 +154,13 @@ impl Runtime { runtime_builder.thread_stack_size(20 * 1024 * 1024); } - Self::create(None, mem_stat, &mut runtime_builder) + Self::create( + None, + mem_stat, + runtime_builder + .enable_all() + .on_thread_start(ThreadTracker::init), + ) } #[allow(unused_mut)] @@ -174,7 +172,7 @@ impl Runtime { } let mem_stat = MemStat::create(mem_stat_name); - let mut runtime_builder = Self::tracker_builder(mem_stat.clone()); + let mut runtime_builder = tokio::runtime::Builder::new_multi_thread(); #[cfg(debug_assertions)] { @@ -195,7 +193,10 @@ impl Runtime { Self::create( thread_name, mem_stat, - runtime_builder.worker_threads(workers), + runtime_builder + .enable_all() + .on_thread_start(ThreadTracker::init) + .worker_threads(workers), ) } diff --git a/src/common/base/src/runtime/runtime_tracker.rs b/src/common/base/src/runtime/runtime_tracker.rs index 51a4ca7ffd173..4f2a082fbf4bb 100644 --- a/src/common/base/src/runtime/runtime_tracker.rs +++ b/src/common/base/src/runtime/runtime_tracker.rs @@ -44,26 +44,17 @@ use std::alloc::AllocError; use std::cell::RefCell; -use std::fmt::Debug; -use std::fmt::Formatter; use std::future::Future; -use std::mem::take; use std::pin::Pin; -use std::ptr::addr_of_mut; -use std::sync::atomic::AtomicI64; -use std::sync::atomic::Ordering; use std::sync::Arc; use std::task::Context; use std::task::Poll; -use bytesize::ByteSize; -use log::info; use pin_project_lite::pin_project; -/// The root tracker. -/// -/// Every alloc/dealloc stat will be fed to this tracker. -pub static GLOBAL_MEM_STAT: MemStat = MemStat::global(); +use crate::runtime::memory::MemStat; +use crate::runtime::memory::OutOfLimit; +use crate::runtime::memory::StatBuffer; // For implemented and needs to call drop, we cannot use the attribute tag thread local. // https://play.rust-lang.org/?version=nightly&mode=debug&edition=2021&gist=ea33533387d401e86423df1a764b5609 @@ -71,97 +62,31 @@ thread_local! { static TRACKER: RefCell = const { RefCell::new(ThreadTracker::empty()) }; } -#[thread_local] -static mut STAT_BUFFER: StatBuffer = StatBuffer::empty(); - -static MEM_STAT_BUFFER_SIZE: i64 = 4 * 1024 * 1024; - -pub fn set_alloc_error_hook() { - std::alloc::set_alloc_error_hook(|layout| { - let _guard = LimitMemGuard::enter_unlimited(); - - let out_of_limit_desc = ThreadTracker::replace_error_message(None); - - panic!( - "{}", - out_of_limit_desc - .unwrap_or_else(|| format!("memory allocation of {} bytes failed", layout.size())) - ); - }) -} - -/// A guard that restores the thread local tracker to the `saved` when dropped. -pub struct Entered { - /// Saved tracker for restoring - saved: Option>, -} - -impl Drop for Entered { - fn drop(&mut self) { - unsafe { - let _ = STAT_BUFFER.flush::(); - ThreadTracker::replace_mem_stat(self.saved.take()); - } - } -} - pub struct LimitMemGuard { saved: bool, } impl LimitMemGuard { pub fn enter_unlimited() -> Self { - unsafe { - let saved = STAT_BUFFER.unlimited_flag; - STAT_BUFFER.unlimited_flag = true; - Self { saved } + Self { + saved: StatBuffer::current().set_unlimited_flag(true), } } pub fn enter_limited() -> Self { - unsafe { - let saved = STAT_BUFFER.unlimited_flag; - STAT_BUFFER.unlimited_flag = false; - Self { saved } + Self { + saved: StatBuffer::current().set_unlimited_flag(false), } } pub(crate) fn is_unlimited() -> bool { - unsafe { STAT_BUFFER.unlimited_flag } + StatBuffer::current().is_unlimited() } } impl Drop for LimitMemGuard { fn drop(&mut self) { - unsafe { - STAT_BUFFER.unlimited_flag = self.saved; - } - } -} - -/// Error of exceeding limit. -#[derive(Clone)] -pub struct OutOfLimit { - pub value: V, - pub limit: V, -} - -impl OutOfLimit { - pub const fn new(value: V, limit: V) -> Self { - Self { value, limit } - } -} - -impl Debug for OutOfLimit { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "memory usage {}({}) exceeds limit {}({})", - ByteSize::b(self.value as u64), - self.value, - ByteSize::b(self.limit as u64), - self.limit, - ) + StatBuffer::current().set_unlimited_flag(self.saved); } } @@ -174,14 +99,7 @@ pub struct ThreadTracker { impl Drop for ThreadTracker { fn drop(&mut self) { - unsafe { - let _guard = LimitMemGuard::enter_unlimited(); - let memory_usage = take(&mut STAT_BUFFER.memory_usage); - - // Memory operations during destruction will be recorded to global stat. - STAT_BUFFER.destroyed_thread_local_macro = true; - let _ = MemStat::record_memory::(&None, memory_usage); - } + StatBuffer::current().mark_destroyed(); } } @@ -197,6 +115,14 @@ impl ThreadTracker { } } + // rust style thread local is always lazy init. + // need to be called immediately after the threads start + pub fn init() { + TRACKER.with(|x| { + let _ = x.borrow_mut(); + }) + } + /// Replace the `mem_stat` with the current thread's. pub fn replace_mem_stat(mem_state: Option>) -> Option> { TRACKER.with(|v: &RefCell| { @@ -217,39 +143,12 @@ impl ThreadTracker { }) } - /// Enters a context in which it reports memory stats to `mem stat` and returns a guard that restores the previous mem stat when being dropped. - pub fn enter(mem_state: Option>) -> Entered { - unsafe { - let _ = STAT_BUFFER.flush::(); - Entered { - saved: ThreadTracker::replace_mem_stat(mem_state), - } - } - } - /// Accumulate stat about allocated memory. /// /// `size` is the positive number of allocated bytes. #[inline] pub fn alloc(size: i64) -> Result<(), AllocError> { - let state_buffer = unsafe { &mut *addr_of_mut!(STAT_BUFFER) }; - - // Rust will alloc or dealloc memory after the thread local is destroyed when we using thread_local macro. - // This is the boundary of thread exit. It may be dangerous to throw mistakes here. - if state_buffer.destroyed_thread_local_macro { - let used = GLOBAL_MEM_STAT.used.fetch_add(size, Ordering::Relaxed); - GLOBAL_MEM_STAT - .peak_used - .fetch_max(used + size, Ordering::Relaxed); - return Ok(()); - } - - let has_oom = match state_buffer.incr(size) <= MEM_STAT_BUFFER_SIZE { - true => Ok(()), - false => state_buffer.flush::(), - }; - - if let Err(out_of_limit) = has_oom { + if let Err(out_of_limit) = StatBuffer::current().alloc(size) { // https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=03d21a15e52c7c0356fca04ece283cf9 if !std::thread::panicking() && !LimitMemGuard::is_unlimited() { let _guard = LimitMemGuard::enter_unlimited(); @@ -266,274 +165,26 @@ impl ThreadTracker { /// `size` is positive number of bytes of the memory to deallocate. #[inline] pub fn dealloc(size: i64) { - let state_buffer = unsafe { &mut *addr_of_mut!(STAT_BUFFER) }; - - // Rust will alloc or dealloc memory after the thread local is destroyed when we using thread_local macro. - if state_buffer.destroyed_thread_local_macro { - GLOBAL_MEM_STAT.used.fetch_add(-size, Ordering::Relaxed); - return; - } - - if state_buffer.incr(-size) < -MEM_STAT_BUFFER_SIZE { - let _ = state_buffer.flush::(); - } - - // NOTE: De-allocation does not panic - // even when it's possible exceeding the limit - // due to other threads sharing the same MemStat may have allocated a lot of memory. + StatBuffer::current().dealloc(size) } -} -/// Buffering memory allocation stats. -/// -/// A StatBuffer buffers stats changes in local variables, and periodically flush them to other storage such as an `Arc` shared by several threads. -#[derive(Clone, Debug, Default)] -pub struct StatBuffer { - memory_usage: i64, - // Whether to allow unlimited memory. Alloc memory will not panic if it is true. - unlimited_flag: bool, - destroyed_thread_local_macro: bool, -} - -impl StatBuffer { - pub const fn empty() -> Self { - Self { - memory_usage: 0, - unlimited_flag: false, - destroyed_thread_local_macro: false, - } - } - - pub fn incr(&mut self, bs: i64) -> i64 { - self.memory_usage += bs; - self.memory_usage - } - - /// Flush buffered stat to MemStat it belongs to. - pub fn flush(&mut self) -> Result<(), OutOfLimit> { - let memory_usage = take(&mut self.memory_usage); + pub fn record_memory(batch: i64, cur: i64) -> Result<(), OutOfLimit> { let has_thread_local = TRACKER.try_with(|tracker: &RefCell| { // We need to ensure no heap memory alloc or dealloc. it will cause panic of borrow recursive call. - MemStat::record_memory::(&tracker.borrow().mem_stat, memory_usage) + let tracker = tracker.borrow(); + match tracker.mem_stat.as_deref() { + None => Ok(()), + Some(mem_stat) => mem_stat.record_memory::(batch, cur), + } }); match has_thread_local { - Ok(Ok(_)) => Ok(()), + Ok(Ok(_)) | Err(_) => Ok(()), Ok(Err(oom)) => Err(oom), - Err(_access_error) => MemStat::record_memory::(&None, memory_usage), - } - } -} - -/// Memory allocation stat. -/// -/// - A MemStat have child MemStat. -/// - Every stat that is fed to a child is also fed to its parent. -/// - A MemStat has at most one parent. -pub struct MemStat { - name: Option, - - used: AtomicI64, - - peak_used: AtomicI64, - - /// The limit of max used memory for this tracker. - /// - /// Set to 0 to disable the limit. - limit: AtomicI64, - - parent_memory_stat: Option>, -} - -impl MemStat { - pub const fn global() -> Self { - Self { - name: None, - used: AtomicI64::new(0), - limit: AtomicI64::new(0), - peak_used: AtomicI64::new(0), - parent_memory_stat: None, - } - } - - pub fn create(name: String) -> Arc { - let parent = MemStat::current(); - MemStat::create_child(name, parent) - } - - pub fn create_child(name: String, parent_memory_stat: Option>) -> Arc { - Arc::new(MemStat { - name: Some(name), - used: AtomicI64::new(0), - limit: AtomicI64::new(0), - peak_used: AtomicI64::new(0), - parent_memory_stat, - }) - } - - pub fn set_limit(&self, mut size: i64) { - // It may cause the process unable to run if memory limit is too low. - const LOWEST: i64 = 256 * 1024 * 1024; - - if size > 0 && size < LOWEST { - size = LOWEST; - } - - self.limit.store(size, Ordering::Relaxed); - } - - /// Feed memory usage stat to MemStat and return if it exceeds the limit. - /// - /// It feeds `state` to the this tracker and all of its ancestors, including GLOBAL_TRACKER. - #[inline] - pub fn record_memory( - mem_stat: &Option>, - memory_usage: i64, - ) -> Result<(), OutOfLimit> { - let mut is_root = false; - - let mem_stat = match mem_stat { - Some(x) => x, - None => { - // No parent, report to GLOBAL_TRACKER - is_root = true; - &GLOBAL_MEM_STAT - } - }; - - let mut used = mem_stat.used.fetch_add(memory_usage, Ordering::Relaxed); - - used += memory_usage; - mem_stat.peak_used.fetch_max(used, Ordering::Relaxed); - - if !is_root { - if let Err(cause) = - Self::record_memory::(&mem_stat.parent_memory_stat, memory_usage) - { - if NEED_ROLLBACK { - let used = mem_stat.used.fetch_sub(memory_usage, Ordering::Relaxed); - mem_stat - .peak_used - .fetch_max(used - memory_usage, Ordering::Relaxed); - } - - return Err(cause); - } - } - - if let Err(cause) = mem_stat.check_limit(used) { - if NEED_ROLLBACK { - let used = mem_stat.used.fetch_sub(memory_usage, Ordering::Relaxed); - mem_stat - .peak_used - .fetch_max(used - memory_usage, Ordering::Relaxed); - } - - return Err(cause); - } - - Ok(()) - } - - /// Check if used memory is out of the limit. - #[inline] - fn check_limit(&self, used: i64) -> Result<(), OutOfLimit> { - let limit = self.limit.load(Ordering::Relaxed); - - // No limit - if limit == 0 { - return Ok(()); - } - - if used <= limit { - return Ok(()); - } - - Err(OutOfLimit::new(used, limit)) - } - - #[inline] - pub fn current() -> Option> { - TRACKER.with(|f: &RefCell| f.borrow().mem_stat.clone()) - } - - #[inline] - pub fn get_memory_usage(&self) -> i64 { - self.used.load(Ordering::Relaxed) - } - - #[inline] - #[allow(unused)] - pub fn get_peak_memory_usage(&self) -> i64 { - self.peak_used.load(Ordering::Relaxed) - } - - #[allow(unused)] - pub fn log_memory_usage(&self) { - let name = self.name.clone().unwrap_or_else(|| String::from("global")); - let memory_usage = self.used.load(Ordering::Relaxed); - let memory_usage = std::cmp::max(0, memory_usage) as u64; - info!( - "Current memory usage({}): {}.", - name, - ByteSize::b(memory_usage) - ); - } - - #[allow(unused)] - pub fn log_peek_memory_usage(&self) { - let name = self.name.clone().unwrap_or_else(|| String::from("global")); - let peak_memory_usage = self.peak_used.load(Ordering::Relaxed); - let peak_memory_usage = std::cmp::max(0, peak_memory_usage) as u64; - info!( - "Peak memory usage({}): {}.", - name, - ByteSize::b(peak_memory_usage) - ); - } - - pub fn on_start_thread(self: &Arc) -> impl Fn() { - let mem_stat = self.clone(); - - move || { - let s = ThreadTracker::replace_mem_stat(Some(mem_stat.clone())); - - debug_assert!(s.is_none(), "a new thread must have no tracker"); } } } -pin_project! { - /// A [`Future`] that enters its thread tracker when being polled. - #[must_use = "futures do nothing unless you `.await` or poll them"] - pub struct TrackedFuture { - #[pin] - inner: T, - - mem_stat: Option>, - } -} - -impl TrackedFuture { - pub fn create(inner: T) -> TrackedFuture { - Self::create_with_mem_stat(MemStat::current(), inner) - } - - pub fn create_with_mem_stat(mem_stat: Option>, inner: T) -> Self { - Self { inner, mem_stat } - } -} - -impl Future for TrackedFuture { - type Output = T::Output; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - let _g = ThreadTracker::enter(this.mem_stat.clone()); - this.inner.poll(cx) - } -} - pin_project! { /// A [`Future`] that enters its thread tracker when being polled. #[must_use = "futures do nothing unless you `.await` or poll them"] @@ -559,135 +210,3 @@ impl Future for UnlimitedFuture { this.inner.poll(cx) } } - -#[cfg(test)] -mod tests { - mod async_thread_tracker { - use std::future::Future; - use std::pin::Pin; - use std::task::Context; - use std::task::Poll; - - use crate::runtime::runtime_tracker::STAT_BUFFER; - use crate::runtime::MemStat; - use crate::runtime::TrackedFuture; - - struct Foo { - i: usize, - } - - impl Future for Foo { - type Output = Vec; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let _ = cx; - let v = Vec::with_capacity(self.i * 1024 * 1024); - - Poll::Ready(v) - } - } - - #[test] - fn test_async_thread_tracker_normal_quit() -> anyhow::Result<()> { - // A future alloc memory and it should be tracked. - // The memory is passed out and is de-allocated outside the future and should not be tracked. - - let mem_stat = MemStat::create("test_async_thread_tracker_normal_quit".to_string()); - - let f = Foo { i: 3 }; - let f = TrackedFuture::create_with_mem_stat(Some(mem_stat.clone()), f); - - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build()?; - - let v = rt.block_on(f); - - let used = mem_stat.get_memory_usage(); - assert_eq!( - 3 * 1024 * 1024, - used, - "when future dropped, mem stat buffer is flushed" - ); - - drop(v); - - unsafe { - let _ = STAT_BUFFER.flush::(); - } - - let used = mem_stat.get_memory_usage(); - assert_eq!( - 3 * 1024 * 1024, - used, - "can not see mem dropped outside the future" - ); - - Ok(()) - } - } - - mod async_thread_tracker_panic { - use std::future::Future; - use std::pin::Pin; - use std::sync::Arc; - use std::task::Context; - use std::task::Poll; - - use crate::runtime::MemStat; - use crate::runtime::TrackedFuture; - - struct Foo { - i: usize, - } - - impl Future for Foo { - type Output = Vec; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let _ = cx; - let _v: Vec = Vec::with_capacity(self.i * 1024 * 1024); - panic!("foo"); - } - } - - #[test] - fn test_async_thread_tracker_panic() -> anyhow::Result<()> { - // A future alloc memory then panic. - // The memory stat should revert to 0. - // - // But it looks panicking allocates some memory. - // The used memory after the first panic stays stable. - - // Run a future in a one-shot runtime, return the used memory. - fn run_fut_in_rt(mem_stat: &Arc) -> i64 { - let f = Foo { i: 8 }; - let f = TrackedFuture::create_with_mem_stat(Some(mem_stat.clone()), f); - - let rt = tokio::runtime::Builder::new_multi_thread() - .worker_threads(5) - .enable_all() - .build() - .unwrap(); - - rt.block_on(async { - let h = crate::runtime::spawn(f); - let res = h.await; - assert!(res.is_err(), "panicked"); - }); - mem_stat.get_memory_usage() - } - - let mem_stat = MemStat::create("test_async_thread_tracker_panic".to_string()); - - let used0 = run_fut_in_rt(&mem_stat); - let used1 = run_fut_in_rt(&mem_stat); - - // The constantly used memory is about 1MB. - assert!(used1 - used0 < 1024 * 1024); - assert!(used0 - used1 < 1024 * 1024); - - Ok(()) - } - } -} diff --git a/src/common/base/src/runtime/thread.rs b/src/common/base/src/runtime/thread.rs index cc3899dc3dd4e..cded605b8ef01 100644 --- a/src/common/base/src/runtime/thread.rs +++ b/src/common/base/src/runtime/thread.rs @@ -18,7 +18,6 @@ use std::thread::JoinHandle; use databend_common_exception::ErrorCode; use databend_common_exception::Result; -use crate::runtime::MemStat; use crate::runtime::ThreadTracker; pub struct Thread; @@ -67,25 +66,18 @@ impl Thread { thread_builder = thread_builder.stack_size(5 * 1024 * 1024); } - let mut mem_stat_name = String::from("UnnamedThread"); - if let Some(named) = name.take() { - mem_stat_name = format!("{}Thread", named); thread_builder = thread_builder.name(named); } - ThreadJoinHandle::create(match MemStat::current() { - None => thread_builder.spawn(f).unwrap(), - Some(memory_tracker) => thread_builder + ThreadJoinHandle::create( + thread_builder .spawn(move || { - let c = MemStat::create_child(mem_stat_name, Some(memory_tracker)); - let s = ThreadTracker::replace_mem_stat(Some(c)); - debug_assert!(s.is_none(), "a new thread must have no tracker"); - + ThreadTracker::init(); f() }) .unwrap(), - }) + ) } pub fn spawn(f: F) -> ThreadJoinHandle diff --git a/src/common/base/tests/it/main.rs b/src/common/base/tests/it/main.rs index 7554579f57e1f..b01b56d905ab5 100644 --- a/src/common/base/tests/it/main.rs +++ b/src/common/base/tests/it/main.rs @@ -22,7 +22,6 @@ mod pool_retry; mod progress; mod range_merger; mod runtime; -mod runtime_tracker; mod stoppable; mod string; diff --git a/src/common/base/tests/it/runtime_tracker.rs b/src/common/base/tests/it/runtime_tracker.rs deleted file mode 100644 index 3d33dc858a651..0000000000000 --- a/src/common/base/tests/it/runtime_tracker.rs +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2022 Datafuse Labs. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::time::Duration; - -use databend_common_base::runtime::MemStat; -use databend_common_base::runtime::Runtime; -use databend_common_base::runtime::TrackedFuture; -use databend_common_base::runtime::TrySpawn; -use databend_common_base::GLOBAL_TASK; -use databend_common_exception::Result; - -#[tokio::test(flavor = "multi_thread", worker_threads = 8)] -async fn test_async_thread_tracker() -> Result<()> { - let (out_tx, out_rx) = async_channel::bounded(10); - let (inner_tx, inner_rx) = async_channel::bounded(10); - - let outer_runtime = Runtime::with_worker_threads(2, Some(String::from("Outer")))?; - let inner_runtime = Runtime::with_worker_threads(2, Some(String::from("Inner")))?; - - let memory_tracker = MemStat::create("test_async_thread_tracker".to_string()); - let inner_join_handler = inner_runtime.spawn( - GLOBAL_TASK, - TrackedFuture::create_with_mem_stat(Some(memory_tracker.clone()), async move { - let memory = vec![0_u8; 3 * 1024 * 1024]; - tokio::time::sleep(Duration::from_millis(100)).await; - out_tx.send(()).await.unwrap(); - inner_rx.recv().await.unwrap(); - drop(memory); - - let memory1 = vec![0_u8; 3 * 1024 * 1024]; - tokio::time::sleep(Duration::from_millis(100)).await; - out_tx.send(()).await.unwrap(); - inner_rx.recv().await.unwrap(); - - let memory2 = vec![0_u8; 2 * 1024 * 1024]; - tokio::time::sleep(Duration::from_millis(100)).await; - out_tx.send(()).await.unwrap(); - inner_rx.recv().await.unwrap(); - - drop(memory1); - tokio::time::sleep(Duration::from_millis(100)).await; - out_tx.send(()).await.unwrap(); - inner_rx.recv().await.unwrap(); - - drop(memory2); - tokio::time::sleep(Duration::from_millis(100)).await; - out_tx.send(()).await.unwrap(); - inner_rx.recv().await.unwrap(); - }), - ); - - let outer_join_handler = outer_runtime.spawn(GLOBAL_TASK, async move { - for (min_memory_usage, max_memory_usage) in [ - (2 * 1024 * 1024, 4 * 1024 * 1024), - (2 * 1024 * 1024, 4 * 1024 * 1024), - (4 * 1024 * 1024, 6 * 1024 * 1024), - (1024 * 1024, 3 * 1024 * 1024), - (0, 1024 * 1024), - ] { - out_rx.recv().await.unwrap(); - let memory_usage = memory_tracker.get_memory_usage(); - assert!(min_memory_usage <= memory_usage); - assert!(max_memory_usage > memory_usage); - inner_tx.send(()).await.unwrap(); - } - }); - - inner_join_handler.await.unwrap(); - outer_join_handler.await.unwrap(); - - drop(inner_runtime); - drop(outer_runtime); - - // println!("{}", memory_tracker2.get_memory_usage()); - // XXX: maybe memory tracker leak - // assert_eq!(memory_tracker2.get_memory_usage(), 0); - Ok(()) -} diff --git a/src/common/storage/src/runtime_layer.rs b/src/common/storage/src/runtime_layer.rs index 3c50f3400d9f4..23db1f99d4ed9 100644 --- a/src/common/storage/src/runtime_layer.rs +++ b/src/common/storage/src/runtime_layer.rs @@ -25,7 +25,6 @@ use async_trait::async_trait; use bytes::Bytes; use databend_common_base::base::tokio::task::JoinHandle; use databend_common_base::runtime::Runtime; -use databend_common_base::runtime::TrackedFuture; use databend_common_base::runtime::TrySpawn; use databend_common_base::GLOBAL_TASK; use futures::ready; @@ -115,10 +114,8 @@ impl LayeredAccessor for RuntimeAccessor { async fn create_dir(&self, path: &str, args: OpCreateDir) -> Result { let op = self.inner.clone(); let path = path.to_string(); - let future = async move { op.create_dir(&path, args).await }; - let future = TrackedFuture::create(future); self.runtime - .spawn(GLOBAL_TASK, future) + .spawn(GLOBAL_TASK, async move { op.create_dir(&path, args).await }) .await .expect("join must success") } @@ -128,11 +125,8 @@ impl LayeredAccessor for RuntimeAccessor { let op = self.inner.clone(); let path = path.to_string(); - let future = async move { op.read(&path, args).await }; - - let future = TrackedFuture::create(future); self.runtime - .spawn(GLOBAL_TASK, future) + .spawn(GLOBAL_TASK, async move { op.read(&path, args).await }) .await .expect("join must success") .map(|(rp, r)| { @@ -145,10 +139,8 @@ impl LayeredAccessor for RuntimeAccessor { async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> { let op = self.inner.clone(); let path = path.to_string(); - let future = async move { op.write(&path, args).await }; - let future = TrackedFuture::create(future); self.runtime - .spawn(GLOBAL_TASK, future) + .spawn(GLOBAL_TASK, async move { op.write(&path, args).await }) .await .expect("join must success") } @@ -157,10 +149,8 @@ impl LayeredAccessor for RuntimeAccessor { async fn stat(&self, path: &str, args: OpStat) -> Result { let op = self.inner.clone(); let path = path.to_string(); - let future = async move { op.stat(&path, args).await }; - let future = TrackedFuture::create(future); self.runtime - .spawn(GLOBAL_TASK, future) + .spawn(GLOBAL_TASK, async move { op.stat(&path, args).await }) .await .expect("join must success") } @@ -169,10 +159,8 @@ impl LayeredAccessor for RuntimeAccessor { async fn delete(&self, path: &str, args: OpDelete) -> Result { let op = self.inner.clone(); let path = path.to_string(); - let future = async move { op.delete(&path, args).await }; - let future = TrackedFuture::create(future); self.runtime - .spawn(GLOBAL_TASK, future) + .spawn(GLOBAL_TASK, async move { op.delete(&path, args).await }) .await .expect("join must success") } @@ -181,10 +169,8 @@ impl LayeredAccessor for RuntimeAccessor { async fn list(&self, path: &str, args: OpList) -> Result<(RpList, Self::Lister)> { let op = self.inner.clone(); let path = path.to_string(); - let future = async move { op.list(&path, args).await }; - let future = TrackedFuture::create(future); self.runtime - .spawn(GLOBAL_TASK, future) + .spawn(GLOBAL_TASK, async move { op.list(&path, args).await }) .await .expect("join must success") } @@ -243,7 +229,7 @@ impl oio::Read for RuntimeIO { buffer.set_len(buf.len()) } - let future = async move { + self.state = State::Read(self.runtime.spawn(GLOBAL_TASK, async move { let mut buffer = buffer; let res = r.read(&mut buffer).await; match res { @@ -254,9 +240,7 @@ impl oio::Read for RuntimeIO { } Err(err) => (r, Err(err)), } - }; - let future = TrackedFuture::create(future); - self.state = State::Read(self.runtime.spawn(GLOBAL_TASK, future)); + })); self.poll_read(cx, buf) } @@ -295,12 +279,11 @@ impl oio::Read for RuntimeIO { match &mut self.state { State::Idle(r) => { let mut r = r.take().expect("Idle must have a valid reader"); - let future = async move { + + self.state = State::Seek(self.runtime.spawn(GLOBAL_TASK, async move { let res = r.seek(pos).await; (r, res) - }; - let future = TrackedFuture::create(future); - self.state = State::Seek(self.runtime.spawn(GLOBAL_TASK, future)); + })); self.poll_seek(cx, pos) } @@ -329,12 +312,10 @@ impl oio::Read for RuntimeIO { match &mut self.state { State::Idle(r) => { let mut r = r.take().expect("Idle must have a valid reader"); - let future = async move { + self.state = State::Next(self.runtime.spawn(GLOBAL_TASK, async move { let res = r.next().await; (r, res) - }; - let future = TrackedFuture::create(future); - self.state = State::Next(self.runtime.spawn(GLOBAL_TASK, future)); + })); self.poll_next(cx) } diff --git a/src/query/service/src/pipelines/executor/executor_graph.rs b/src/query/service/src/pipelines/executor/executor_graph.rs index ced056a04935a..ac3b2511d4c66 100644 --- a/src/query/service/src/pipelines/executor/executor_graph.rs +++ b/src/query/service/src/pipelines/executor/executor_graph.rs @@ -19,7 +19,6 @@ use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use std::sync::Arc; -use databend_common_base::runtime::TrackedFuture; use databend_common_base::runtime::TrySpawn; use databend_common_exception::ErrorCode; use databend_common_exception::Result; @@ -403,7 +402,7 @@ impl ScheduleQueue { let process_future = proc.processor.async_process(); executor.async_runtime.spawn( query_id.as_ref().clone(), - TrackedFuture::create(ProcessorAsyncTask::create( + ProcessorAsyncTask::create( query_id, wakeup_worker_id, proc.processor.clone(), @@ -413,7 +412,7 @@ impl ScheduleQueue { node_profile, graph, process_future, - )) + ) .in_span(Span::enter_with_local_parent(std::any::type_name::< ProcessorAsyncTask, >())),