Skip to content

Commit

Permalink
fix: use slab instead of hashmap + atomic counter
Browse files Browse the repository at this point in the history
Allow managing rollover and uniqueness, as well as being faster to access.
  • Loading branch information
wyfo committed Nov 21, 2024
1 parent 3404e05 commit 08e13ba
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 85 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ serde = { version = "1.0.210", default-features = false, features = [
] } # Default features are disabled due to usage in no_std crates
serde_json = "1.0.128"
serde_yaml = "0.9.34"
slab = "0.4.9"
static_init = "1.0.3"
stabby = "36.1.1"
sha3 = "0.10.8"
Expand Down
3 changes: 1 addition & 2 deletions commons/zenoh-protocol/src/core/wire_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use alloc::{
borrow::Cow,
string::{String, ToString},
};
use core::{convert::TryInto, fmt, sync::atomic::AtomicU16};
use core::{convert::TryInto, fmt};

use zenoh_keyexpr::{keyexpr, OwnedKeyExpr};
use zenoh_result::{bail, ZResult};
Expand All @@ -28,7 +28,6 @@ use crate::network::Mapping;
pub type ExprId = u16;
pub type ExprLen = u16;

pub type AtomicExprId = AtomicU16;
pub const EMPTY_EXPR_ID: ExprId = 0;

