diff --git a/Cargo.lock b/Cargo.lock index 317af86ff..7a00e0de1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5467,6 +5467,7 @@ dependencies = [ "rustc_version 0.4.1", "serde", "serde_json", + "slab", "socket2 0.5.7", "tokio", "tokio-util", diff --git a/Cargo.toml b/Cargo.toml index b5853f489..c33219f46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -158,6 +158,7 @@ serde = { version = "1.0.210", default-features = false, features = [ ] } # Default features are disabled due to usage in no_std crates serde_json = "1.0.128" serde_yaml = "0.9.34" +slab = "0.4.9" static_init = "1.0.3" stabby = "36.1.1" sha3 = "0.10.8" diff --git a/commons/zenoh-protocol/src/core/wire_expr.rs b/commons/zenoh-protocol/src/core/wire_expr.rs index 7a70f1006..3681863ca 100644 --- a/commons/zenoh-protocol/src/core/wire_expr.rs +++ b/commons/zenoh-protocol/src/core/wire_expr.rs @@ -17,7 +17,7 @@ use alloc::{ borrow::Cow, string::{String, ToString}, }; -use core::{convert::TryInto, fmt, sync::atomic::AtomicU16}; +use core::{convert::TryInto, fmt}; use zenoh_keyexpr::{keyexpr, OwnedKeyExpr}; use zenoh_result::{bail, ZResult}; @@ -28,7 +28,6 @@ use crate::network::Mapping; pub type ExprId = u16; pub type ExprLen = u16; -pub type AtomicExprId = AtomicU16; pub const EMPTY_EXPR_ID: ExprId = 0; /// A zenoh **resource** is represented by a pair composed by a **key** and a diff --git a/commons/zenoh-protocol/src/network/mod.rs b/commons/zenoh-protocol/src/network/mod.rs index ed23b0337..4e38ceda2 100644 --- a/commons/zenoh-protocol/src/network/mod.rs +++ b/commons/zenoh-protocol/src/network/mod.rs @@ -27,7 +27,7 @@ pub use declare::{ pub use interest::Interest; pub use oam::Oam; pub use push::Push; -pub use request::{AtomicRequestId, Request, RequestId}; +pub use request::{Request, RequestId}; pub use response::{Response, ResponseFinal}; use crate::core::{CongestionControl, Priority, Reliability}; diff --git a/commons/zenoh-protocol/src/network/request.rs b/commons/zenoh-protocol/src/network/request.rs index 3fd9eb221..019e68095 100644 --- a/commons/zenoh-protocol/src/network/request.rs +++ b/commons/zenoh-protocol/src/network/request.rs @@ -11,13 +11,10 @@ // Contributors: // ZettaScale Zenoh Team, // -use core::sync::atomic::AtomicU32; - use crate::{core::WireExpr, zenoh::RequestBody}; /// The resolution of a RequestId pub type RequestId = u32; -pub type AtomicRequestId = AtomicU32; pub mod flag { pub const N: u8 = 1 << 5; // 0x20 Named if N==1 then the key expr has name/suffix diff --git a/zenoh/Cargo.toml b/zenoh/Cargo.toml index 94f6d6eb4..1bc323095 100644 --- a/zenoh/Cargo.toml +++ b/zenoh/Cargo.toml @@ -86,6 +86,7 @@ phf = { workspace = true } rand = { workspace = true, features = ["default"] } serde = { workspace = true, features = ["default"] } serde_json = { workspace = true } +slab = { workspace = true } socket2 = { workspace = true } uhlc = { workspace = true, features = ["default"] } vec_map = { workspace = true } diff --git a/zenoh/src/api/key_expr.rs b/zenoh/src/api/key_expr.rs index 2a3c775bf..86e188022 100644 --- a/zenoh/src/api/key_expr.rs +++ b/zenoh/src/api/key_expr.rs @@ -618,7 +618,8 @@ impl Wait for KeyExprUndeclaration<'_> { }; tracing::trace!("undeclare_keyexpr({:?})", expr_id); let mut state = zwrite!(session.0.state); - state.local_resources.remove(&expr_id); + assert_ne!(expr_id, 0, "0 is not a valid keyexpr id"); + state.local_resources.remove(expr_id as usize); let primitives = state.primitives()?; drop(state); diff --git a/zenoh/src/api/session.rs b/zenoh/src/api/session.rs index 0c01bffdb..423ac7d2c 100644 --- a/zenoh/src/api/session.rs +++ b/zenoh/src/api/session.rs @@ -25,6 +25,7 @@ use std::{ time::{Duration, SystemTime, UNIX_EPOCH}, }; +use slab::Slab; use tracing::{error, info, trace, warn}; use uhlc::Timestamp; #[cfg(feature = "internal")] @@ -42,8 +43,7 @@ use zenoh_protocol::network::{ use zenoh_protocol::{ core::{ key_expr::{keyexpr, OwnedKeyExpr}, - AtomicExprId, CongestionControl, EntityId, ExprId, Parameters, Reliability, WireExpr, - EMPTY_EXPR_ID, + CongestionControl, EntityId, ExprId, Parameters, Reliability, WireExpr, EMPTY_EXPR_ID, }, network::{ self, @@ -53,8 +53,8 @@ use zenoh_protocol::{ UndeclareSubscriber, }, interest::{InterestMode, InterestOptions}, - push, request, AtomicRequestId, DeclareFinal, Interest, Mapping, Push, Request, RequestId, - Response, ResponseFinal, + push, request, DeclareFinal, Interest, Mapping, Push, Request, RequestId, Response, + ResponseFinal, }, zenoh::{ query::{self, ext::QueryBodyType}, @@ -121,11 +121,7 @@ zconfigurable! { pub(crate) struct SessionState { pub(crate) primitives: Option>, // @TODO replace with MaybeUninit ?? - pub(crate) expr_id_counter: AtomicExprId, // @TODO: manage rollover and uniqueness - pub(crate) qid_counter: AtomicRequestId, - #[cfg(feature = "unstable")] - pub(crate) liveliness_qid_counter: AtomicRequestId, - pub(crate) local_resources: HashMap, + pub(crate) local_resources: Slab, pub(crate) remote_resources: HashMap, #[cfg(feature = "unstable")] pub(crate) remote_subscribers: HashMap>, @@ -140,9 +136,9 @@ pub(crate) struct SessionState { pub(crate) tokens: HashMap>, #[cfg(feature = "unstable")] pub(crate) matching_listeners: HashMap>, - pub(crate) queries: HashMap, + pub(crate) queries: Slab, #[cfg(feature = "unstable")] - pub(crate) liveliness_queries: HashMap, + pub(crate) liveliness_queries: Slab, pub(crate) aggregated_subscribers: Vec, pub(crate) aggregated_publishers: Vec, } @@ -152,13 +148,12 @@ impl SessionState { aggregated_subscribers: Vec, aggregated_publishers: Vec, ) -> SessionState { + // Note: local_resources start at 1 because 0 is reserved for NO_RESOURCE + let mut local_resources = Slab::new(); + local_resources.insert(Resource::Prefix { prefix: "".into() }); SessionState { primitives: None, - expr_id_counter: AtomicExprId::new(1), // Note: start at 1 because 0 is reserved for NO_RESOURCE - qid_counter: AtomicRequestId::new(0), - #[cfg(feature = "unstable")] - liveliness_qid_counter: AtomicRequestId::new(0), - local_resources: HashMap::new(), + local_resources, remote_resources: HashMap::new(), #[cfg(feature = "unstable")] remote_subscribers: HashMap::new(), @@ -173,9 +168,9 @@ impl SessionState { tokens: HashMap::new(), #[cfg(feature = "unstable")] matching_listeners: HashMap::new(), - queries: HashMap::new(), + queries: Slab::new(), #[cfg(feature = "unstable")] - liveliness_queries: HashMap::new(), + liveliness_queries: Slab::new(), aggregated_subscribers, aggregated_publishers, } @@ -193,13 +188,13 @@ impl SessionState { #[inline] fn get_local_res(&self, id: &ExprId) -> Option<&Resource> { - self.local_resources.get(id) + self.local_resources.get(*id as usize).filter(|_| *id != 0) } #[inline] fn get_remote_res(&self, id: &ExprId, mapping: Mapping) -> Option<&Resource> { match mapping { - Mapping::Receiver => self.local_resources.get(id), + Mapping::Receiver => self.get_local_res(id), Mapping::Sender => self.remote_resources.get(id), } } @@ -1129,11 +1124,11 @@ impl SessionInner { match state .local_resources .iter() - .find(|(_expr_id, res)| res.name() == prefix) + .skip(1) // skip NO_RESOURCE + .find(|(_, res)| res.name() == prefix) { - Some((expr_id, _res)) => Ok(*expr_id), + Some((expr_id, _res)) => Ok(expr_id as ExprId), None => { - let expr_id = state.expr_id_counter.fetch_add(1, Ordering::SeqCst); let mut res = Resource::new(Box::from(prefix)); if let Resource::Node(res_node) = &mut res { for kind in [ @@ -1147,7 +1142,10 @@ impl SessionInner { } } } - state.local_resources.insert(expr_id, res); + if state.local_resources.vacant_key() > ExprId::MAX as usize { + bail!("too many keyexprs declared"); + } + let expr_id = state.local_resources.insert(res) as ExprId; drop(state); primitives.send_declare(Declare { interest_id: None, @@ -1332,8 +1330,9 @@ impl SessionInner { .insert(sub_state.id, sub_state.clone()); for res in state .local_resources - .values_mut() - .filter_map(Resource::as_node_mut) + .iter_mut() + .skip(1) // skip NO_RESOURCE + .filter_map(|(_, res)|res.as_node_mut()) { if key_expr.intersects(&res.key_expr) { res.subscribers_mut(SubscriberKind::Subscriber) @@ -1411,8 +1410,9 @@ impl SessionInner { trace!("undeclare_subscriber({:?})", sub_state); for res in state .local_resources - .values_mut() - .filter_map(Resource::as_node_mut) + .iter_mut() + .skip(1) // skip NO_RESOURCE + .filter_map(|(_, res)|res.as_node_mut()) { res.subscribers_mut(kind) .retain(|sub| sub.id != sub_state.id); @@ -1604,8 +1604,9 @@ impl SessionInner { for res in state .local_resources - .values_mut() - .filter_map(Resource::as_node_mut) + .iter_mut() + .skip(1) // skip NO_RESOURCE + .filter_map(|(_, res)|res.as_node_mut()) { if key_expr.intersects(&res.key_expr) { res.subscribers_mut(SubscriberKind::LivelinessSubscriber) @@ -2065,12 +2066,22 @@ impl SessionInner { ConsolidationMode::Auto => ConsolidationMode::Latest, mode => mode, }; - let qid = state.qid_counter.fetch_add(1, Ordering::SeqCst); let nb_final = match destination { Locality::Any => 2, _ => 1, }; + let wexpr = key_expr.to_wire(self).to_owned(); + let qid = state.queries.insert(QueryState { + nb_final, + key_expr: key_expr.clone().into_owned(), + parameters: parameters.clone().into_owned(), + reception_mode: consolidation, + replies: (consolidation != ConsolidationMode::None).then(HashMap::new), + callback, + }) as RequestId; + tracing::trace!("Register query {} (nb_final = {})", qid, nb_final); + let token = self.task_controller.get_cancellation_token(); self.task_controller .spawn_with_rt(zenoh_runtime::ZRuntime::Net, { @@ -2081,7 +2092,7 @@ impl SessionInner { tokio::select! { _ = tokio::time::sleep(timeout) => { let mut state = zwrite!(session.state); - if let Some(query) = state.queries.remove(&qid) { + if let Some(query) = state.queries.try_remove(qid as usize) { std::mem::drop(state); tracing::debug!("Timeout on query {}! Send error and close.", qid); if query.reception_mode == ConsolidationMode::Latest { @@ -2101,20 +2112,6 @@ impl SessionInner { } }); - tracing::trace!("Register query {} (nb_final = {})", qid, nb_final); - let wexpr = key_expr.to_wire(self).to_owned(); - state.queries.insert( - qid, - QueryState { - nb_final, - key_expr: key_expr.clone().into_owned(), - parameters: parameters.clone().into_owned(), - reception_mode: consolidation, - replies: (consolidation != ConsolidationMode::None).then(HashMap::new), - callback, - }, - ); - let primitives = state.primitives()?; drop(state); @@ -2176,7 +2173,13 @@ impl SessionInner { ) -> ZResult<()> { tracing::trace!("liveliness.get({}, {:?})", key_expr, timeout); let mut state = zwrite!(self.state); - let id = state.liveliness_qid_counter.fetch_add(1, Ordering::SeqCst); + + let wexpr = key_expr.to_wire(self).to_owned(); + let id = state + .liveliness_queries + .insert(LivelinessQueryState { callback }) as InterestId; + tracing::trace!("Register liveliness query {}", id); + let token = self.task_controller.get_cancellation_token(); self.task_controller .spawn_with_rt(zenoh_runtime::ZRuntime::Net, { @@ -2186,7 +2189,7 @@ impl SessionInner { tokio::select! { _ = tokio::time::sleep(timeout) => { let mut state = zwrite!(session.state); - if let Some(query) = state.liveliness_queries.remove(&id) { + if let Some(query) = state.liveliness_queries.try_remove(id as usize) { std::mem::drop(state); tracing::debug!("Timeout on liveliness query {}! Send error and close.", id); query.callback.call(Reply { @@ -2201,12 +2204,6 @@ impl SessionInner { } }); - tracing::trace!("Register liveliness query {}", id); - let wexpr = key_expr.to_wire(self).to_owned(); - state - .liveliness_queries - .insert(id, LivelinessQueryState { callback }); - let primitives = state.primitives()?; drop(state); @@ -2402,7 +2399,8 @@ impl Primitives for WeakSession { { Ok(key_expr) => { if let Some(interest_id) = msg.interest_id { - if let Some(query) = state.liveliness_queries.get(&interest_id) { + if let Some(query) = state.liveliness_queries.get(interest_id as usize) + { let reply = Reply { result: Ok(Sample { key_expr, @@ -2516,7 +2514,7 @@ impl Primitives for WeakSession { #[cfg(feature = "unstable")] if let Some(interest_id) = msg.interest_id { let mut state = zwrite!(self.state); - let _ = state.liveliness_queries.remove(&interest_id); + let _ = state.liveliness_queries.try_remove(interest_id as usize); } } } @@ -2592,7 +2590,7 @@ impl Primitives for WeakSession { if state.primitives.is_none() { return; // Session closing or closed } - match state.queries.get_mut(&msg.rid) { + match state.queries.get_mut(msg.rid as usize) { Some(query) => { let callback = query.callback.clone(); std::mem::drop(state); @@ -2623,7 +2621,7 @@ impl Primitives for WeakSession { return; } }; - match state.queries.get_mut(&msg.rid) { + match state.queries.get_mut(msg.rid as usize) { Some(query) => { let c = zcondfeat!("unstable", !query.parameters.reply_key_expr_any(), true); @@ -2794,11 +2792,11 @@ impl Primitives for WeakSession { if state.primitives.is_none() { return; // Session closing or closed } - match state.queries.get_mut(&msg.rid) { + match state.queries.get_mut(msg.rid as usize) { Some(query) => { query.nb_final -= 1; if query.nb_final == 0 { - let query = state.queries.remove(&msg.rid).unwrap(); + let query = state.queries.try_remove(msg.rid as usize).unwrap(); std::mem::drop(state); if query.reception_mode == ConsolidationMode::Latest { for (_, reply) in query.replies.unwrap().into_iter() { diff --git a/zenoh/src/net/routing/dispatcher/face.rs b/zenoh/src/net/routing/dispatcher/face.rs index 6e1db6bbf..0b3078a80 100644 --- a/zenoh/src/net/routing/dispatcher/face.rs +++ b/zenoh/src/net/routing/dispatcher/face.rs @@ -19,12 +19,13 @@ use std::{ time::Duration, }; +use slab::Slab; use tokio_util::sync::CancellationToken; use zenoh_protocol::{ core::{ExprId, Reliability, WhatAmI, ZenohIdProto}, network::{ interest::{InterestId, InterestMode, InterestOptions}, - Mapping, Push, Request, RequestId, Response, ResponseFinal, + Mapping, Push, Request, Response, ResponseFinal, }, zenoh::RequestBody, }; @@ -70,8 +71,7 @@ pub struct FaceState { HashMap, CancellationToken)>, pub(crate) local_mappings: HashMap>, pub(crate) remote_mappings: HashMap>, - pub(crate) next_qid: RequestId, - pub(crate) pending_queries: HashMap, CancellationToken)>, + pub(crate) pending_queries: Slab<(Arc, CancellationToken)>, pub(crate) mcast_group: Option, pub(crate) in_interceptors: Option>, pub(crate) hat: Box, @@ -102,8 +102,7 @@ impl FaceState { pending_current_interests: HashMap::new(), local_mappings: HashMap::new(), remote_mappings: HashMap::new(), - next_qid: 0, - pending_queries: HashMap::new(), + pending_queries: Slab::new(), mcast_group, in_interceptors, hat, diff --git a/zenoh/src/net/routing/dispatcher/queries.rs b/zenoh/src/net/routing/dispatcher/queries.rs index f8a9f1f12..e62a912c4 100644 --- a/zenoh/src/net/routing/dispatcher/queries.rs +++ b/zenoh/src/net/routing/dispatcher/queries.rs @@ -290,13 +290,11 @@ pub(crate) fn update_matches_query_routes(tables: &Tables, res: &Arc) #[inline] fn insert_pending_query(outface: &mut Arc, query: Arc) -> RequestId { let outface_mut = get_mut_unchecked(outface); - outface_mut.next_qid += 1; - let qid = outface_mut.next_qid; - outface_mut.pending_queries.insert( - qid, - (query, outface_mut.task_controller.get_cancellation_token()), - ); - qid + outface_mut + .pending_queries + .insert((query, outface_mut.task_controller.get_cancellation_token())) + .try_into() + .expect("too many pending queries") } #[inline] @@ -381,7 +379,7 @@ impl QueryCleanup { qid, timeout, }; - if let Some((_, cancellation_token)) = face.pending_queries.get(&qid) { + if let Some((_, cancellation_token)) = face.pending_queries.get(qid as usize) { let c_cancellation_token = cancellation_token.clone(); face.task_controller .spawn_with_rt(zenoh_runtime::ZRuntime::Net, async move { @@ -422,7 +420,7 @@ impl Timed for QueryCleanup { let queries_lock = zwrite!(self.tables.queries_lock); if let Some(query) = get_mut_unchecked(&mut face) .pending_queries - .remove(&self.qid) + .try_remove(self.qid as usize) { drop(queries_lock); tracing::warn!( @@ -682,7 +680,7 @@ pub(crate) fn route_send_response( inc_res_stats!(face, rx, admin, body) } - match face.pending_queries.get(&qid) { + match face.pending_queries.get(qid as usize) { Some((query, _)) => { drop(queries_lock); @@ -717,7 +715,10 @@ pub(crate) fn route_send_response_final( qid: RequestId, ) { let queries_lock = zwrite!(tables_ref.queries_lock); - match get_mut_unchecked(face).pending_queries.remove(&qid) { + match get_mut_unchecked(face) + .pending_queries + .try_remove(qid as usize) + { Some(query) => { drop(queries_lock); tracing::debug!( @@ -735,7 +736,7 @@ pub(crate) fn route_send_response_final( pub(crate) fn finalize_pending_queries(tables_ref: &TablesLock, face: &mut Arc) { let queries_lock = zwrite!(tables_ref.queries_lock); - for (_, query) in get_mut_unchecked(face).pending_queries.drain() { + for query in get_mut_unchecked(face).pending_queries.drain() { finalize_pending_query(query); } drop(queries_lock); diff --git a/zenoh/src/net/runtime/mod.rs b/zenoh/src/net/runtime/mod.rs index 301698eea..225b7e64a 100644 --- a/zenoh/src/net/runtime/mod.rs +++ b/zenoh/src/net/runtime/mod.rs @@ -70,7 +70,7 @@ use crate::{ pub(crate) struct RuntimeState { zid: ZenohId, whatami: WhatAmI, - next_id: AtomicU32, + next_id: AtomicU32, // @TODO: manage rollover and uniqueness router: Arc, config: Notifier, manager: TransportManager,