diff --git a/Cargo.lock b/Cargo.lock index 3be76e4ff..a9110cf17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -175,6 +175,16 @@ dependencies = [ "syn", ] +[[package]] +name = "atomic-wait" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a55b94919229f2c42292fd71ffa4b75e83193bffdd77b1e858cd55fd2d0b0ea8" +dependencies = [ + "libc", + "windows-sys 0.42.0", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -2132,6 +2142,17 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "papaya" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d17fbf29d99ed1d2a1fecdb37d08898790965c85fd2634ba4023ab9710089059" +dependencies = [ + "atomic-wait", + "seize", + "serde", +] + [[package]] name = "parking" version = "2.2.1" @@ -2482,6 +2503,7 @@ dependencies = [ "notify", "num_cpus", "once_cell", + "papaya", "parking_lot", "pprof2", "pretty_assertions", @@ -2495,6 +2517,7 @@ dependencies = [ "regex", "schemars", "seahash", + "seize", "serde", "serde_json", "serde_regex", @@ -2945,6 +2968,12 @@ dependencies = [ "libc", ] +[[package]] +name = "seize" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d659fa6f19e82a52ab8d3fff3c380bd8cc16462eaea411395618a38760eb85bc" + [[package]] name = "serde" version = "1.0.215" @@ -3921,6 +3950,21 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -3979,6 +4023,12 @@ dependencies = [ "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -3991,6 +4041,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -4003,6 +4059,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -4021,6 +4083,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -4033,6 +4101,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -4045,6 +4119,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -4057,6 +4137,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" diff --git a/Cargo.toml b/Cargo.toml index f7c2bd8cd..911a31551 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -157,6 +157,8 @@ libflate = "2.1.0" form_urlencoded = "1.2.1" enum_dispatch = "0.3.13" gxhash = "3.4.1" +papaya = { version = "0.1.3", features = ["serde"] } +seize = "0.4.5" [dependencies.hyper-util] version = "0.1" diff --git a/benches/cluster_map.rs b/benches/cluster_map.rs index 6a8b33bae..654926ef2 100644 --- a/benches/cluster_map.rs +++ b/benches/cluster_map.rs @@ -17,10 +17,10 @@ mod serde { fn serialize_to_protobuf(cm: &ClusterMap) -> Vec { let mut resources = Vec::new(); - for cluster in cm.iter() { + for (key, cluster) in cm.pin().iter() { resources.push( Resource::Cluster(Cluster { - locality: cluster.key().clone().map(From::from), + locality: key.clone().map(From::from), endpoints: cluster .endpoints .iter() @@ -110,12 +110,7 @@ mod ops { use shared::{gen_cluster_map, GenCluster}; fn compute_hash(gc: &GenCluster) -> usize { - let mut total_endpoints = 0; - - for kv in gc.cm.iter() { - total_endpoints += kv.endpoints.len(); - } - + let total_endpoints = gc.cm.pin().values().map(|v| v.endpoints.len()).sum(); assert_eq!(total_endpoints, gc.total_endpoints); total_endpoints } diff --git a/benches/shared.rs b/benches/shared.rs index e6748c19a..7854ed77a 100644 --- a/benches/shared.rs +++ b/benches/shared.rs @@ -676,7 +676,7 @@ pub fn gen_cluster_map(token_kind: TokenKind) -> GenCluster { // Now actually insert the endpoints, now that the order of keys is established, // annoying, but note we split out iteration versus insertion, otherwise we deadlock - let keys: Vec<_> = cm.iter().map(|kv| kv.key().clone()).collect(); + let keys: Vec<_> = cm.pin().iter().map(|(key, _)| key.clone()).collect(); let mut sets = std::collections::BTreeMap::new(); let mut token_generator = match token_kind { diff --git a/benches/token_router.rs b/benches/token_router.rs index 51dd0effd..db99e3144 100644 --- a/benches/token_router.rs +++ b/benches/token_router.rs @@ -14,8 +14,8 @@ fn token_router(b: Bencher, token_kind: &str) { let cm = std::sync::Arc::new(gc.cm); // Calculate the amount of bytes for all the tokens - for eps in cm.iter() { - for ep in &eps.value().endpoints { + for eps in cm.pin().values() { + for ep in &eps.endpoints { for tok in &ep.metadata.known.tokens { tokens.push(tok.clone()); } diff --git a/crates/test/tests/mesh.rs b/crates/test/tests/mesh.rs index 45dd35fb5..985d4b2e7 100644 --- a/crates/test/tests/mesh.rs +++ b/crates/test/tests/mesh.rs @@ -189,8 +189,9 @@ trace_test!(datacenter_discovery, { #[track_caller] fn assert_config(config: &quilkin::Config, datacenter: &quilkin::config::Datacenter) { let dcs = config.datacenters().read(); - let ipv4_dc = dcs.get(&std::net::Ipv4Addr::LOCALHOST.into()); - let ipv6_dc = dcs.get(&std::net::Ipv6Addr::LOCALHOST.into()); + let pin = dcs.pin(); + let ipv4_dc = pin.get(&std::net::Ipv4Addr::LOCALHOST.into()); + let ipv6_dc = pin.get(&std::net::Ipv6Addr::LOCALHOST.into()); match (ipv4_dc, ipv6_dc) { (Some(dc), None) => assert_eq!(&*dc, datacenter), diff --git a/deny.toml b/deny.toml index 0d4461fb2..09446f89b 100644 --- a/deny.toml +++ b/deny.toml @@ -60,6 +60,7 @@ allow = ["Apache-2.0", "MIT", "ISC", "BSD-3-Clause", "Unicode-3.0"] exceptions = [ { crate = "adler32", allow = ["Zlib"] }, { crate = "foldhash", allow = ["Zlib"] }, + { crate = "atomic-wait", allow = ["BSD-2-Clause"] }, # This license should not really be used for code, but here we are { crate = "notify", allow = ["CC0-1.0"] }, { crate = "ring", allow = ["OpenSSL"] }, diff --git a/src/collections/ttl.rs b/src/collections/ttl.rs index f2da7a1d0..4c3466d79 100644 --- a/src/collections/ttl.rs +++ b/src/collections/ttl.rs @@ -14,9 +14,6 @@ * limitations under the License. */ -use dashmap::mapref::entry::Entry as DashMapEntry; -use dashmap::mapref::one::{Ref, RefMut}; -use dashmap::DashMap; use tracing::warn; use std::hash::Hash; @@ -25,7 +22,7 @@ use std::sync::Arc; use std::time::Duration; use tokio::sync::oneshot::{channel, Receiver, Sender}; -pub use dashmap::try_result::TryResult; +type HashMap = papaya::HashMap; // Clippy isn't recognizing that these imports are used conditionally. #[allow(unused_imports)] @@ -94,7 +91,7 @@ impl std::ops::Deref for Value { /// Map contains the hash map implementation. struct Map { - inner: DashMap>, + inner: HashMap, gxhash::GxBuildHasher>, ttl: Duration, clock: Clock, shutdown_tx: Option>, @@ -134,15 +131,19 @@ where V: Send + Sync + 'static, { pub fn new(ttl: Duration, poll_interval: Duration) -> Self { - Self::initialize(DashMap::new(), ttl, poll_interval) + Self::initialize(<_>::default(), ttl, poll_interval) } #[allow(dead_code)] pub fn with_capacity(ttl: Duration, poll_interval: Duration, capacity: usize) -> Self { - Self::initialize(DashMap::with_capacity(capacity), ttl, poll_interval) + Self::initialize( + HashMap::with_capacity_and_hasher(capacity, <_>::default()), + ttl, + poll_interval, + ) } - fn initialize(inner: DashMap>, ttl: Duration, poll_interval: Duration) -> Self { + fn initialize(inner: HashMap>, ttl: Duration, poll_interval: Duration) -> Self { let (shutdown_tx, shutdown_rx) = channel(); let map = TtlMap(Arc::new(Map { inner, @@ -168,41 +169,38 @@ where } } -#[allow(dead_code)] impl TtlMap where K: Hash + Eq + Send + Sync + 'static, - V: Send + Sync, + V: Send + Sync + Clone, { /// Returns a reference to value corresponding to key. - pub fn get(&self, key: &K) -> Option>> { - let value = self.0.inner.get(key); - if let Some(ref value) = value { - value.update_expiration(self.0.ttl) + pub fn get(&self, key: &K) -> Option { + let pin = self.0.inner.pin(); + let value = pin.get(key); + if let Some(value) = value { + value.update_expiration(self.0.ttl); } - value + value.map(|value| value.value.clone()) } +} +impl TtlMap +where + K: Hash + Eq + Send + Sync + 'static, + V: Send + Sync, +{ /// Returns a reference to value corresponding to key. - pub fn try_get(&self, key: &K) -> TryResult>> { - let value = self.0.inner.try_get(key); - if let TryResult::Present(ref value) = value { - value.update_expiration(self.0.ttl) - } - - value - } - - /// Returns a mutable reference to value corresponding to key. - /// The value will be reset to expire at the configured TTL after the time of retrieval. - pub fn get_mut(&self, key: &K) -> Option>> { - let value = self.0.inner.get_mut(key); - if let Some(ref value) = value { + pub fn get_by_ref(&self, key: &K, and_then: impl FnOnce(&V) -> F) -> Option { + let pin = self.0.inner.pin(); + let value = pin.get(key); + if let Some(value) = value { value.update_expiration(self.0.ttl); + Some((and_then)(value)) + } else { + None } - - value } /// Returns the number of entries currently in the map. @@ -222,22 +220,32 @@ where /// Returns true if the map contains a value for the specified key. pub fn contains_key(&self, key: &K) -> bool { - self.0.inner.contains_key(key) + self.0.inner.pin().contains_key(key) } /// Inserts a key-value pair into the map. /// The value will be set to expire at the configured TTL after the time of insertion. /// If a previous value existed for this key, that value is returned. - pub fn insert(&self, key: K, value: V) -> Option { + pub fn insert(&self, key: K, value: V) { self.0 .inner - .insert(key, Value::new(value, self.0.ttl, self.0.clock.clone())) - .map(|value| value.value) + .pin() + .insert(key, Value::new(value, self.0.ttl, self.0.clock.clone())); } /// Removes a key-value pair from the map. pub fn remove(&self, key: K) -> bool { - self.0.inner.remove(&key).is_some() + self.0.inner.pin().remove(&key).is_some() + } + + /// Removes a key-value pair from the map. + #[cfg(test)] + pub fn remove_force_drop(&self, key: K) -> bool { + use papaya::Guard; + let guard = self.0.inner.guard(); + let removed = self.0.inner.remove(&key, &guard).is_some(); + guard.flush(); + removed } /// Removes all entries from the map @@ -245,25 +253,6 @@ where pub fn clear(&self) { self.0.inner.clear(); } - - /// Returns an entry for in-place updates of the specified key-value pair. - /// Note: This acquires a write lock on the map's shard that corresponds - /// to the entry. - pub fn entry(&self, key: K) -> Entry> { - let ttl = self.0.ttl; - match self.0.inner.entry(key) { - inner @ DashMapEntry::Occupied(_) => Entry::Occupied(OccupiedEntry { - inner, - ttl, - clock: self.0.clock.clone(), - }), - inner @ DashMapEntry::Vacant(_) => Entry::Vacant(VacantEntry { - inner, - ttl, - clock: self.0.clock.clone(), - }), - } - } } impl std::fmt::Debug @@ -292,87 +281,6 @@ where } } -/// A view into an occupied entry in the map. -pub struct OccupiedEntry<'a, K, V> { - inner: DashMapEntry<'a, K, V>, - ttl: Duration, - clock: Clock, -} - -/// A view into a vacant entry in the map. -pub struct VacantEntry<'a, K, V> { - inner: DashMapEntry<'a, K, V>, - ttl: Duration, - clock: Clock, -} - -/// A view into an entry in the map. -/// It may either be [`VacantEntry`] or [`OccupiedEntry`] -pub enum Entry<'a, K, V> { - Occupied(OccupiedEntry<'a, K, V>), - Vacant(VacantEntry<'a, K, V>), -} - -impl<'a, K, V> OccupiedEntry<'a, K, Value> -where - K: Eq + Hash, -{ - /// Returns a reference to the entry's value. - /// The value will be reset to expire at the configured TTL after the time of retrieval. - pub fn get(&self) -> &Value { - match &self.inner { - DashMapEntry::Occupied(entry) => { - let value = entry.get(); - value.update_expiration(self.ttl); - value - } - _ => unreachable!("BUG: entry type should be occupied"), - } - } - - #[allow(dead_code)] - /// Returns a mutable reference to the entry's value. - /// The value will be reset to expire at the configured TTL after the time of retrieval. - pub fn get_mut(&mut self) -> &mut Value { - match &mut self.inner { - DashMapEntry::Occupied(entry) => { - let value = entry.get_mut(); - value.update_expiration(self.ttl); - value - } - _ => unreachable!("BUG: entry type should be occupied"), - } - } - - #[allow(dead_code)] - /// Replace the entry's value with a new value, returning the old value. - /// The value will be set to expire at the configured TTL after the time of insertion. - pub fn insert(&mut self, value: V) -> Value { - match &mut self.inner { - DashMapEntry::Occupied(entry) => { - entry.insert(Value::new(value, self.ttl, self.clock.clone())) - } - _ => unreachable!("BUG: entry type should be occupied"), - } - } -} - -impl<'a, K, V> VacantEntry<'a, K, Value> -where - K: Eq + Hash, -{ - /// Set an entry's value. - /// The value will be set to expire at the configured TTL after the time of insertion. - pub fn insert(self, value: V) -> RefMut<'a, K, Value> { - match self.inner { - DashMapEntry::Vacant(entry) => { - entry.insert(Value::new(value, self.ttl, self.clock.clone())) - } - _ => unreachable!("BUG: entry type should be vacant"), - } - } -} - fn spawn_cleanup_task( map: Arc>, poll_interval: Duration, @@ -410,21 +318,13 @@ where return; }; - // Take a read lock first and check if there is at least 1 item to remove. - let has_expired_keys = map - .inner + let pin = map.inner.pin(); + let expired_keys = pin .iter() - .filter(|entry| entry.value().expiration_secs() <= now_secs) - .take(1) - .next() - .is_some(); - - // If we have work to do then, take a write lock. - if has_expired_keys { - // Go over the whole map in case anything expired - // since acquiring the write lock. - map.inner - .retain(|_, value| value.expiration_secs() > now_secs); + .filter(|(_, value)| value.expiration_secs() <= now_secs); + + for (key, _) in expired_keys { + map.inner.pin().remove(key); } } @@ -521,8 +421,8 @@ mod tests { map.insert(one.clone(), 1); map.insert(two.clone(), 2); - assert_eq!(map.get(&one).unwrap().value, 1); - assert_eq!(map.get(&two).unwrap().value, 2); + assert_eq!(map.get(&one).unwrap(), 1); + assert_eq!(map.get(&two).unwrap(), 2); } #[tokio::test] @@ -536,15 +436,17 @@ mod tests { Duration::from_secs(10), Duration::from_millis(10), ); - map.insert(one.clone(), 1); - let exp1 = map.get(&one).unwrap().expiration_secs(); + map.insert(one.clone(), 1); + let exp1 = map.0.inner.pin().get(&one).unwrap().expiration_secs(); time::advance(Duration::from_secs(2)).await; - let exp2 = map.get(&one).unwrap().expiration_secs(); + let _ = map.get(&one).unwrap(); + let exp2 = map.0.inner.pin().get(&one).unwrap().expiration_secs(); time::advance(Duration::from_secs(3)).await; - let exp3 = map.get(&one).unwrap().expiration_secs(); + let _ = map.get(&one).unwrap(); + let exp3 = map.0.inner.pin().get(&one).unwrap().expiration_secs(); assert!(exp1 < exp2); assert_eq!(2, exp2 - exp1); @@ -569,177 +471,6 @@ mod tests { assert!(map.contains_key(&two)); } - #[tokio::test] - async fn entry_occupied_insert_and_get() { - let (one, _) = address_pair(); - - let map = TtlMap::::new( - Duration::from_secs(10), - Duration::from_millis(10), - ); - map.insert(one.clone(), 1); - - match map.entry(one.clone()) { - Entry::Occupied(mut entry) => { - assert_eq!(entry.get().value, 1); - entry.insert(5); - } - _ => unreachable!("expected occupied entry"), - } - - assert_eq!(map.get(&one).unwrap().value, 5); - } - - #[tokio::test] - async fn entry_occupied_get_mut() { - let (one, _) = address_pair(); - - let map = TtlMap::::new( - Duration::from_secs(10), - Duration::from_millis(10), - ); - map.insert(one.clone(), 1); - - match map.entry(one.clone()) { - Entry::Occupied(mut entry) => { - entry.get_mut().value = 5; - } - _ => unreachable!("expected occupied entry"), - } - - assert_eq!(map.get(&one).unwrap().value, 5); - } - - #[tokio::test] - async fn entry_vacant_insert() { - let (one, _) = address_pair(); - - let map = TtlMap::::new( - Duration::from_secs(10), - Duration::from_millis(10), - ); - - match map.entry(one.clone()) { - Entry::Vacant(entry) => { - let mut e = entry.insert(1); - assert_eq!(e.value, 1); - e.value = 5; - } - _ => unreachable!("expected occupied entry"), - } - - assert_eq!(map.get(&one).unwrap().value, 5); - } - - #[tokio::test] - async fn entry_occupied_get_expiration() { - // Test that when we get a value via OccupiedEntry, we update its expiration. - time::pause(); - - let (one, _) = address_pair(); - - let map = TtlMap::::new( - Duration::from_secs(10), - Duration::from_millis(10), - ); - map.insert(one.clone(), 1); - - let exp1 = map.get(&one).unwrap().expiration_secs(); - - time::advance(Duration::from_secs(2)).await; - - let exp2 = match map.entry(one.clone()) { - Entry::Occupied(entry) => entry.get().expiration_secs(), - _ => unreachable!("expected occupied entry"), - }; - - assert!(exp1 < exp2); - assert_eq!(2, exp2 - exp1); - } - - #[tokio::test] - async fn entry_occupied_get_mut_expiration() { - // Test that when we get_mut a value via OccupiedEntry, we update its expiration. - time::pause(); - - let (one, _) = address_pair(); - - let map = TtlMap::::new( - Duration::from_secs(10), - Duration::from_millis(10), - ); - map.insert(one.clone(), 1); - - let exp1 = map.get(&one).unwrap().expiration_secs(); - - time::advance(Duration::from_secs(2)).await; - - let exp2 = match map.entry(one) { - Entry::Occupied(mut entry) => entry.get_mut().expiration_secs(), - _ => unreachable!("expected occupied entry"), - }; - - assert!(exp1 < exp2); - assert_eq!(2, exp2 - exp1); - } - - #[tokio::test] - async fn entry_occupied_insert_expiration() { - // Test that when we replace a value via OccupiedEntry, we update its expiration. - time::pause(); - - let (one, _) = address_pair(); - - let map = TtlMap::::new( - Duration::from_secs(10), - Duration::from_millis(10), - ); - map.insert(one.clone(), 1); - - let exp1 = map.get(&one).unwrap().expiration_secs(); - - time::advance(Duration::from_secs(2)).await; - - let old_exp1 = match map.entry(one.clone()) { - Entry::Occupied(mut entry) => entry.insert(9).expiration_secs(), - _ => unreachable!("expected occupied entry"), - }; - - let exp2 = map.get(&one).unwrap().expiration_secs(); - - assert_eq!(exp1, old_exp1); - assert!(exp1 < exp2); - assert_eq!(2, exp2 - exp1); - } - - #[tokio::test] - async fn entry_occupied_vacant_expiration() { - // Test that when we insert a value via VacantEntry, we update its expiration. - time::pause(); - - let (one, _) = address_pair(); - - let map = TtlMap::::new( - Duration::from_secs(10), - Duration::from_millis(10), - ); - - let exp1 = match map.entry(one.clone()) { - Entry::Vacant(entry) => entry.insert(9).expiration_secs(), - _ => unreachable!("expected vacant entry"), - }; - - time::advance(Duration::from_secs(2)).await; - - let exp2 = map.get(&one).unwrap().expiration_secs(); - - // Initial expiration should be set at our configured ttl. - assert_eq!(10, exp1); - - assert!(exp1 < exp2); - assert_eq!(2, exp2 - exp1); - } - #[tokio::test] async fn expiration_ttl() { // Test that when we expire entries at our configured ttl. @@ -750,10 +481,9 @@ mod tests { let ttl = Duration::from_secs(12); let map = TtlMap::::new(ttl, Duration::from_millis(10)); - let exp = match map.entry(one) { - Entry::Vacant(entry) => entry.insert(9).expiration_secs(), - _ => unreachable!("expected vacant entry"), - }; + assert!(map.0.inner.pin().get(&one).is_none()); + map.insert(one.clone(), 9); + let exp = map.0.inner.pin().get(&one).unwrap().expiration_secs(); // Check that it expires at our configured TTL. assert_eq!(12, exp); diff --git a/src/components/proxy/sessions.rs b/src/components/proxy/sessions.rs index 0ee3a57dd..1a92509eb 100644 --- a/src/components/proxy/sessions.rs +++ b/src/components/proxy/sessions.rs @@ -226,7 +226,12 @@ impl SessionPool { ) -> Result<(Option, PendingSends), super::PipelineError> { tracing::trace!(source=%key.source, dest=%key.dest, "SessionPool::get"); // If we already have a session for the key pairing, return that session. - if let Some(entry) = self.session_map.get(&key) { + if let Some((asn_info, upstream_sender)) = self.session_map.get_by_ref(&key, |value| { + ( + value.asn_info.as_ref().map(MetricsIpNetEntry::from), + value.upstream_sender.clone(), + ) + }) { tracing::trace!("returning existing session"); return Ok(( entry.asn_info.as_ref().map(MetricsIpNetEntry::from), @@ -398,9 +403,9 @@ impl SessionPool { /// Forces removal of session to make testing quicker. #[cfg(test)] async fn drop_session(&self, key: SessionKey) -> bool { - let is_removed = self.session_map.remove(key); + let is_removed = self.session_map.remove_force_drop(key); // Sleep because there's no async drop. - tokio::time::sleep(Duration::from_millis(100)).await; + tokio::time::sleep(Duration::from_millis(200)).await; is_removed } @@ -579,22 +584,6 @@ mod tests { ) } - #[tokio::test] - async fn insert_and_release_single_socket() { - let (pool, _receiver) = new_pool().await; - let key = ( - (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), - (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), - ) - .into(); - - let _session = pool.get(key).unwrap(); - - assert!(pool.drop_session(key).await); - - assert!(pool.has_no_allocated_sockets()); - } - #[tokio::test] async fn insert_and_release_multiple_sockets() { let (pool, _receiver) = new_pool().await; @@ -637,8 +626,12 @@ mod tests { let _socket1 = pool.get(key1).unwrap(); let _socket2 = pool.get(key2).unwrap(); assert_ne!( - pool.session_map.get(&key1).unwrap().socket_port, - pool.session_map.get(&key2).unwrap().socket_port + pool.session_map + .get_by_ref(&key1, |v| v.socket_port) + .unwrap(), + pool.session_map + .get_by_ref(&key2, |v| v.socket_port) + .unwrap() ); assert!(pool.drop_session(key1).await); @@ -663,8 +656,12 @@ mod tests { let _socket2 = pool.get(key2).unwrap(); assert_eq!( - pool.session_map.get(&key1).unwrap().socket_port, - pool.session_map.get(&key2).unwrap().socket_port + pool.session_map + .get_by_ref(&key1, |v| v.socket_port) + .unwrap(), + pool.session_map + .get_by_ref(&key2, |v| v.socket_port) + .unwrap() ); } diff --git a/src/config.rs b/src/config.rs index 7064792c8..e99e3a3af 100644 --- a/src/config.rs +++ b/src/config.rs @@ -297,10 +297,10 @@ impl Config { }); } DatacenterConfig::NonAgent { datacenters } => { - for entry in datacenters.read().iter() { - let host = entry.key().to_string(); - let qcmp_port = entry.qcmp_port; - let version = format!("{}-{qcmp_port}", entry.icao_code); + for (key, value) in datacenters.read().pin().iter() { + let host = key.to_string(); + let qcmp_port = value.qcmp_port; + let version = format!("{}-{qcmp_port}", value.icao_code); if client_state.version_matches(&host, &version) { continue; @@ -309,7 +309,7 @@ impl Config { let resource = crate::xds::Resource::Datacenter( crate::net::cluster::proto::Datacenter { qcmp_port: qcmp_port as _, - icao_code: entry.icao_code.to_string(), + icao_code: value.icao_code.to_string(), host: host.clone(), }, ); @@ -330,7 +330,7 @@ impl Config { let Ok(addr) = key.parse() else { continue; }; - if dc.get(&addr).is_none() { + if dc.pin().get(&addr).is_none() { removed.insert(key.clone()); } } @@ -366,8 +366,8 @@ impl Config { }; if client_state.subscribed.is_empty() { - for cluster in self.clusters.read().iter() { - push(cluster.key(), cluster.value())?; + for (key, value) in self.clusters.read().pin().iter() { + push(key, value)?; } } else { for locality in client_state.subscribed.iter().filter_map(|name| { @@ -377,8 +377,8 @@ impl Config { name.parse().ok().map(Some) } }) { - if let Some(cluster) = self.clusters.read().get(&locality) { - push(cluster.key(), cluster.value())?; + if let Some(value) = self.clusters.read().pin().get(&locality) { + push(&locality, value)?; } } }; @@ -387,7 +387,7 @@ impl Config { // is when ClusterMap::update_unlocated_endpoints is called to move the None // locality endpoints to another one, so we just detect that case manually if client_state.versions.contains_key("") - && self.clusters.read().get(&None).is_none() + && self.clusters.read().pin().get(&None).is_none() { removed.insert("".into()); } @@ -593,16 +593,15 @@ impl Config { #[derive(Default, Debug, Deserialize, Serialize)] pub struct DatacenterMap { - map: dashmap::DashMap, + map: papaya::HashMap, version: AtomicU64, } impl DatacenterMap { #[inline] - pub fn insert(&self, ip: IpAddr, datacenter: Datacenter) -> Option { - let old = self.map.insert(ip, datacenter); + pub fn insert(&self, ip: IpAddr, datacenter: Datacenter) { + self.map.pin().insert(ip, datacenter); self.version.fetch_add(1, Relaxed); - old } #[inline] @@ -621,13 +620,10 @@ impl DatacenterMap { } #[inline] - pub fn get(&self, key: &IpAddr) -> Option> { - self.map.get(key) - } - - #[inline] - pub fn iter(&self) -> dashmap::iter::Iter { - self.map.iter() + pub fn pin( + &self, + ) -> papaya::HashMapRef { + self.map.pin() } } @@ -676,8 +672,8 @@ impl PartialEq for DatacenterMap { return false; } - for a in self.iter() { - match rhs.get(a.key()).filter(|b| *a.value() == **b) { + for (key, value) in self.pin().iter() { + match rhs.pin().get(key).filter(|b| *value == **b) { Some(_) => {} None => return false, } diff --git a/src/filters/local_rate_limit.rs b/src/filters/local_rate_limit.rs index 1c4462b5d..e82a9c1ee 100644 --- a/src/filters/local_rate_limit.rs +++ b/src/filters/local_rate_limit.rs @@ -20,11 +20,7 @@ use std::time::Duration; use serde::{Deserialize, Serialize}; -use crate::{ - collections::ttl::{Entry, TtlMap}, - filters::prelude::*, - net::endpoint::EndpointAddress, -}; +use crate::{collections::ttl::TtlMap, filters::prelude::*, net::endpoint::EndpointAddress}; use crate::generated::quilkin::filters::local_rate_limit::v1alpha1 as proto; @@ -50,8 +46,8 @@ const SESSION_EXPIRY_POLL_INTERVAL: Duration = Duration::from_secs(60); /// number of packet handling workers). #[derive(Debug)] struct Bucket { - counter: Arc, - window_start_time_secs: Arc, + counter: AtomicUsize, + window_start_time_secs: AtomicU64, } /// A filter that implements rate limiting on packets based on the token-bucket @@ -61,7 +57,7 @@ struct Bucket { /// flow through the filter untouched. pub struct LocalRateLimit { /// Tracks rate limiting state per source address. - state: TtlMap, + state: TtlMap>, /// Filter configuration. config: Config, } @@ -95,10 +91,10 @@ impl LocalRateLimit { } if let Some(bucket) = self.state.get(address) { - let prev_count = bucket.value.counter.fetch_add(1, Ordering::Relaxed); + let prev_count = bucket.counter.fetch_add(1, Ordering::Relaxed); let now_secs = self.state.now_relative_secs(); - let window_start_secs = bucket.value.window_start_time_secs.load(Ordering::Relaxed); + let window_start_secs = bucket.window_start_time_secs.load(Ordering::Relaxed); let elapsed_secs = now_secs - window_start_secs; let start_new_window = elapsed_secs > self.config.period as u64; @@ -115,9 +111,8 @@ impl LocalRateLimit { if start_new_window { // Current time window has ended, so we can reset the counter and // start a new time window instead. - bucket.value.counter.store(1, Ordering::Relaxed); + bucket.counter.store(1, Ordering::Relaxed); bucket - .value .window_start_time_secs .store(now_secs, Ordering::Relaxed); } @@ -125,21 +120,23 @@ impl LocalRateLimit { return true; } - match self.state.entry(address.clone()) { - Entry::Occupied(entry) => { + match self.state.get(address) { + Some(value) => { // It is possible that some other task has added the item since we // checked for it. If so, only increment the counter - no need to // update the window start time since the window has just started. - let bucket = entry.get(); - bucket.value.counter.fetch_add(1, Ordering::Relaxed); + value.counter.fetch_add(1, Ordering::Relaxed); } - Entry::Vacant(entry) => { + None => { // New entry, set both the time stamp and let now_secs = self.state.now_relative_secs(); - entry.insert(Bucket { - counter: Arc::new(AtomicUsize::new(1)), - window_start_time_secs: Arc::new(AtomicU64::new(now_secs)), - }); + self.state.insert( + address.clone(), + Arc::new(Bucket { + counter: AtomicUsize::new(1), + window_start_time_secs: AtomicU64::new(now_secs), + }), + ); } }; diff --git a/src/net/cluster.rs b/src/net/cluster.rs index 92fdc8602..4420f87b5 100644 --- a/src/net/cluster.rs +++ b/src/net/cluster.rs @@ -20,8 +20,8 @@ use std::{ sync::atomic::{AtomicU64, AtomicUsize, Ordering::Relaxed}, }; -use dashmap::DashMap; use once_cell::sync::Lazy; +use papaya::HashMap; use serde::{Deserialize, Serialize}; use crate::net::endpoint::{Endpoint, EndpointAddress, Locality}; @@ -259,15 +259,12 @@ impl EndpointSet { /// Represents a full snapshot of all clusters. pub struct ClusterMap { - map: DashMap, EndpointSet, S>, - token_map: DashMap>, + map: papaya::HashMap, EndpointSet, S>, + token_map: papaya::HashMap, S>, num_endpoints: AtomicUsize, version: AtomicU64, } -type DashMapRef<'inner> = dashmap::mapref::one::Ref<'inner, Option, EndpointSet>; -type DashMapRefMut<'inner> = dashmap::mapref::one::RefMut<'inner, Option, EndpointSet>; - impl ClusterMap { pub fn new() -> Self { Self::default() @@ -293,7 +290,7 @@ where { pub fn benchmarking(capacity: usize, hasher: S) -> Self { Self { - map: DashMap::with_capacity_and_hasher(capacity, hasher), + map: papaya::HashMap::with_capacity_and_hasher(capacity, hasher), ..Self::default() } } @@ -305,8 +302,8 @@ where pub fn apply(&self, locality: Option, cluster: EndpointSet) { let new_len = cluster.len(); - if let Some(mut current) = self.map.get_mut(&locality) { - let current = current.value_mut(); + if let Some(current) = self.map.pin().get(&locality) { + let mut current = current.clone(); let (old_len, token_map_diff) = current.replace(cluster); @@ -316,22 +313,24 @@ where self.num_endpoints.fetch_sub(old_len - new_len, Relaxed); } + self.map.pin().insert(locality, current); self.version.fetch_add(1, Relaxed); for (token_hash, addrs) in token_map_diff { if let Some(addrs) = addrs { - self.token_map.insert(token_hash, addrs); + self.token_map.pin().insert(token_hash, addrs); } else { - self.token_map.remove(&token_hash); + self.token_map.pin().remove(&token_hash); } } } else { for (token_hash, addrs) in &cluster.token_map { self.token_map + .pin() .insert(*token_hash, addrs.iter().cloned().collect()); } - self.map.insert(locality, cluster); + self.map.pin().insert(locality, cluster); self.num_endpoints.fetch_add(new_len, Relaxed); self.version.fetch_add(1, Relaxed); } @@ -347,20 +346,9 @@ where self.map.is_empty() } - pub fn get(&self, key: &Option) -> Option { - self.map.get(key) - } - - pub fn get_mut(&self, key: &Option) -> Option { - self.map.get_mut(key) - } - - pub fn get_default(&self) -> Option { - self.get(&None) - } - - pub fn get_default_mut(&self) -> Option { - self.get_mut(&None) + #[inline] + pub fn pin(&self) -> papaya::HashMapRef, EndpointSet, S, seize::LocalGuard> { + self.map.pin() } #[inline] @@ -370,11 +358,12 @@ where #[inline] pub fn remove_endpoint(&self, needle: &Endpoint) -> bool { - for mut entry in self.map.iter_mut() { - let set = entry.value_mut(); - - if set.endpoints.remove(needle) { - set.update(); + for (key, value) in self.map.pin().iter() { + if value.endpoints.contains(needle) { + let mut value = value.clone(); + value.endpoints.remove(needle); + value.update(); + self.map.pin().insert(key.clone(), value); self.num_endpoints.fetch_sub(1, Relaxed); self.version.fetch_add(1, Relaxed); return true; @@ -386,45 +375,33 @@ where #[inline] pub fn remove_endpoint_if(&self, closure: impl Fn(&Endpoint) -> bool) -> bool { - for mut entry in self.map.iter_mut() { - let set = entry.value_mut(); - if let Some(endpoint) = set + for (key, value) in self.map.pin().iter() { + if let Some(endpoint) = value .endpoints .iter() .find(|endpoint| (closure)(endpoint)) .cloned() { - // This will always be true, but.... - let removed = set.endpoints.remove(&endpoint); - if removed { - set.update(); - self.num_endpoints.fetch_sub(1, Relaxed); - self.version.fetch_add(1, Relaxed); - } - return removed; + let mut value = value.clone(); + value.endpoints.remove(&endpoint); + value.update(); + self.map.pin().insert(key.clone(), value); + self.num_endpoints.fetch_sub(1, Relaxed); + self.version.fetch_add(1, Relaxed); + return true; } } false } - #[inline] - pub fn iter(&self) -> dashmap::iter::Iter, EndpointSet, S> { - self.map.iter() - } - - pub fn entry( - &self, - key: Option, - ) -> dashmap::mapref::entry::Entry, EndpointSet> { - self.map.entry(key) - } - #[inline] pub fn replace(&self, locality: Option, endpoint: Endpoint) -> Option { - if let Some(mut set) = self.map.get_mut(&locality) { + if let Some(set) = self.map.pin().get(&locality) { + let mut set = set.clone(); let replaced = set.endpoints.replace(endpoint); set.update(); + self.map.pin().insert(locality, set); self.version.fetch_add(1, Relaxed); if replaced.is_none() { @@ -442,16 +419,16 @@ where pub fn endpoints(&self) -> Vec { let mut endpoints = Vec::with_capacity(self.num_of_endpoints()); - for set in self.map.iter() { - endpoints.extend(set.value().endpoints.iter().cloned()); + for (_, value) in self.map.pin().iter() { + endpoints.extend(value.endpoints.iter().cloned()); } endpoints } pub fn nth_endpoint(&self, mut index: usize) -> Option { - for set in self.iter() { - let set = &set.value().endpoints; + for (_, value) in self.map.pin().iter() { + let set = &value.endpoints; if index < set.len() { return set.iter().nth(index).cloned(); } else { @@ -465,8 +442,8 @@ where pub fn filter_endpoints(&self, f: impl Fn(&Endpoint) -> bool) -> Vec { let mut endpoints = Vec::new(); - for set in self.iter() { - for endpoint in set.endpoints.iter().filter(|e| (f)(e)) { + for (_, value) in self.map.pin().iter() { + for endpoint in value.endpoints.iter().filter(|e| (f)(e)) { endpoints.push(endpoint.clone()); } } @@ -486,23 +463,20 @@ where #[inline] pub fn update_unlocated_endpoints(&self, locality: Locality) { - if let Some((_, set)) = self.map.remove(&None) { + if let Some(set) = self.map.pin().remove(&None).cloned() { self.version.fetch_add(1, Relaxed); - if let Some(replaced) = self.map.insert(Some(locality), set) { + if let Some(replaced) = self.map.pin().insert(Some(locality), set) { self.num_endpoints.fetch_sub(replaced.len(), Relaxed); } } } #[inline] - pub fn remove_locality(&self, locality: &Option) -> Option { - let ret = self.map.remove(locality).map(|(_k, v)| v); - if let Some(ret) = &ret { + pub fn remove_locality(&self, locality: &Option) { + if let Some(ret) = self.map.pin().remove(locality) { self.version.fetch_add(1, Relaxed); self.num_endpoints.fetch_sub(ret.len(), Relaxed); } - - ret } pub fn addresses_for_token(&self, token: Token, addrs: &mut Vec) { @@ -546,7 +520,7 @@ where { fn default() -> Self { Self { - map: , EndpointSet, S>>::default(), + map: , EndpointSet, S>>::default(), token_map: Default::default(), version: <_>::default(), num_endpoints: <_>::default(), @@ -567,10 +541,12 @@ where S: Default + std::hash::BuildHasher + Clone, { fn eq(&self, rhs: &Self) -> bool { - for a in self.iter() { + for (key, value) in self.map.pin().iter() { match rhs - .get(a.key()) - .filter(|b| a.value().endpoints == b.endpoints) + .map + .pin() + .get(key) + .filter(|b| value.endpoints == b.endpoints) { Some(_) => {} None => return false, @@ -650,10 +626,9 @@ impl Serialize for ClusterMap { S: serde::Serializer, { self.map + .pin() .iter() - .map(|entry| { - EndpointWithLocality::from((entry.key().clone(), entry.value().endpoints.clone())) - }) + .map(|(key, value)| EndpointWithLocality::from((key.clone(), value.endpoints.clone()))) .collect::>() .serialize(ser) } @@ -664,7 +639,7 @@ where S: Default + std::hash::BuildHasher + Clone, { fn from(cmd: ClusterMapDeser) -> Self { - let map = DashMap::from_iter(cmd.endpoints.into_iter().map( + let map = HashMap::from_iter(cmd.endpoints.into_iter().map( |EndpointWithLocality { locality, endpoints, @@ -675,17 +650,19 @@ where } } -impl From, EndpointSet, S>> for ClusterMap +impl From, EndpointSet, S>> for ClusterMap where S: Default + std::hash::BuildHasher + Clone, { - fn from(map: DashMap, EndpointSet, S>) -> Self { - let num_endpoints = AtomicUsize::new(map.iter().map(|kv| kv.value().len()).sum()); + fn from(map: HashMap, EndpointSet, S>) -> Self { + let num_endpoints = AtomicUsize::new(map.pin().iter().map(|(_, value)| value.len()).sum()); - let token_map = DashMap::>::default(); - for es in &map { - for (token_hash, addrs) in &es.value().token_map { - token_map.insert(*token_hash, addrs.iter().cloned().collect()); + let token_map = HashMap::, S>::default(); + for value in map.pin().values() { + for (token_hash, addrs) in &value.token_map { + token_map + .pin() + .insert(*token_hash, addrs.iter().cloned().collect()); } } @@ -726,13 +703,15 @@ mod tests { cluster1.insert(Some(nl1.clone()), [endpoint.clone()].into()); cluster1.insert(Some(de1.clone()), [endpoint.clone()].into()); - assert_eq!(cluster1.get(&Some(nl1.clone())).unwrap().len(), 1); + assert_eq!(cluster1.pin().get(&Some(nl1.clone())).unwrap().len(), 1); assert!(cluster1 + .pin() .get(&Some(nl1.clone())) .unwrap() .contains(&endpoint)); - assert_eq!(cluster1.get(&Some(de1.clone())).unwrap().len(), 1); + assert_eq!(cluster1.pin().get(&Some(de1.clone())).unwrap().len(), 1); assert!(cluster1 + .pin() .get(&Some(de1.clone())) .unwrap() .contains(&endpoint)); @@ -741,16 +720,13 @@ mod tests { cluster1.insert(Some(de1.clone()), [endpoint.clone()].into()); - assert_eq!(cluster1.get(&Some(nl1.clone())).unwrap().len(), 1); - assert_eq!(cluster1.get(&Some(de1.clone())).unwrap().len(), 1); - assert!(cluster1 - .get(&Some(de1.clone())) - .unwrap() - .contains(&endpoint)); + assert_eq!(cluster1.pin().get(&Some(nl1.clone())).unwrap().len(), 1); + assert_eq!(cluster1.pin().get(&Some(de1.clone())).unwrap().len(), 1); + assert!(dbg!(cluster1.pin().get(&Some(de1.clone())).unwrap()).contains(&endpoint)); cluster1.insert(Some(de1.clone()), <_>::default()); - assert_eq!(cluster1.get(&Some(nl1.clone())).unwrap().len(), 1); - assert!(cluster1.get(&Some(de1.clone())).unwrap().is_empty()); + assert_eq!(cluster1.pin().get(&Some(nl1.clone())).unwrap().len(), 1); + assert!(cluster1.pin().get(&Some(de1.clone())).unwrap().is_empty()); } } diff --git a/src/net/endpoint/address.rs b/src/net/endpoint/address.rs index 2948b0a24..51a0864a6 100644 --- a/src/net/endpoint/address.rs +++ b/src/net/endpoint/address.rs @@ -71,7 +71,7 @@ impl EndpointAddress { Lazy::new(<_>::default); match CACHE.get(name) { - Some(ip) => **ip, + Some(ip) => ip, None => { let handle = tokio::runtime::Handle::current(); let set = handle diff --git a/src/net/phoenix.rs b/src/net/phoenix.rs index 4e8bc63b6..f3a8a79ea 100644 --- a/src/net/phoenix.rs +++ b/src/net/phoenix.rs @@ -446,9 +446,10 @@ impl Phoenix { let crate::config::DatacenterConfig::NonAgent { datacenters } = &config.datacenter else { unreachable!("this shouldn't be called by an agent") }; - for entry in datacenters.write().iter() { - let addr = (*entry.key(), entry.value().qcmp_port).into(); - self.add_node_if_not_exists(addr, entry.value().icao_code); + + for (key, value) in datacenters.read().pin().iter() { + let addr = (*key, value.qcmp_port).into(); + self.add_node_if_not_exists(addr, value.icao_code); } } }