/// A zenoh **resource** is represented by a pair composed by a **key** and a
Expand Down
2 changes: 1 addition & 1 deletion commons/zenoh-protocol/src/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub use declare::{
pub use interest::Interest;
pub use oam::Oam;
pub use push::Push;
pub use request::{AtomicRequestId, Request, RequestId};
pub use request::{Request, RequestId};
pub use response::{Response, ResponseFinal};

use crate::core::{CongestionControl, Priority, Reliability};
Expand Down
3 changes: 0 additions & 3 deletions commons/zenoh-protocol/src/network/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@
// Contributors:
// ZettaScale Zenoh Team, <zenoh@zettascale.tech>
//
use core::sync::atomic::AtomicU32;

use crate::{core::WireExpr, zenoh::RequestBody};

/// The resolution of a RequestId
pub type RequestId = u32;
pub type AtomicRequestId = AtomicU32;

pub mod flag {
pub const N: u8 = 1 << 5; // 0x20 Named if N==1 then the key expr has name/suffix
Expand Down
1 change: 1 addition & 0 deletions zenoh/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ phf = { workspace = true }
rand = { workspace = true, features = ["default"] }
serde = { workspace = true, features = ["default"] }
serde_json = { workspace = true }
slab = { workspace = true }
socket2 = { workspace = true }
uhlc = { workspace = true, features = ["default"] }
vec_map = { workspace = true }
Expand Down
3 changes: 2 additions & 1 deletion zenoh/src/api/key_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,8 @@ impl Wait for KeyExprUndeclaration<'_> {
};
tracing::trace!("undeclare_keyexpr({:?})", expr_id);
let mut state = zwrite!(session.0.state);
state.local_resources.remove(&expr_id);
assert_ne!(expr_id, 0, "0 is not a valid keyexpr id");
state.local_resources.remove(expr_id as usize);

let primitives = state.primitives()?;
drop(state);
Expand Down
118 changes: 58 additions & 60 deletions zenoh/src/api/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use std::{
time::{Duration, SystemTime, UNIX_EPOCH},
};

use slab::Slab;
use tracing::{error, info, trace, warn};
use uhlc::Timestamp;
#[cfg(feature = "internal")]
Expand All @@ -42,8 +43,7 @@ use zenoh_protocol::network::{
use zenoh_protocol::{
core::{
key_expr::{keyexpr, OwnedKeyExpr},
AtomicExprId, CongestionControl, EntityId, ExprId, Parameters, Reliability, WireExpr,
EMPTY_EXPR_ID,
CongestionControl, EntityId, ExprId, Parameters, Reliability, WireExpr, EMPTY_EXPR_ID,
},
network::{
self,
Expand All @@ -53,8 +53,8 @@ use zenoh_protocol::{
UndeclareSubscriber,
},
interest::{InterestMode, InterestOptions},
push, request, AtomicRequestId, DeclareFinal, Interest, Mapping, Push, Request, RequestId,
Response, ResponseFinal,
push, request, DeclareFinal, Interest, Mapping, Push, Request, RequestId, Response,
ResponseFinal,
},
zenoh::{
query::{self, ext::QueryBodyType},
Expand Down Expand Up @@ -121,11 +121,7 @@ zconfigurable! {

pub(crate) struct SessionState {
pub(crate) primitives: Option<Arc<Face>>, // @TODO replace with MaybeUninit ??
pub(crate) expr_id_counter: AtomicExprId, // @TODO: manage rollover and uniqueness
pub(crate) qid_counter: AtomicRequestId,
#[cfg(feature = "unstable")]
pub(crate) liveliness_qid_counter: AtomicRequestId,
pub(crate) local_resources: HashMap<ExprId, Resource>,
pub(crate) local_resources: Slab<Resource>,
pub(crate) remote_resources: HashMap<ExprId, Resource>,
#[cfg(feature = "unstable")]
pub(crate) remote_subscribers: HashMap<SubscriberId, KeyExpr<'static>>,
Expand All @@ -140,9 +136,9 @@ pub(crate) struct SessionState {
pub(crate) tokens: HashMap<Id, Arc<LivelinessTokenState>>,
#[cfg(feature = "unstable")]
pub(crate) matching_listeners: HashMap<Id, Arc<MatchingListenerState>>,
pub(crate) queries: HashMap<RequestId, QueryState>,
pub(crate) queries: Slab<QueryState>,
#[cfg(feature = "unstable")]
pub(crate) liveliness_queries: HashMap<InterestId, LivelinessQueryState>,
pub(crate) liveliness_queries: Slab<LivelinessQueryState>,
pub(crate) aggregated_subscribers: Vec<OwnedKeyExpr>,
pub(crate) aggregated_publishers: Vec<OwnedKeyExpr>,
}
Expand All @@ -152,13 +148,12 @@ impl SessionState {
aggregated_subscribers: Vec<OwnedKeyExpr>,
aggregated_publishers: Vec<OwnedKeyExpr>,
) -> SessionState {
// Note: local_resources start at 1 because 0 is reserved for NO_RESOURCE
let mut local_resources = Slab::new();
local_resources.insert(Resource::Prefix { prefix: "".into() });
SessionState {
primitives: None,
expr_id_counter: AtomicExprId::new(1), // Note: start at 1 because 0 is reserved for NO_RESOURCE
qid_counter: AtomicRequestId::new(0),
#[cfg(feature = "unstable")]
liveliness_qid_counter: AtomicRequestId::new(0),
local_resources: HashMap::new(),
local_resources,
remote_resources: HashMap::new(),
#[cfg(feature = "unstable")]
remote_subscribers: HashMap::new(),
Expand All @@ -173,9 +168,9 @@ impl SessionState {
tokens: HashMap::new(),
#[cfg(feature = "unstable")]
matching_listeners: HashMap::new(),
queries: HashMap::new(),
queries: Slab::new(),
#[cfg(feature = "unstable")]
liveliness_queries: HashMap::new(),
liveliness_queries: Slab::new(),
aggregated_subscribers,
aggregated_publishers,
}
Expand All @@ -193,13 +188,13 @@ impl SessionState {

#[inline]
fn get_local_res(&self, id: &ExprId) -> Option<&Resource> {
self.local_resources.get(id)
self.local_resources.get(*id as usize).filter(|_| *id != 0)
}

#[inline]
fn get_remote_res(&self, id: &ExprId, mapping: Mapping) -> Option<&Resource> {
match mapping {
Mapping::Receiver => self.local_resources.get(id),
Mapping::Receiver => self.get_local_res(id),
Mapping::Sender => self.remote_resources.get(id),
}
}
Expand Down Expand Up @@ -1129,11 +1124,11 @@ impl SessionInner {
match state
.local_resources
.iter()
.find(|(_expr_id, res)| res.name() == prefix)
.skip(1) // skip NO_RESOURCE
.find(|(_, res)| res.name() == prefix)
{
Some((expr_id, _res)) => Ok(*expr_id),
Some((expr_id, _res)) => Ok(expr_id as ExprId),
None => {
let expr_id = state.expr_id_counter.fetch_add(1, Ordering::SeqCst);
let mut res = Resource::new(Box::from(prefix));
if let Resource::Node(res_node) = &mut res {
for kind in [
Expand All @@ -1147,7 +1142,10 @@ impl SessionInner {
}
}
}
state.local_resources.insert(expr_id, res);
if state.local_resources.vacant_key() > ExprId::MAX as usize {
bail!("too many keyexprs declared");
}
let expr_id = state.local_resources.insert(res) as ExprId;
drop(state);
primitives.send_declare(Declare {
interest_id: None,
Expand Down Expand Up @@ -1332,8 +1330,9 @@ impl SessionInner {
.insert(sub_state.id, sub_state.clone());
for res in state
.local_resources
.values_mut()
.filter_map(Resource::as_node_mut)
.iter_mut()
.skip(1) // skip NO_RESOURCE
.filter_map(|(_, res)|res.as_node_mut())
{
if key_expr.intersects(&res.key_expr) {
res.subscribers_mut(SubscriberKind::Subscriber)
Expand Down Expand Up @@ -1411,8 +1410,9 @@ impl SessionInner {
trace!("undeclare_subscriber({:?})", sub_state);
for res in state
.local_resources
.values_mut()
.filter_map(Resource::as_node_mut)
.iter_mut()
.skip(1) // skip NO_RESOURCE
.filter_map(|(_, res)|res.as_node_mut())
{
res.subscribers_mut(kind)
.retain(|sub| sub.id != sub_state.id);
Expand Down Expand Up @@ -1604,8 +1604,9 @@ impl SessionInner {

for res in state
.local_resources
.values_mut()
.filter_map(Resource::as_node_mut)
.iter_mut()
.skip(1) // skip NO_RESOURCE
.filter_map(|(_, res)|res.as_node_mut())
{
if key_expr.intersects(&res.key_expr) {
res.subscribers_mut(SubscriberKind::LivelinessSubscriber)
Expand Down Expand Up @@ -2065,12 +2066,22 @@ impl SessionInner {
ConsolidationMode::Auto => ConsolidationMode::Latest,
mode => mode,
};
let qid = state.qid_counter.fetch_add(1, Ordering::SeqCst);
let nb_final = match destination {
Locality::Any => 2,
_ => 1,
};

let wexpr = key_expr.to_wire(self).to_owned();
let qid = state.queries.insert(QueryState {
nb_final,
key_expr: key_expr.clone().into_owned(),
parameters: parameters.clone().into_owned(),
reception_mode: consolidation,
replies: (consolidation != ConsolidationMode::None).then(HashMap::new),
callback,
}) as RequestId;
tracing::trace!("Register query {} (nb_final = {})", qid, nb_final);

let token = self.task_controller.get_cancellation_token();
self.task_controller
.spawn_with_rt(zenoh_runtime::ZRuntime::Net, {
Expand All @@ -2081,7 +2092,7 @@ impl SessionInner {
tokio::select! {
_ = tokio::time::sleep(timeout) => {
let mut state = zwrite!(session.state);
if let Some(query) = state.queries.remove(&qid) {
if let Some(query) = state.queries.try_remove(qid as usize) {
std::mem::drop(state);
tracing::debug!("Timeout on query {}! Send error and close.", qid);
if query.reception_mode == ConsolidationMode::Latest {
Expand All @@ -2101,20 +2112,6 @@ impl SessionInner {
}
});

tracing::trace!("Register query {} (nb_final = {})", qid, nb_final);
let wexpr = key_expr.to_wire(self).to_owned();
state.queries.insert(
qid,
QueryState {
nb_final,
key_expr: key_expr.clone().into_owned(),
parameters: parameters.clone().into_owned(),
reception_mode: consolidation,
replies: (consolidation != ConsolidationMode::None).then(HashMap::new),
callback,
},
);

let primitives = state.primitives()?;
drop(state);

Expand Down Expand Up @@ -2176,7 +2173,13 @@ impl SessionInner {
) -> ZResult<()> {
tracing::trace!("liveliness.get({}, {:?})", key_expr, timeout);
let mut state = zwrite!(self.state);
let id = state.liveliness_qid_counter.fetch_add(1, Ordering::SeqCst);

let wexpr = key_expr.to_wire(self).to_owned();
let id = state
.liveliness_queries
.insert(LivelinessQueryState { callback }) as InterestId;
tracing::trace!("Register liveliness query {}", id);

let token = self.task_controller.get_cancellation_token();
self.task_controller
.spawn_with_rt(zenoh_runtime::ZRuntime::Net, {
Expand All @@ -2186,7 +2189,7 @@ impl SessionInner {
tokio::select! {
_ = tokio::time::sleep(timeout) => {
let mut state = zwrite!(session.state);
if let Some(query) = state.liveliness_queries.remove(&id) {
if let Some(query) = state.liveliness_queries.try_remove(id as usize) {
std::mem::drop(state);
tracing::debug!("Timeout on liveliness query {}! Send error and close.", id);
query.callback.call(Reply {
Expand All @@ -2201,12 +2204,6 @@ impl SessionInner {
}
});

tracing::trace!("Register liveliness query {}", id);
let wexpr = key_expr.to_wire(self).to_owned();
state
.liveliness_queries
.insert(id, LivelinessQueryState { callback });

let primitives = state.primitives()?;
drop(state);

Expand Down Expand Up @@ -2402,7 +2399,8 @@ impl Primitives for WeakSession {
{
Ok(key_expr) => {
if let Some(interest_id) = msg.interest_id {
if let Some(query) = state.liveliness_queries.get(&interest_id) {
if let Some(query) = state.liveliness_queries.get(interest_id as usize)
{
let reply = Reply {
result: Ok(Sample {
key_expr,
Expand Down Expand Up @@ -2516,7 +2514,7 @@ impl Primitives for WeakSession {
#[cfg(feature = "unstable")]
if let Some(interest_id) = msg.interest_id {
let mut state = zwrite!(self.state);
let _ = state.liveliness_queries.remove(&interest_id);
let _ = state.liveliness_queries.try_remove(interest_id as usize);
}
}
}
Expand Down Expand Up @@ -2592,7 +2590,7 @@ impl Primitives for WeakSession {
if state.primitives.is_none() {
return; // Session closing or closed
}
match state.queries.get_mut(&msg.rid) {
match state.queries.get_mut(msg.rid as usize) {
Some(query) => {
let callback = query.callback.clone();
std::mem::drop(state);
Expand Down Expand Up @@ -2623,7 +2621,7 @@ impl Primitives for WeakSession {
return;
}
};
match state.queries.get_mut(&msg.rid) {
match state.queries.get_mut(msg.rid as usize) {
Some(query) => {
let c =
zcondfeat!("unstable", !query.parameters.reply_key_expr_any(), true);
Expand Down Expand Up @@ -2794,11 +2792,11 @@ impl Primitives for WeakSession {
if state.primitives.is_none() {
return; // Session closing or closed
}
match state.queries.get_mut(&msg.rid) {
match state.queries.get_mut(msg.rid as usize) {
Some(query) => {
query.nb_final -= 1;
if query.nb_final == 0 {
let query = state.queries.remove(&msg.rid).unwrap();
let query = state.queries.try_remove(msg.rid as usize).unwrap();
std::mem::drop(state);
if query.reception_mode == ConsolidationMode::Latest {
for (_, reply) in query.replies.unwrap().into_iter() {
Expand Down
Loading

0 comments on commit 08e13ba

Please sign in to comment.