Skip to content

Commit

Permalink
fix: rollback session modification
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Nov 22, 2024
1 parent 3290782 commit a2b1583
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 65 deletions.
3 changes: 2 additions & 1 deletion commons/zenoh-protocol/src/core/wire_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use alloc::{
borrow::Cow,
string::{String, ToString},
};
use core::{convert::TryInto, fmt};
use core::{convert::TryInto, fmt, sync::atomic::AtomicU16};

use zenoh_keyexpr::{keyexpr, OwnedKeyExpr};
use zenoh_result::{bail, ZResult};
Expand All @@ -28,6 +28,7 @@ use crate::network::Mapping;
pub type ExprId = u16;
pub type ExprLen = u16;

pub type AtomicExprId = AtomicU16;
pub const EMPTY_EXPR_ID: ExprId = 0;

/// A zenoh **resource** is represented by a pair composed by a **key** and a
Expand Down
2 changes: 1 addition & 1 deletion commons/zenoh-protocol/src/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub use declare::{
pub use interest::Interest;
pub use oam::Oam;
pub use push::Push;
pub use request::{Request, RequestId};
pub use request::{AtomicRequestId, Request, RequestId};
pub use response::{Response, ResponseFinal};

use crate::core::{CongestionControl, Priority, Reliability};
Expand Down
3 changes: 3 additions & 0 deletions commons/zenoh-protocol/src/network/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
// Contributors:
// ZettaScale Zenoh Team, <zenoh@zettascale.tech>
//
use core::sync::atomic::AtomicU32;

use crate::{core::WireExpr, zenoh::RequestBody};

/// The resolution of a RequestId
pub type RequestId = u32;
pub type AtomicRequestId = AtomicU32;

pub mod flag {
pub const N: u8 = 1 << 5; // 0x20 Named if N==1 then the key expr has name/suffix
Expand Down
3 changes: 1 addition & 2 deletions zenoh/src/api/key_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,8 +618,7 @@ impl Wait for KeyExprUndeclaration<'_> {
};
tracing::trace!("undeclare_keyexpr({:?})", expr_id);
let mut state = zwrite!(session.0.state);
assert_ne!(expr_id, 0, "0 is not a valid keyexpr id");
state.local_resources.remove(expr_id as usize);
state.local_resources.remove(&expr_id);

let primitives = state.primitives()?;
drop(state);
Expand Down
121 changes: 60 additions & 61 deletions zenoh/src/api/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use std::{
time::{Duration, SystemTime, UNIX_EPOCH},
};

use slab::Slab;
use tracing::{error, info, trace, warn};
use uhlc::Timestamp;
#[cfg(feature = "internal")]
Expand All @@ -43,7 +42,8 @@ use zenoh_protocol::network::{
use zenoh_protocol::{
core::{
key_expr::{keyexpr, OwnedKeyExpr},
CongestionControl, EntityId, ExprId, Parameters, Reliability, WireExpr, EMPTY_EXPR_ID,
AtomicExprId, CongestionControl, EntityId, ExprId, Parameters, Reliability, WireExpr,
EMPTY_EXPR_ID,
},
network::{
self,
Expand All @@ -53,8 +53,8 @@ use zenoh_protocol::{
UndeclareSubscriber,
},
interest::{InterestMode, InterestOptions},
push, request, DeclareFinal, Interest, Mapping, Push, Request, RequestId, Response,
ResponseFinal,
push, request, AtomicRequestId, DeclareFinal, Interest, Mapping, Push, Request, RequestId,
Response, ResponseFinal,
},
zenoh::{
query::{self, ext::QueryBodyType},
Expand Down Expand Up @@ -121,7 +121,11 @@ zconfigurable! {

pub(crate) struct SessionState {
pub(crate) primitives: Option<Arc<Face>>, // @TODO replace with MaybeUninit ??
pub(crate) local_resources: Slab<Resource>,
pub(crate) expr_id_counter: AtomicExprId, // @TODO: manage rollover and uniqueness
pub(crate) qid_counter: AtomicRequestId,
#[cfg(feature = "unstable")]
pub(crate) liveliness_qid_counter: AtomicRequestId,
pub(crate) local_resources: HashMap<ExprId, Resource>,
pub(crate) remote_resources: HashMap<ExprId, Resource>,
#[cfg(feature = "unstable")]
pub(crate) remote_subscribers: HashMap<SubscriberId, KeyExpr<'static>>,
Expand All @@ -136,9 +140,9 @@ pub(crate) struct SessionState {
pub(crate) tokens: HashMap<Id, Arc<LivelinessTokenState>>,
#[cfg(feature = "unstable")]
pub(crate) matching_listeners: HashMap<Id, Arc<MatchingListenerState>>,
pub(crate) queries: Slab<QueryState>,
pub(crate) queries: HashMap<RequestId, QueryState>,
#[cfg(feature = "unstable")]
pub(crate) liveliness_queries: Slab<LivelinessQueryState>,
pub(crate) liveliness_queries: HashMap<InterestId, LivelinessQueryState>,
pub(crate) aggregated_subscribers: Vec<OwnedKeyExpr>,
pub(crate) aggregated_publishers: Vec<OwnedKeyExpr>,
}
Expand All @@ -148,12 +152,13 @@ impl SessionState {
aggregated_subscribers: Vec<OwnedKeyExpr>,
aggregated_publishers: Vec<OwnedKeyExpr>,
) -> SessionState {
// Note: local_resources start at 1 because 0 is reserved for NO_RESOURCE
let mut local_resources = Slab::new();
local_resources.insert(Resource::Prefix { prefix: "".into() });
SessionState {
primitives: None,
local_resources,
expr_id_counter: AtomicExprId::new(1), // Note: start at 1 because 0 is reserved for NO_RESOURCE
qid_counter: AtomicRequestId::new(0),
#[cfg(feature = "unstable")]
liveliness_qid_counter: AtomicRequestId::new(0),
local_resources: HashMap::new(),
remote_resources: HashMap::new(),
#[cfg(feature = "unstable")]
remote_subscribers: HashMap::new(),
Expand All @@ -168,9 +173,9 @@ impl SessionState {
tokens: HashMap::new(),
#[cfg(feature = "unstable")]
matching_listeners: HashMap::new(),
queries: Slab::new(),
queries: HashMap::new(),
#[cfg(feature = "unstable")]
liveliness_queries: Slab::new(),
liveliness_queries: HashMap::new(),
aggregated_subscribers,
aggregated_publishers,
}
Expand All @@ -188,16 +193,13 @@ impl SessionState {

#[inline]
fn get_local_res(&self, id: &ExprId) -> Option<&Resource> {
if *id == 0 {
return None;
}
self.local_resources.get(*id as usize)
self.local_resources.get(id)
}

#[inline]
fn get_remote_res(&self, id: &ExprId, mapping: Mapping) -> Option<&Resource> {
match mapping {
Mapping::Receiver => self.get_local_res(id),
Mapping::Receiver => self.local_resources.get(id),
Mapping::Sender => self.remote_resources.get(id),
}
}
Expand Down Expand Up @@ -1127,11 +1129,11 @@ impl SessionInner {
match state
.local_resources
.iter()
.skip(1) // skip NO_RESOURCE
.find(|(_, res)| res.name() == prefix)
.find(|(_expr_id, res)| res.name() == prefix)
{
Some((expr_id, _res)) => Ok(expr_id as ExprId),
Some((expr_id, _res)) => Ok(*expr_id),
None => {
let expr_id = state.expr_id_counter.fetch_add(1, Ordering::SeqCst);
let mut res = Resource::new(Box::from(prefix));
if let Resource::Node(res_node) = &mut res {
for kind in [
Expand All @@ -1145,10 +1147,7 @@ impl SessionInner {
}
}
}
if state.local_resources.vacant_key() > ExprId::MAX as usize {
bail!("too many keyexprs declared");
}
let expr_id = state.local_resources.insert(res) as ExprId;
state.local_resources.insert(expr_id, res);
drop(state);
primitives.send_declare(Declare {
interest_id: None,
Expand Down Expand Up @@ -1333,9 +1332,8 @@ impl SessionInner {
.insert(sub_state.id, sub_state.clone());
for res in state
.local_resources
.iter_mut()
.skip(1) // skip NO_RESOURCE
.filter_map(|(_, res)|res.as_node_mut())
.values_mut()
.filter_map(Resource::as_node_mut)
{
if key_expr.intersects(&res.key_expr) {
res.subscribers_mut(SubscriberKind::Subscriber)
Expand Down Expand Up @@ -1413,9 +1411,8 @@ impl SessionInner {
trace!("undeclare_subscriber({:?})", sub_state);
for res in state
.local_resources
.iter_mut()
.skip(1) // skip NO_RESOURCE
.filter_map(|(_, res)|res.as_node_mut())
.values_mut()
.filter_map(Resource::as_node_mut)
{
res.subscribers_mut(kind)
.retain(|sub| sub.id != sub_state.id);
Expand Down Expand Up @@ -1607,9 +1604,8 @@ impl SessionInner {

for res in state
.local_resources
.iter_mut()
.skip(1) // skip NO_RESOURCE
.filter_map(|(_, res)|res.as_node_mut())
.values_mut()
.filter_map(Resource::as_node_mut)
{
if key_expr.intersects(&res.key_expr) {
res.subscribers_mut(SubscriberKind::LivelinessSubscriber)
Expand Down Expand Up @@ -2069,22 +2065,12 @@ impl SessionInner {
ConsolidationMode::Auto => ConsolidationMode::Latest,
mode => mode,
};
let qid = state.qid_counter.fetch_add(1, Ordering::SeqCst);
let nb_final = match destination {
Locality::Any => 2,
_ => 1,
};

let wexpr = key_expr.to_wire(self).to_owned();
let qid = state.queries.insert(QueryState {
nb_final,
key_expr: key_expr.clone().into_owned(),
parameters: parameters.clone().into_owned(),
reception_mode: consolidation,
replies: (consolidation != ConsolidationMode::None).then(HashMap::new),
callback,
}) as RequestId;
tracing::trace!("Register query {} (nb_final = {})", qid, nb_final);

let token = self.task_controller.get_cancellation_token();
self.task_controller
.spawn_with_rt(zenoh_runtime::ZRuntime::Net, {
Expand All @@ -2095,7 +2081,7 @@ impl SessionInner {
tokio::select! {
_ = tokio::time::sleep(timeout) => {
let mut state = zwrite!(session.state);
if let Some(query) = state.queries.try_remove(qid as usize) {
if let Some(query) = state.queries.remove(&qid) {
std::mem::drop(state);
tracing::debug!("Timeout on query {}! Send error and close.", qid);
if query.reception_mode == ConsolidationMode::Latest {
Expand All @@ -2115,6 +2101,20 @@ impl SessionInner {
}
});

tracing::trace!("Register query {} (nb_final = {})", qid, nb_final);
let wexpr = key_expr.to_wire(self).to_owned();
state.queries.insert(
qid,
QueryState {
nb_final,
key_expr: key_expr.clone().into_owned(),
parameters: parameters.clone().into_owned(),
reception_mode: consolidation,
replies: (consolidation != ConsolidationMode::None).then(HashMap::new),
callback,
},
);

let primitives = state.primitives()?;
drop(state);

Expand Down Expand Up @@ -2176,13 +2176,7 @@ impl SessionInner {
) -> ZResult<()> {
tracing::trace!("liveliness.get({}, {:?})", key_expr, timeout);
let mut state = zwrite!(self.state);

let wexpr = key_expr.to_wire(self).to_owned();
let id = state
.liveliness_queries
.insert(LivelinessQueryState { callback }) as InterestId;
tracing::trace!("Register liveliness query {}", id);

let id = state.liveliness_qid_counter.fetch_add(1, Ordering::SeqCst);
let token = self.task_controller.get_cancellation_token();
self.task_controller
.spawn_with_rt(zenoh_runtime::ZRuntime::Net, {
Expand All @@ -2192,7 +2186,7 @@ impl SessionInner {
tokio::select! {
_ = tokio::time::sleep(timeout) => {
let mut state = zwrite!(session.state);
if let Some(query) = state.liveliness_queries.try_remove(id as usize) {
if let Some(query) = state.liveliness_queries.remove(&id) {
std::mem::drop(state);
tracing::debug!("Timeout on liveliness query {}! Send error and close.", id);
query.callback.call(Reply {
Expand All @@ -2207,6 +2201,12 @@ impl SessionInner {
}
});

tracing::trace!("Register liveliness query {}", id);
let wexpr = key_expr.to_wire(self).to_owned();
state
.liveliness_queries
.insert(id, LivelinessQueryState { callback });

let primitives = state.primitives()?;
drop(state);

Expand Down Expand Up @@ -2402,8 +2402,7 @@ impl Primitives for WeakSession {
{
Ok(key_expr) => {
if let Some(interest_id) = msg.interest_id {
if let Some(query) = state.liveliness_queries.get(interest_id as usize)
{
if let Some(query) = state.liveliness_queries.get(&interest_id) {
let reply = Reply {
result: Ok(Sample {
key_expr,
Expand Down Expand Up @@ -2517,7 +2516,7 @@ impl Primitives for WeakSession {
#[cfg(feature = "unstable")]
if let Some(interest_id) = msg.interest_id {
let mut state = zwrite!(self.state);
let _ = state.liveliness_queries.try_remove(interest_id as usize);
let _ = state.liveliness_queries.remove(&interest_id);
}
}
}
Expand Down Expand Up @@ -2593,7 +2592,7 @@ impl Primitives for WeakSession {
if state.primitives.is_none() {
return; // Session closing or closed
}
match state.queries.get_mut(msg.rid as usize) {
match state.queries.get_mut(&msg.rid) {
Some(query) => {
let callback = query.callback.clone();
std::mem::drop(state);
Expand Down Expand Up @@ -2624,7 +2623,7 @@ impl Primitives for WeakSession {
return;
}
};
match state.queries.get_mut(msg.rid as usize) {
match state.queries.get_mut(&msg.rid) {
Some(query) => {
let c =
zcondfeat!("unstable", !query.parameters.reply_key_expr_any(), true);
Expand Down Expand Up @@ -2795,11 +2794,11 @@ impl Primitives for WeakSession {
if state.primitives.is_none() {
return; // Session closing or closed
}
match state.queries.get_mut(msg.rid as usize) {
match state.queries.get_mut(&msg.rid) {
Some(query) => {
query.nb_final -= 1;
if query.nb_final == 0 {
let query = state.queries.try_remove(msg.rid as usize).unwrap();
let query = state.queries.remove(&msg.rid).unwrap();
std::mem::drop(state);
if query.reception_mode == ConsolidationMode::Latest {
for (_, reply) in query.replies.unwrap().into_iter() {
Expand Down

0 comments on commit a2b1583

Please sign in to comment.