From 5c9f168fdafd6aeaf0591098fa1176eff814766a Mon Sep 17 00:00:00 2001 From: Levi Morrison Date: Thu, 13 Nov 2025 17:57:49 -0700 Subject: [PATCH 1/2] feat(profiling): parallel set and string set --- Cargo.lock | 2 + libdd-profiling/Cargo.toml | 2 + .../src/profiles/collections/arc.rs | 311 +++++++++++++++ .../src/profiles/collections/error.rs | 2 + .../src/profiles/collections/mod.rs | 4 + .../src/profiles/collections/parallel/mod.rs | 12 + .../src/profiles/collections/parallel/set.rs | 180 +++++++++ .../profiles/collections/parallel/sharded.rs | 188 +++++++++ .../collections/parallel/slice_set.rs | 369 ++++++++++++++++++ .../collections/parallel/string_set.rs | 232 +++++++++++ .../src/profiles/collections/set.rs | 8 +- 11 files changed, 1306 insertions(+), 4 deletions(-) create mode 100644 libdd-profiling/src/profiles/collections/arc.rs create mode 100644 libdd-profiling/src/profiles/collections/parallel/mod.rs create mode 100644 libdd-profiling/src/profiles/collections/parallel/set.rs create mode 100644 libdd-profiling/src/profiles/collections/parallel/sharded.rs create mode 100644 libdd-profiling/src/profiles/collections/parallel/slice_set.rs create mode 100644 libdd-profiling/src/profiles/collections/parallel/string_set.rs diff --git a/Cargo.lock b/Cargo.lock index 9c6b210fe7..46b4cc026b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2776,6 +2776,7 @@ dependencies = [ "bytes", "chrono", "criterion", + "crossbeam-utils", "futures", "hashbrown 0.16.0", "http", @@ -2788,6 +2789,7 @@ dependencies = [ "libdd-profiling-protobuf", "lz4_flex", "mime", + "parking_lot", "proptest", "prost", "rustc-hash 1.1.0", diff --git a/libdd-profiling/Cargo.toml b/libdd-profiling/Cargo.toml index c9d1938ad4..bc8eb70928 100644 --- a/libdd-profiling/Cargo.toml +++ b/libdd-profiling/Cargo.toml @@ -27,6 +27,7 @@ bitmaps = "3.2.0" byteorder = { version = "1.5", features = ["std"] } bytes = "1.1" chrono = {version = "0.4", default-features = false, features = ["std", "clock"]} +crossbeam-utils = { version = "0.8.21" } libdd-alloc = { version = "1.0.0", path = "../libdd-alloc" } libdd-profiling-protobuf = { version = "1.0.0", path = "../libdd-profiling-protobuf", features = ["prost_impls"] } libdd-common = { version = "1.0.0", path = "../libdd-common" } @@ -39,6 +40,7 @@ hyper-multipart-rfc7578 = "0.9.0" indexmap = "2.11" lz4_flex = { version = "0.9", default-features = false, features = ["std", "safe-encode", "frame"] } mime = "0.3.16" +parking_lot = { version = "0.12", default-features = false } prost = "0.13.5" rustc-hash = { version = "1.1", default-features = false } serde = {version = "1.0", features = ["derive"]} diff --git a/libdd-profiling/src/profiles/collections/arc.rs b/libdd-profiling/src/profiles/collections/arc.rs new file mode 100644 index 0000000000..a0a7c8f445 --- /dev/null +++ b/libdd-profiling/src/profiles/collections/arc.rs @@ -0,0 +1,311 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 +// This is heavily inspired by the standard library's `Arc` implementation, +// which is dual-licensed as Apache-2.0 or MIT. + +use allocator_api2::alloc::{AllocError, Allocator, Global}; +use allocator_api2::boxed::Box; +use core::sync::atomic::{fence, AtomicUsize, Ordering}; +use core::{alloc::Layout, fmt, mem::ManuallyDrop, ptr}; +use core::{marker::PhantomData, ops::Deref, ptr::NonNull}; +use crossbeam_utils::CachePadded; + +/// A thread-safe reference-counting pointer with only strong references. +/// +/// This type is similar to `std::sync::Arc` but intentionally omits APIs that +/// can panic or abort the process. In particular: +/// - There are no weak references. +/// - Cloning uses [`Arc::try_clone`], which returns an error on reference-count overflow instead of +/// aborting the process. +/// - Construction uses fallible allocation via [`Arc::try_new`]. +/// +/// Deref gives shared access to the inner value; mutation should use interior +/// mutability primitives as with `std::sync::Arc`. +#[repr(C)] +#[derive(Debug)] +pub struct Arc { + ptr: NonNull>, + alloc: A, + phantom: PhantomData>, +} + +// repr(C) prevents field reordering that could affect raw-pointer helpers. +#[repr(C)] +struct ArcInner { + refcount: CachePadded, + data: CachePadded, +} + +impl ArcInner { + fn from_ptr<'a>(ptr: *const T) -> &'a Self { + let data = ptr.cast::(); + let data_offset = Arc::::data_offset(); + let bytes_ptr = unsafe { data.sub(data_offset) }; + let arc_ptr = bytes_ptr as *mut ArcInner; + unsafe { &*arc_ptr } + } + + fn try_clone(&self) -> Result<(), ArcOverflow> { + if self.refcount.fetch_add(1, Ordering::Relaxed) > MAX_REFCOUNT { + self.refcount.fetch_sub(1, Ordering::Relaxed); + return Err(ArcOverflow); + } + Ok(()) + } +} + +impl Arc { + pub fn try_new(data: T) -> Result, AllocError> { + Self::try_new_in(data, Global) + } + + /// Tries to increment the reference count using only a pointer to the + /// inner `T`. This does not create an `Arc` instance. + /// + /// # Safety + /// - `ptr` must be a valid pointer to the `T` inside an `Arc` allocation produced by this + /// module. Passing any other pointer is undefined behavior. + /// - There must be at least one existing reference alive when called. + pub unsafe fn try_increment_count(ptr: *const T) -> Result<(), ArcOverflow> { + let inner = ArcInner::from_ptr(ptr); + inner.try_clone() + } +} + +impl Arc { + /// Constructs a new `Arc` in the provided allocator, returning an + /// error if allocation fails. + pub fn try_new_in(data: T, alloc: A) -> Result, AllocError> { + let inner = ArcInner { + refcount: CachePadded::new(AtomicUsize::new(1)), + data: CachePadded::new(data), + }; + let boxed = Box::try_new_in(inner, alloc)?; + let (ptr, alloc) = Box::into_non_null(boxed); + Ok(Arc { + ptr, + alloc, + phantom: PhantomData, + }) + } + + /// Returns the inner value, if the `Arc` has exactly one reference. + /// + /// Otherwise, an [`Err`] is returned with the same `Arc` that was passed + /// in. + /// + /// It is strongly recommended to use [`Arc::into_inner`] instead if you + /// don't keep the `Arc` in the [`Err`] case. + pub fn try_unwrap(this: Self) -> Result { + // Attempt to take unique ownership by transitioning strong: 1 -> 0 + let inner_ref = unsafe { this.ptr.as_ref() }; + if inner_ref + .refcount + .compare_exchange(1, 0, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + { + // We have unique ownership; move out T and deallocate without dropping T. + let this = ManuallyDrop::new(this); + let ptr = this.ptr.as_ptr(); + let alloc: A = unsafe { ptr::read(&this.alloc) }; + // Reconstruct a Box to ArcInner and convert into inner to avoid double-drop of T + let boxed: Box, A> = unsafe { Box::from_raw_in(ptr, alloc) }; + let ArcInner { refcount: _, data } = Box::into_inner(boxed); + // We moved out `data` above, so do not use `data` here; free already done via + // into_inner + Ok(CachePadded::into_inner(data)) + } else { + Err(this) + } + } + + pub fn into_inner(this: Self) -> Option { + // Prevent running Drop; we will manage the refcount and allocation manually. + let this = ManuallyDrop::new(this); + let inner = unsafe { this.ptr.as_ref() }; + if inner.refcount.fetch_sub(1, Ordering::Release) != 1 { + return None; + } + fence(Ordering::Acquire); + + // We are the last strong reference. Move out T and free the allocation + // without dropping T. + let ptr = this.ptr.as_ptr(); + let alloc: A = unsafe { ptr::read(&this.alloc) }; + let boxed: Box, A> = unsafe { Box::from_raw_in(ptr, alloc) }; + let ArcInner { refcount: _, data } = Box::into_inner(boxed); + Some(CachePadded::into_inner(data)) + } + + /// Returns a raw non-null pointer to the inner value. The pointer is valid + /// as long as there is at least one strong reference alive. + #[inline] + pub fn as_ptr(&self) -> NonNull { + let ptr = NonNull::as_ptr(self.ptr); + // SAFETY: `ptr` points to a valid `ArcInner` allocation. Taking the + // address of the `data` field preserves provenance unlike going + // through Deref. + let data = unsafe { ptr::addr_of_mut!((*ptr).data) }; + // SAFETY: data field address is derived from a valid NonNull. + unsafe { NonNull::new_unchecked(data as *mut T) } + } + + /// Converts the Arc into a non-null pointer to the inner value, without + /// decreasing the reference count. + /// + /// The caller must later call `Arc::from_raw` with the same pointer exactly + /// once to avoid leaking the allocation. + #[inline] + #[must_use = "losing the pointer will leak memory"] + pub fn into_raw(this: Self) -> NonNull { + let this = ManuallyDrop::new(this); + // Reuse as_ptr logic without dropping `this`. + Arc::as_ptr(&this) + } +} + +// SAFETY: `Arc` is Send and Sync iff `T` is Send and Sync. +unsafe impl Send for Arc {} +unsafe impl Sync for Arc {} + +impl Arc { + #[inline] + fn inner(&self) -> &ArcInner { + // SAFETY: `ptr` is a valid, live allocation managed by this Arc + unsafe { self.ptr.as_ref() } + } +} + +/// Error returned when the reference count would overflow. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ArcOverflow; + +impl fmt::Display for ArcOverflow { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("arc: reference count overflow") + } +} + +impl core::error::Error for ArcOverflow {} + +/// A limit on the amount of references that may be made to an `Arc`. +const MAX_REFCOUNT: usize = isize::MAX as usize; + +impl Arc { + /// Fallible clone that increments the strong reference count. + /// + /// Returns an error if the reference count would exceed `isize::MAX`. + pub fn try_clone(&self) -> Result { + let inner = self.inner(); + inner.try_clone()?; + Ok(Arc { + ptr: self.ptr, + alloc: self.alloc.clone(), + phantom: PhantomData, + }) + } +} + +impl Drop for Arc { + fn drop(&mut self) { + let inner = self.inner(); + if inner.refcount.fetch_sub(1, Ordering::Release) == 1 { + // Synchronize with other threads that might have modified the data + // before dropping the last strong reference. + // Raymond Chen wrote a little blog article about it: + // https://devblogs.microsoft.com/oldnewthing/20251015-00/?p=111686 + fence(Ordering::Acquire); + // SAFETY: this was the last strong reference; reclaim allocation + let ptr = self.ptr.as_ptr(); + // Move out allocator for deallocation, but prevent double-drop of `alloc` + let alloc: A = unsafe { ptr::read(&self.alloc) }; + unsafe { drop(Box::, A>::from_raw_in(ptr, alloc)) }; + } + } +} + +impl Deref for Arc { + type Target = T; + + fn deref(&self) -> &Self::Target { + // SAFETY: The allocation outlives `self` while any strong refs exist. + unsafe { &self.ptr.as_ref().data } + } +} + +impl Arc { + #[inline] + fn data_offset() -> usize { + // Layout of ArcInner is repr(C): [CachePadded, CachePadded] + let header = Layout::new::>(); + match header.extend(Layout::new::>()) { + Ok((_layout, offset)) => offset, + Err(_) => { + // Fallback: compute padding manually to avoid unwrap. This should + // not fail in practice for valid types. + let align = Layout::new::>().align(); + let size = header.size(); + let padding = (align - (size % align)) % align; + size + padding + } + } + } + + /// Recreates an `Arc` from a raw pointer produced by `Arc::into_raw`. + /// + /// # Safety + /// - `ptr` must have been returned by a previous call to `Arc::::into_raw`. + /// - if `ptr` has been cast, it needs to be to a compatible repr. + /// - It must not be used to create multiple owning `Arc`s without corresponding `into_raw` + /// calls; otherwise the refcount will be decremented multiple times. + #[inline] + pub unsafe fn from_raw_in(ptr: NonNull, alloc: A) -> Self { + let data = ptr.as_ptr() as *const u8; + let arc_ptr_u8 = data.sub(Self::data_offset()); + let arc_ptr = arc_ptr_u8 as *mut ArcInner; + Arc { + ptr: NonNull::new_unchecked(arc_ptr), + alloc, + phantom: PhantomData, + } + } +} + +impl Arc { + /// Recreates an `Arc` in the `Global` allocator from a raw pointer + /// produced by `Arc::into_raw`. + /// + /// # Safety + /// See [`Arc::from_raw_in`] for requirements. + #[inline] + pub unsafe fn from_raw(ptr: NonNull) -> Self { + Arc::from_raw_in(ptr, Global) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn try_new_and_unwrap_unique() { + let arc = Arc::try_new(123u32).unwrap(); + let v = Arc::try_unwrap(arc).ok().unwrap(); + assert_eq!(v, 123); + } + + #[test] + fn try_unwrap_fails_when_shared() { + let arc = Arc::try_new(5usize).unwrap(); + let arc2 = arc.try_clone().unwrap(); + assert!(Arc::try_unwrap(arc).is_err()); + assert_eq!(*arc2, 5); + } + + #[test] + fn deref_access() { + let arc = Arc::try_new("abc").unwrap(); + assert_eq!(arc.len(), 3); + assert_eq!(*arc, "abc"); + } +} diff --git a/libdd-profiling/src/profiles/collections/error.rs b/libdd-profiling/src/profiles/collections/error.rs index c35b772a26..57d3da1bef 100644 --- a/libdd-profiling/src/profiles/collections/error.rs +++ b/libdd-profiling/src/profiles/collections/error.rs @@ -8,6 +8,8 @@ pub enum SetError { InvalidArgument, #[error("set error: out of memory")] OutOfMemory, + #[error("set error: reference count overflow")] + ReferenceCountOverflow, } impl From for SetError { diff --git a/libdd-profiling/src/profiles/collections/mod.rs b/libdd-profiling/src/profiles/collections/mod.rs index 514402f7c4..bfce216e5d 100644 --- a/libdd-profiling/src/profiles/collections/mod.rs +++ b/libdd-profiling/src/profiles/collections/mod.rs @@ -1,7 +1,9 @@ // Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ // SPDX-License-Identifier: Apache-2.0 +mod arc; mod error; +mod parallel; mod set; mod slice_set; mod string_set; @@ -9,7 +11,9 @@ mod thin_str; pub type SetHasher = core::hash::BuildHasherDefault; +pub use arc::*; pub use error::*; +pub use parallel::*; pub use set::*; pub use slice_set::*; pub use string_set::*; diff --git a/libdd-profiling/src/profiles/collections/parallel/mod.rs b/libdd-profiling/src/profiles/collections/parallel/mod.rs new file mode 100644 index 0000000000..4c32464df8 --- /dev/null +++ b/libdd-profiling/src/profiles/collections/parallel/mod.rs @@ -0,0 +1,12 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +mod set; +mod sharded; +mod slice_set; +mod string_set; + +pub use set::*; +pub use sharded::*; +pub use slice_set::*; +pub use string_set::*; diff --git a/libdd-profiling/src/profiles/collections/parallel/set.rs b/libdd-profiling/src/profiles/collections/parallel/set.rs new file mode 100644 index 0000000000..b4e2e4e544 --- /dev/null +++ b/libdd-profiling/src/profiles/collections/parallel/set.rs @@ -0,0 +1,180 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +use crate::profiles::collections::parallel::sharded::Sharded; +use crate::profiles::collections::{Arc, Set, SetError, SetId, SET_MIN_CAPACITY}; +use core::hash; +use libdd_alloc::Global; +use std::ffi::c_void; +use std::hash::BuildHasher; +use std::ptr; + +#[derive(Debug)] +#[repr(C)] +pub struct ParallelSet { + pub(crate) storage: Arc, N>>, +} + +impl ParallelSet { + const fn is_power_of_two_gt1() -> bool { + N.is_power_of_two() && N > 1 + } + + pub fn try_new() -> Result { + if !Self::is_power_of_two_gt1() { + return Err(SetError::InvalidArgument); + } + let storage = Sharded::, N>::try_new_with_min_capacity(SET_MIN_CAPACITY)?; + let storage = Arc::try_new(storage)?; + Ok(Self { storage }) + } + + #[inline] + fn storage(&self) -> &Sharded, N> { + &self.storage + } + + pub fn try_clone(&self) -> Result { + let storage = self + .storage + .try_clone() + .map_err(|_| SetError::ReferenceCountOverflow)?; + Ok(Self { storage }) + } + + #[inline] + fn select_shard(hash: u64) -> usize { + (hash as usize) & (N - 1) + } + + pub fn try_insert(&self, value: T) -> Result, SetError> { + let hash = crate::profiles::collections::SetHasher::default().hash_one(&value); + let idx = Self::select_shard(hash); + let lock = &self.storage().shards[idx]; + + let read_len = { + let guard = lock.read(); + // SAFETY: `hash` was computed using this set's hasher over `&value`. + if let Some(id) = unsafe { guard.find_with_hash(hash, &value) } { + return Ok(id); + } + guard.len() + }; + + let mut guard = lock.write(); + if guard.len() != read_len { + // SAFETY: `hash` was computed using this set's hasher over `&value`. + if let Some(id) = unsafe { guard.find_with_hash(hash, &value) } { + return Ok(id); + } + } + + // SAFETY: `hash` was computed using this set's hasher over `&value`, + // and uniqueness has been enforced by the preceding read/write checks. + unsafe { guard.insert_unique_uncontended_with_hash(hash, value) } + .map_err(|_| SetError::OutOfMemory) + } + + /// Returns the `SetId` for `value` if it exists in the parallel set, without inserting. + /// Intended for tests and debugging; typical usage should prefer `try_insert`. + pub fn find(&self, value: &T) -> Option> { + let hash = crate::profiles::collections::SetHasher::default().hash_one(value); + let idx = Self::select_shard(hash); + let lock = &self.storage().shards[idx]; + let guard = lock.read(); + // SAFETY: `hash` computed using this set's hasher over `&value`. + unsafe { guard.find_with_hash(hash, value) } + } + + /// Returns a shared reference to the value for a given `SetId`. + /// + /// # Safety + /// - The `id` must have been obtained from this exact `ParallelSet` (and shard) instance, and + /// thus point to a live `T` stored in its arena. + /// # Safety + /// - `id` must come from this exact `ParallelSet` instance (same shard) and still refer to a + /// live element in its arena. + /// - The returned reference is immutable; do not concurrently mutate the same element via + /// interior mutability. + pub unsafe fn get(&self, id: SetId) -> &T { + // We do not need to lock to read the value; storage is arena-backed and + // values are immutable once inserted. Caller guarantees `id` belongs here. + unsafe { id.0.as_ref() } + } + + pub fn into_raw(self) -> ptr::NonNull { + Arc::into_raw(self.storage).cast() + } + + /// # Safety + /// - `this` must be produced by `into_raw` for a `ParallelSet` with matching `T`, `N`, + /// and allocator. + /// - After calling, do not use the original raw pointer again. + pub unsafe fn from_raw(this: ptr::NonNull) -> Self { + let storage = unsafe { Arc::from_raw_in(this.cast(), Global) }; + Self { storage } + } +} + +// SAFETY: uses `RwLock>` to synchronize access. All reads/writes in +// this wrapper go through those locks. All non-mut methods of +// `ParallelSetStorage` and `Set` are safe to call under a read-lock, and all +// mut methods are safe to call under a write-lock. +unsafe impl Send for ParallelSet {} +unsafe impl Sync for ParallelSet {} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + use std::collections::HashSet as StdHashSet; + + proptest! { + #![proptest_config(ProptestConfig { + cases: if cfg!(miri) { 4 } else { 64 }, + .. ProptestConfig::default() + })] + + #[test] + fn proptest_parallel_set_matches_std_hashset( + values in proptest::collection::vec(any::(), 0..if cfg!(miri) { 32 } else { 512 }) + ) { + type PSet = ParallelSet; + let set = PSet::try_new().unwrap(); + let mut shadow = StdHashSet::::new(); + + for v in &values { + shadow.insert(*v); + let _ = set.try_insert(*v).unwrap(); + } + + // Compare lengths + let len_pset = { + let s = set.storage(); + let mut acc = 0usize; + for shard in &s.shards { acc += shard.read().len(); } + acc + }; + prop_assert_eq!(len_pset, shadow.len()); + + // Each shadow value must be present and equal + for &v in &shadow { + let id = set.find(&v); + prop_assert!(id.is_some()); + let id = id.unwrap(); + // SAFETY: id just obtained from this set + let fetched = unsafe { set.get(id) }; + prop_assert_eq!(*fetched, v); + } + } + } + + #[test] + fn auto_traits_send_sync() { + fn require_send() {} + fn require_sync() {} + type PSet = ParallelSet; + require_send::(); + require_sync::(); + } +} diff --git a/libdd-profiling/src/profiles/collections/parallel/sharded.rs b/libdd-profiling/src/profiles/collections/parallel/sharded.rs new file mode 100644 index 0000000000..c5a08a3f30 --- /dev/null +++ b/libdd-profiling/src/profiles/collections/parallel/sharded.rs @@ -0,0 +1,188 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +use crate::profiles::collections::{ + Set, SetError, SetHasher as Hasher, SetId, SliceSet, ThinSlice, +}; +use core::hash::Hash; +use core::mem::MaybeUninit; +use crossbeam_utils::CachePadded; +use parking_lot::RwLock; + +/// Operations a set must provide for so that a sharded set can be built on +/// top of it. +/// +/// # Safety +/// +/// Implementors must ensure that all methods which take `&self` are safe to +/// call under a read-lock, and all `&mut self` methods are safe to call under +/// a write-lock, and are safe for `Send` and `Sync`. +pub unsafe trait SetOps { + type Lookup<'a>: Copy + where + Self: 'a; + + /// Owned payload used for insertion. For some containers (e.g. slice-backed + /// sets) this can be a borrowed view like `&'a [T]` because the container + /// copies data into its own arena during insertion. + type Owned<'a> + where + Self: 'a; + + type Id: Copy; + + fn try_with_capacity(capacity: usize) -> Result + where + Self: Sized; + + fn len(&self) -> usize; + + #[inline] + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// # Safety + /// Same safety contract as the underlying container's find_with_hash. + unsafe fn find_with_hash(&self, hash: u64, key: Self::Lookup<'_>) -> Option; + + /// # Safety + /// Same safety contract as the underlying container's insert_unique_uncontended_with_hash. + unsafe fn insert_unique_uncontended_with_hash( + &mut self, + hash: u64, + key: Self::Owned<'_>, + ) -> Result; +} + +#[derive(Debug)] +pub struct Sharded { + pub(crate) shards: [CachePadded>; N], +} + +impl Sharded { + #[inline] + pub const fn is_power_of_two_gt1() -> bool { + N.is_power_of_two() && N > 1 + } + + #[inline] + pub fn select_shard(hash: u64) -> usize { + (hash as usize) & (N - 1) + } + + pub fn try_new_with_min_capacity(min_capacity: usize) -> Result { + if !Self::is_power_of_two_gt1() { + return Err(SetError::InvalidArgument); + } + let mut shards_uninit: [MaybeUninit>>; N] = + unsafe { MaybeUninit::uninit().assume_init() }; + let mut i = 0usize; + while i < N { + match S::try_with_capacity(min_capacity) { + Ok(inner) => { + shards_uninit[i].write(CachePadded::new(RwLock::new(inner))); + i += 1; + } + Err(e) => { + for j in (0..i).rev() { + unsafe { shards_uninit[j].assume_init_drop() }; + } + return Err(e); + } + } + } + let shards: [CachePadded>; N] = + unsafe { core::mem::transmute_copy(&shards_uninit) }; + // If N=0, then we error at the very top of the function, so we know + // there's at least one. + Ok(Self { shards }) + } + + pub fn try_insert_common<'a>( + &self, + lookup: S::Lookup<'a>, + owned: S::Owned<'a>, + ) -> Result + where + S::Lookup<'a>: Hash + PartialEq, + { + use std::hash::BuildHasher; + let hash = Hasher::default().hash_one(lookup); + let idx = Self::select_shard(hash); + let lock = &self.shards[idx]; + + let read_len = { + let guard = lock.read(); + if let Some(id) = unsafe { guard.find_with_hash(hash, lookup) } { + return Ok(id); + } + guard.len() + }; + + let mut guard = lock.write(); + if guard.len() != read_len { + if let Some(id) = unsafe { guard.find_with_hash(hash, lookup) } { + return Ok(id); + } + } + + unsafe { guard.insert_unique_uncontended_with_hash(hash, owned) } + } +} + +// SAFETY: relies on safety requirements of `SetOps`. +unsafe impl Send for Sharded {} +unsafe impl Sync for Sharded {} + +unsafe impl SetOps for Set { + type Lookup<'a> = &'a T; + type Owned<'a> = T; + type Id = SetId; + + fn try_with_capacity(capacity: usize) -> Result { + Set::try_with_capacity(capacity) + } + + fn len(&self) -> usize { + self.len() + } + + unsafe fn find_with_hash(&self, hash: u64, key: Self::Lookup<'_>) -> Option { + self.find_with_hash(hash, key) + } + + unsafe fn insert_unique_uncontended_with_hash( + &mut self, + hash: u64, + value: Self::Owned<'_>, + ) -> Result { + self.insert_unique_uncontended_with_hash(hash, value) + } +} + +unsafe impl SetOps for SliceSet { + type Lookup<'a> = &'a [T]; + type Owned<'a> = &'a [T]; + type Id = ThinSlice<'static, T>; + + fn try_with_capacity(capacity: usize) -> Result { + SliceSet::try_with_capacity(capacity) + } + + fn len(&self) -> usize { + self.len() + } + + unsafe fn find_with_hash(&self, hash: u64, key: Self::Lookup<'_>) -> Option { + unsafe { self.find_with_hash(hash, key) } + } + + unsafe fn insert_unique_uncontended_with_hash( + &mut self, + hash: u64, + key: Self::Owned<'_>, + ) -> Result { + unsafe { self.insert_unique_uncontended_with_hash(hash, key) } + } +} diff --git a/libdd-profiling/src/profiles/collections/parallel/slice_set.rs b/libdd-profiling/src/profiles/collections/parallel/slice_set.rs new file mode 100644 index 0000000000..0fedce7915 --- /dev/null +++ b/libdd-profiling/src/profiles/collections/parallel/slice_set.rs @@ -0,0 +1,369 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +use crate::profiles::collections::{ + Arc, ArcOverflow, SetError, SetHasher as Hasher, SliceSet, ThinSlice, +}; +use std::hash::{self, BuildHasher}; +use std::ops::Deref; + +/// Number of shards used by the parallel slice set and (by extension) +/// the string-specific parallel set. Kept as a constant so tests and +/// related code can refer to the same value. +pub const N_SHARDS: usize = 16; + +/// The initial capacities for Rust's hash map (and set) currently go +/// like this: 3, 7, 14, 28. We want to avoid some of the smaller sizes so +/// that there's less frequent re-allocation, which is the most expensive +/// part of the set's operations. +const HASH_TABLE_MIN_CAPACITY: usize = 28; + +pub type ParallelSliceStorage = super::Sharded, N_SHARDS>; + +/// A slice set which can have parallel read and write operations. It works +/// by sharding the set and using read-write locks local to each shard. Items +/// cannot be removed from the set; the implementation relies on this for a +/// variety of optimizations. +/// +/// This is a fairly naive implementation. Unfortunately, dashmap and other +/// off-the-shelf implementations I looked at don't have adequate APIs for +/// avoiding panics, including handling allocation failures. Since we're +/// rolling our own, we can get some benefits like having wait-free lookups +/// when fetching the value associated to an ID. +/// +/// Also unfortunately, even parking_lot's RwLock doesn't handle allocation +/// failures. But I'm not going to go _that_ far to avoid allocation failures +/// today. We're very unlikely to run out of memory while adding a waiter to +/// its queue, because the amount of memory used is bounded by the number of +/// threads, which is small. +#[repr(transparent)] +pub struct ParallelSliceSet { + pub(crate) arc: Arc>, +} + +// SAFETY: uses `RwLock>` to synchronize access. All reads/writes +// in this wrapper go through those locks. All non-mut methods of +// `ParallelSliceStorage` and `Set` are safe to call under a read-lock, and all +// mut methods are safe to call under a write-lock. +unsafe impl Send for ParallelSliceSet {} +unsafe impl Sync for ParallelSliceSet {} + +impl Deref for ParallelSliceSet { + type Target = ParallelSliceStorage; + fn deref(&self) -> &Self::Target { + &self.arc + } +} + +impl ParallelSliceSet { + pub fn try_clone(&self) -> Result, ArcOverflow> { + let ptr = self.arc.try_clone().map_err(|_| ArcOverflow)?; + Ok(ParallelSliceSet { arc: ptr }) + } + + pub const fn select_shard(hash: u64) -> usize { + // Use lower bits for shard selection to avoid interfering with + // Swiss tables' internal SIMD comparisons that use upper 7 bits. + // Using 4 bits provides resilience against hash function deficiencies + // and optimal scaling for low thread counts. + (hash & 0b1111) as usize + } + + /// Tries to create a new parallel slice set. + pub fn try_new() -> Result { + let storage = ParallelSliceStorage::try_new_with_min_capacity(HASH_TABLE_MIN_CAPACITY)?; + let ptr = Arc::try_new(storage)?; + Ok(Self { arc: ptr }) + } + + /// # Safety + /// The slice must not have been inserted yet, as it skips checking if + /// the slice is already present. + pub unsafe fn insert_unique_uncontended( + &self, + slice: &[T], + ) -> Result, SetError> + where + T: hash::Hash, + { + let hash = Hasher::default().hash_one(slice); + let shard_idx = Self::select_shard(hash); + let lock = &self.shards[shard_idx]; + let mut guard = lock.write(); + guard.insert_unique_uncontended(slice) + } + + /// Adds the slice to the slice set if it isn't present already, and + /// returns a handle to the slice that can be used to retrieve it later. + pub fn try_insert(&self, slice: &[T]) -> Result, SetError> + where + T: hash::Hash + PartialEq, + { + // Hash once and reuse it for all operations. + // Do this without holding any locks. + let hash = Hasher::default().hash_one(slice); + let shard_idx = Self::select_shard(hash); + let lock = &self.shards[shard_idx]; + + let read_len = { + let guard = lock.read(); + // SAFETY: the slice's hash is correct, we use the same hasher as + // SliceSet uses. + if let Some(id) = unsafe { guard.deref().find_with_hash(hash, slice) } { + return Ok(id); + } + guard.len() + }; + + let mut write_guard = lock.write(); + let write_len = write_guard.slices.len(); + // This is an ABA defense. It's simple because we only support insert. + if write_len != read_len { + // SAFETY: the slice's hash is correct, we use the same hasher as + // SliceSet uses. + if let Some(id) = unsafe { write_guard.find_with_hash(hash, slice) } { + return Ok(id); + } + } + + // SAFETY: we just checked above that the slice isn't in the set. + let id = unsafe { write_guard.insert_unique_uncontended_with_hash(hash, slice)? }; + Ok(id) + } +} + +#[cfg(test)] +mod tests { + use crate::profiles::collections::string_set::{StringRef, UnsyncStringSet}; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + /// A test struct representing a function with file and function names. + /// This tests that the generic slice infrastructure works with composite types. + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + struct Function { + file_name: StringRef, + function_name: StringRef, + } + + impl Function { + fn new(file_name: StringRef, function_name: StringRef) -> Self { + Self { + file_name, + function_name, + } + } + } + + #[test] + fn test_function_deduplication() { + // Create string set for the string data + let mut string_set = UnsyncStringSet::try_new().unwrap(); + + // Create some strings + let file1 = string_set.try_insert("src/main.rs").unwrap(); + let file2 = string_set.try_insert("src/lib.rs").unwrap(); + let func1 = string_set.try_insert("main").unwrap(); + let func2 = string_set.try_insert("process_data").unwrap(); + + // Create function objects + let fn1 = Function::new(file1, func1); // main in src/main.rs + let fn2 = Function::new(file1, func2); // process_data in src/main.rs + let fn3 = Function::new(file2, func1); // main in src/lib.rs + let fn4 = Function::new(file1, func1); // main in src/main.rs (duplicate of fn1) + + // Test that the functions are equal/different as expected + assert_eq!(fn1, fn4, "Same function should be equal"); + assert_ne!(fn1, fn2, "Different functions should not be equal"); + assert_ne!(fn1, fn3, "Different functions should not be equal"); + assert_ne!(fn2, fn3, "Different functions should not be equal"); + + // Test that we can distinguish them by their components + unsafe { + assert_eq!(string_set.get_string(fn1.file_name), "src/main.rs"); + assert_eq!(string_set.get_string(fn1.function_name), "main"); + assert_eq!(string_set.get_string(fn2.function_name), "process_data"); + assert_eq!(string_set.get_string(fn3.file_name), "src/lib.rs"); + } + } + + #[test] + fn auto_traits_send_sync() { + fn require_send() {} + fn require_sync() {} + require_send::>(); + require_sync::>(); + } + + #[test] + fn test_function_hashing() { + let mut string_set = UnsyncStringSet::try_new().unwrap(); + + let file1 = string_set.try_insert("src/main.rs").unwrap(); + let func1 = string_set.try_insert("main").unwrap(); + let func2 = string_set.try_insert("process_data").unwrap(); + + let fn1 = Function::new(file1, func1); + let fn2 = Function::new(file1, func2); + let fn1_copy = Function::new(file1, func1); + + // Test hash consistency + let hash1 = { + let mut hasher = DefaultHasher::new(); + fn1.hash(&mut hasher); + hasher.finish() + }; + + let hash1_copy = { + let mut hasher = DefaultHasher::new(); + fn1_copy.hash(&mut hasher); + hasher.finish() + }; + + let hash2 = { + let mut hasher = DefaultHasher::new(); + fn2.hash(&mut hasher); + hasher.finish() + }; + + // Same function should have same hash + assert_eq!(hash1, hash1_copy, "Same function should hash consistently"); + + // Different functions should have different hashes (with high probability) + assert_ne!( + hash1, hash2, + "Different functions should have different hashes" + ); + } + + #[test] + fn test_function_composition() { + let mut string_set = UnsyncStringSet::try_new().unwrap(); + + let file1 = string_set.try_insert("src/utils.rs").unwrap(); + let func1 = string_set.try_insert("calculate_hash").unwrap(); + + let function = Function::new(file1, func1); + + // Test that we can access the components + assert_eq!(function.file_name, file1); + assert_eq!(function.function_name, func1); + + // Test that the string data is preserved + unsafe { + assert_eq!(string_set.get_string(function.file_name), "src/utils.rs"); + assert_eq!( + string_set.get_string(function.function_name), + "calculate_hash" + ); + } + } + + #[test] + fn test_many_functions() { + let mut string_set = UnsyncStringSet::try_new().unwrap(); + + // Create a variety of file and function names + let files = [ + "src/main.rs", + "src/lib.rs", + "src/utils.rs", + "src/parser.rs", + "src/codegen.rs", + ]; + let functions = [ + "main", "new", "process", "parse", "generate", "validate", "cleanup", "init", + ]; + + let mut file_ids = Vec::new(); + let mut func_ids = Vec::new(); + + // Create string IDs + for &file in &files { + file_ids.push(string_set.try_insert(file).unwrap()); + } + for &func in &functions { + func_ids.push(string_set.try_insert(func).unwrap()); + } + + let mut functions_created = Vec::new(); + + // Create many function combinations + for &file_id in &file_ids { + for &func_id in &func_ids { + let function = Function::new(file_id, func_id); + functions_created.push(function); + } + } + + // Should have files.len() * functions.len() unique functions + assert_eq!(functions_created.len(), files.len() * functions.len()); + + // Test that all functions are different (no duplicates) + for i in 0..functions_created.len() { + for j in i + 1..functions_created.len() { + assert_ne!( + functions_created[i], functions_created[j], + "All function combinations should be different" + ); + } + } + + // Test that we can retrieve the original strings + for function in &functions_created { + unsafe { + let file_str = string_set.get_string(function.file_name); + let func_str = string_set.get_string(function.function_name); + + // Verify the strings are in our original arrays + assert!( + files.contains(&file_str), + "File name should be from our test set" + ); + assert!( + functions.contains(&func_str), + "Function name should be from our test set" + ); + } + } + } + + #[test] + fn test_function_edge_cases() { + let mut string_set = UnsyncStringSet::try_new().unwrap(); + + // Test with empty strings + let empty_file = string_set.try_insert("").unwrap(); + let empty_func = string_set.try_insert("").unwrap(); + let normal_file = string_set.try_insert("normal.rs").unwrap(); + let normal_func = string_set.try_insert("normal_function").unwrap(); + + let fn1 = Function::new(empty_file, empty_func); // Both empty + let fn2 = Function::new(empty_file, normal_func); // Empty file, normal function + let fn3 = Function::new(normal_file, empty_func); // Normal file, empty function + let fn4 = Function::new(normal_file, normal_func); // Both normal + + // All should be different + let functions = [fn1, fn2, fn3, fn4]; + for i in 0..functions.len() { + for j in i + 1..functions.len() { + assert_ne!( + functions[i], functions[j], + "Functions with different components should not be equal" + ); + } + } + + // Test that we can retrieve the correct strings + unsafe { + assert_eq!(string_set.get_string(fn1.file_name), ""); + assert_eq!(string_set.get_string(fn1.function_name), ""); + assert_eq!(string_set.get_string(fn2.file_name), ""); + assert_eq!(string_set.get_string(fn2.function_name), "normal_function"); + assert_eq!(string_set.get_string(fn3.file_name), "normal.rs"); + assert_eq!(string_set.get_string(fn3.function_name), ""); + assert_eq!(string_set.get_string(fn4.file_name), "normal.rs"); + assert_eq!(string_set.get_string(fn4.function_name), "normal_function"); + } + } +} diff --git a/libdd-profiling/src/profiles/collections/parallel/string_set.rs b/libdd-profiling/src/profiles/collections/parallel/string_set.rs new file mode 100644 index 0000000000..edee55dc77 --- /dev/null +++ b/libdd-profiling/src/profiles/collections/parallel/string_set.rs @@ -0,0 +1,232 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +use crate::profiles::collections::parallel::slice_set::{ParallelSliceSet, ParallelSliceStorage}; +use crate::profiles::collections::{Arc, ArcOverflow, SetError, StringRef, WELL_KNOWN_STRING_REFS}; +use core::ptr; +use std::ffi::c_void; +use std::ops::Deref; + +/// A string set which can have parallel read and write operations. +/// This is a newtype wrapper around ParallelSliceSet that adds +/// string-specific functionality like well-known strings. +#[repr(transparent)] +pub struct ParallelStringSet { + pub(crate) inner: ParallelSliceSet, +} + +impl Deref for ParallelStringSet { + type Target = ParallelSliceSet; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl ParallelStringSet { + /// Consumes the `ParallelStringSet`, returning a non-null pointer to the + /// inner storage. This storage should not be mutated--it only exists to + /// be passed across FFI boundaries, which is why its type has been erased. + #[inline] + pub fn into_raw(self) -> ptr::NonNull { + Arc::into_raw(self.inner.arc).cast() + } + + /// Recreates a `ParallelStringSet` from a raw pointer produced by + /// [`ParallelStringSet::into_raw`]. + /// + /// # Safety + /// + /// The pointer must have been produced by [`ParallelStringSet::into_raw`] + /// and be returned unchanged. + #[inline] + pub unsafe fn from_raw(raw: ptr::NonNull) -> Self { + let arc = Arc::from_raw(raw.cast::>()); + Self { + inner: ParallelSliceSet { arc }, + } + } + + pub fn try_clone(&self) -> Result { + Ok(ParallelStringSet { + inner: self.inner.try_clone()?, + }) + } + + /// Tries to create a new parallel string set that contains the well-known + /// strings, including the empty string. + pub fn try_new() -> Result { + let inner = ParallelSliceSet::try_new()?; + let set = Self { inner }; + + for id in WELL_KNOWN_STRING_REFS.iter() { + // SAFETY: the well-known strings are unique, and we're in the + // constructor where other threads don't have access to it yet. + _ = unsafe { set.insert_unique_uncontended(id.0.deref())? }; + } + Ok(set) + } + + /// # Safety + /// The string must not have been inserted yet, as it skips checking if + /// the string is already present. + pub unsafe fn insert_unique_uncontended(&self, str: &str) -> Result { + let thin_slice = self.inner.insert_unique_uncontended(str.as_bytes())?; + Ok(StringRef(thin_slice.into())) + } + + /// Adds the string to the string set if it isn't present already, and + /// returns a handle to the string that can be used to retrieve it later. + pub fn try_insert(&self, str: &str) -> Result { + let thin_slice = self.inner.try_insert(str.as_bytes())?; + Ok(StringRef(thin_slice.into())) + } + + /// Selects which shard a hash should go to (0-3 for 4 shards). + pub fn select_shard(hash: u64) -> usize { + ParallelSliceSet::::select_shard(hash) + } + + /// # Safety + /// The caller must ensure that the StringId is valid for this set. + pub unsafe fn get(&self, id: StringRef) -> &str { + // SAFETY: safe as long as caller respects this function's safety. + unsafe { core::mem::transmute::<&str, &str>(id.0.deref()) } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::profiles::collections::parallel::slice_set::N_SHARDS; + use crate::profiles::collections::SetHasher as Hasher; + use std::hash::BuildHasher; + + #[test] + fn test_well_known_strings() { + let strs: [&str; WELL_KNOWN_STRING_REFS.len()] = [ + "", + "end_timestamp_ns", + "local root span id", + "trace endpoint", + "span id", + ]; + for (expected, id) in strs.iter().copied().zip(WELL_KNOWN_STRING_REFS) { + let actual: &str = id.0.deref(); + assert_eq!(expected, actual); + } + + let mut selected = [0; WELL_KNOWN_STRING_REFS.len()]; + for (id, dst) in WELL_KNOWN_STRING_REFS.iter().zip(selected.iter_mut()) { + *dst = ParallelStringSet::select_shard(Hasher::default().hash_one(id.0.deref())); + } + } + + #[test] + fn test_parallel_set() { + let set = ParallelStringSet::try_new().unwrap(); + // SAFETY: these are all well-known strings. + unsafe { + let str = set.get(StringRef::EMPTY); + assert_eq!(str, ""); + + let str = set.get(StringRef::END_TIMESTAMP_NS); + assert_eq!(str, "end_timestamp_ns"); + + let str = set.get(StringRef::LOCAL_ROOT_SPAN_ID); + assert_eq!(str, "local root span id"); + + let str = set.get(StringRef::TRACE_ENDPOINT); + assert_eq!(str, "trace endpoint"); + + let str = set.get(StringRef::SPAN_ID); + assert_eq!(str, "span id"); + }; + + let id = set.try_insert("").unwrap(); + assert_eq!(&*id.0, &*StringRef::EMPTY.0); + + let id = set.try_insert("end_timestamp_ns").unwrap(); + assert_eq!(&*id.0, &*StringRef::END_TIMESTAMP_NS.0); + + let id = set.try_insert("local root span id").unwrap(); + assert_eq!(&*id.0, &*StringRef::LOCAL_ROOT_SPAN_ID.0); + + let id = set.try_insert("trace endpoint").unwrap(); + assert_eq!(&*id.0, &*StringRef::TRACE_ENDPOINT.0); + + let id = set.try_insert("span id").unwrap(); + assert_eq!(&*id.0, &*StringRef::SPAN_ID.0); + } + + #[test] + fn test_hash_distribution() { + let test_strings: Vec = (0..100).map(|i| format!("test_string_{}", i)).collect(); + + let mut shard_counts = [0; N_SHARDS]; + + for s in &test_strings { + let hash = Hasher::default().hash_one(s); + let shard = ParallelStringSet::select_shard(hash); + shard_counts[shard] += 1; + } + + // Verify that distribution is not completely degenerate + // (both shards should get at least some strings) + assert!(shard_counts[0] > 0, "Shard 0 got no strings"); + assert!(shard_counts[1] > 0, "Shard 1 got no strings"); + + // Print distribution for manual inspection + println!("Shard distribution: {:?}", shard_counts); + } + + #[test] + fn test_parallel_set_shard_selection() { + let set = ParallelStringSet::try_new().unwrap(); + + // Test with realistic strings that would appear in profiling + let test_strings = [ + // .NET method signatures + "System.String.Concat(System.Object)", + "Microsoft.Extensions.DependencyInjection.ServiceProvider.GetService(System.Type)", + "System.Text.Json.JsonSerializer.Deserialize(System.String)", + "MyNamespace.MyClass.MyMethod(Int32 id, String name)", + // File paths and URLs + "/usr/lib/x86_64-linux-gnu/libc.so.6", + "/var/run/datadog/apm.socket", + "https://api.datadoghq.com/api/v1/traces", + "/home/user/.local/share/applications/myapp.desktop", + "C:\\Program Files\\MyApp\\bin\\myapp.exe", + // Short common strings + "f", + "g", + ]; + + let mut ids = Vec::new(); + for &test_str in &test_strings { + let id = set.try_insert(test_str).unwrap(); + ids.push((test_str, id)); + } + + // Verify all strings can be retrieved correctly + for (original_str, id) in ids { + unsafe { + assert_eq!(set.get(id), original_str); + } + } + + // Test that inserting the same strings again returns the same IDs + for &test_str in &test_strings { + let id1 = set.try_insert(test_str).unwrap(); + let id2 = set.try_insert(test_str).unwrap(); + assert_eq!(&*id1.0, &*id2.0); + } + } + + #[test] + fn auto_traits_send_sync() { + fn require_send() {} + fn require_sync() {} + require_send::(); + require_sync::(); + } +} diff --git a/libdd-profiling/src/profiles/collections/set.rs b/libdd-profiling/src/profiles/collections/set.rs index ef378e9da0..6504f0c0f0 100644 --- a/libdd-profiling/src/profiles/collections/set.rs +++ b/libdd-profiling/src/profiles/collections/set.rs @@ -61,7 +61,7 @@ pub struct Set { } impl Set { - const SIZE_HINT: usize = 1024 * 1024; + pub const SIZE_HINT: usize = 1024 * 1024; pub fn try_new() -> Result { Self::try_with_capacity(SET_MIN_CAPACITY) @@ -145,7 +145,7 @@ impl Drop for Set { } impl Set { - fn try_with_capacity(capacity: usize) -> Result { + pub(crate) fn try_with_capacity(capacity: usize) -> Result { let arena = ChainAllocator::new_in(Self::SIZE_HINT, VirtualAllocator {}); let mut table = HashTable::new(); @@ -154,7 +154,7 @@ impl Set { Ok(Self { arena, table }) } - unsafe fn find_with_hash(&self, hash: u64, key: &T) -> Option> { + pub(crate) unsafe fn find_with_hash(&self, hash: u64, key: &T) -> Option> { let found = self .table // SAFETY: NonNull inside table points to live, properly aligned Ts. @@ -162,7 +162,7 @@ impl Set { Some(SetId(*found)) } - unsafe fn insert_unique_uncontended_with_hash( + pub(crate) unsafe fn insert_unique_uncontended_with_hash( &mut self, hash: u64, value: T, From 2f6bdcf3740887f8b63bfb60490a8f7fc7592872 Mon Sep 17 00:00:00 2001 From: Levi Morrison Date: Fri, 14 Nov 2025 11:05:20 -0700 Subject: [PATCH 2/2] test: exercise thread safety --- .../collections/parallel/string_set.rs | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/libdd-profiling/src/profiles/collections/parallel/string_set.rs b/libdd-profiling/src/profiles/collections/parallel/string_set.rs index edee55dc77..232cd874f4 100644 --- a/libdd-profiling/src/profiles/collections/parallel/string_set.rs +++ b/libdd-profiling/src/profiles/collections/parallel/string_set.rs @@ -229,4 +229,109 @@ mod tests { require_send::(); require_sync::(); } + + #[test] + fn test_thread_safety() { + use std::sync::{Arc, Barrier}; + use std::thread; + + // Create set and add some strings before sharing across threads + let set = ParallelStringSet::try_new().unwrap(); + let pre_inserted_strings = vec![ + "pre_inserted_1", + "pre_inserted_2", + "pre_inserted_3", + "pre_inserted_4", + ]; + let mut pre_inserted_ids = Vec::new(); + for s in &pre_inserted_strings { + let id = set.try_insert(s).unwrap(); + pre_inserted_ids.push((s.to_string(), id)); + } + + // Share the set across threads using try_clone (which clones the internal Arc) + let num_threads = 4; + let operations_per_thread = 50; + + // Keep a clone of the original set for final verification + let original_set = set.try_clone().unwrap(); + + // Create a barrier to ensure all threads start work simultaneously + let barrier = Arc::new(Barrier::new(num_threads)); + + // Spawn threads that will both read pre-existing strings and insert new ones + let handles: Vec<_> = (0..num_threads) + .map(|thread_id| { + let set = original_set.try_clone().unwrap(); + let pre_ids = pre_inserted_ids.clone(); + let barrier = Arc::clone(&barrier); + thread::spawn(move || { + // Wait for all threads to be spawned before starting work + barrier.wait(); + + // Read pre-existing strings (should be safe to read concurrently) + for (expected_str, id) in &pre_ids { + unsafe { + let actual_str = set.get(*id); + assert_eq!( + actual_str, + expected_str.as_str(), + "Pre-inserted string should be readable" + ); + } + } + + // Concurrently insert new strings + for i in 0..operations_per_thread { + let new_str = format!("thread_{}_string_{}", thread_id, i); + let id = set.try_insert(&new_str).unwrap(); + unsafe { + let retrieved = set.get(id); + assert_eq!(retrieved, new_str, "Inserted string should be retrievable"); + } + } + + // Try inserting strings that other threads might have inserted + for i in 0..operations_per_thread { + let shared_str = format!("shared_string_{}", i); + let id1 = set.try_insert(&shared_str).unwrap(); + let id2 = set.try_insert(&shared_str).unwrap(); + // Both should return the same ID (deduplication) + assert_eq!(&*id1.0, &*id2.0, "Duplicate inserts should return same ID"); + unsafe { + assert_eq!(set.get(id1), shared_str); + } + } + }) + }) + .collect(); + + // Wait for all threads to complete + for handle in handles { + handle.join().expect("Thread should not panic"); + } + + // Verify final state: all pre-inserted strings should still be readable + for (expected_str, id) in &pre_inserted_ids { + unsafe { + let actual_str = original_set.get(*id); + assert_eq!( + actual_str, + expected_str.as_str(), + "Pre-inserted strings should remain readable after concurrent operations" + ); + } + } + + // Verify that shared strings inserted by multiple threads are deduplicated + for i in 0..operations_per_thread { + let shared_str = format!("shared_string_{}", i); + let id1 = original_set.try_insert(&shared_str).unwrap(); + let id2 = original_set.try_insert(&shared_str).unwrap(); + assert_eq!( + &*id1.0, &*id2.0, + "Shared strings should be deduplicated correctly" + ); + } + } }