Skip to content

Commit

Permalink
Merge pull request #13 from dbcfd/sync-lock
Browse files Browse the repository at this point in the history
Switch to async synchronization primitives
  • Loading branch information
mnetship authored Sep 6, 2020
2 parents 94576b5 + 2dfbe43 commit ddb481c
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 137 deletions.
36 changes: 20 additions & 16 deletions src/nats_client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use crate::ops::{Subscribe, Message, Publish};
use futures::{StreamExt};
use crate::nats_client::{NatsClient, NatsClientOptions, NatsClientInner, NatsSid, ReconnectHandler, NatsClientState};

use std::sync::{Arc, RwLock};
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::error::RatsioError;

use futures::lock::Mutex;
Expand All @@ -24,7 +25,7 @@ impl NatsClient {
opts,
server_info: RwLock::new(None),
subscriptions: Arc::new(Mutex::new(HashMap::default())),
on_reconnect: std::sync::Mutex::new(None),
on_reconnect: tokio::sync::Mutex::new(None),
state: RwLock::new(NatsClientState::Connecting),
last_ping: RwLock::new(NatsClientInner::time_in_millis()),
reconnect_version: RwLock::new(version),
Expand All @@ -43,13 +44,17 @@ impl NatsClient {
let arc_client = Arc::new(client);
let reconn_client = arc_client.clone();


if let Ok(mut client_ref) = arc_client.inner.client_ref.write() {
{
let mut client_ref = arc_client.inner.client_ref.write().await;
*client_ref = Some(arc_client.clone());
}

if let Ok(mut reconnect) = arc_client.inner.on_reconnect.lock() {
*reconnect = Some(Box::new(move || { reconn_client.on_reconnect() }));
{
let mut reconnect = arc_client.inner.on_reconnect.lock().await;
let reconnect_f = async move {
reconn_client.on_reconnect().await
};
*reconnect = Some(Box::pin(reconnect_f));
}

//heartbeat monitor
Expand Down Expand Up @@ -142,19 +147,18 @@ impl NatsClient {
self.inner.stop().await
}

pub fn add_reconnect_handler(&self, handler: ReconnectHandler) -> Result<(), RatsioError> {
if let Ok(mut handlers) = self.reconnect_handlers.write() {
handlers.push(handler);
}
pub async fn add_reconnect_handler(&self, handler: ReconnectHandler) -> Result<(), RatsioError> {
let mut handlers = self.reconnect_handlers.write().await;
handlers.push(handler);

Ok(())
}

pub (in crate::nats_client) fn on_reconnect(&self) -> () {
if let Ok(handlers) = self.reconnect_handlers.read() {
let handlers: &Vec<ReconnectHandler> = handlers.as_ref();
for handler in handlers {
handler(self)
}
pub (in crate::nats_client) async fn on_reconnect(&self) -> () {
let handlers = self.reconnect_handlers.read().await;
let handlers: &Vec<ReconnectHandler> = handlers.as_ref();
for handler in handlers {
handler(self)
}
}
}
113 changes: 45 additions & 68 deletions src/nats_client/client_inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,8 @@ impl NatsClientInner {
let stream_self = self_arc.clone();
let _ = tokio::spawn(async move {
while let Some(item) = stream.next().await {
let current_version = if let Ok(version) = stream_self.reconnect_version.read() {
*version
} else {
1
};
if current_version != version {
let current_version = stream_self.reconnect_version.read().await;
if *current_version != version {
break;
}
stream_self.process_nats_event(item).await
Expand All @@ -108,9 +104,8 @@ impl NatsClientInner {
nkey: None,
});
self_arc.send_command(connect).await?;
if let Ok(mut state_guard) = self_arc.state.write() {
*state_guard = NatsClientState::Connected;
}
let mut state_guard = self_arc.state.write().await;
*state_guard = NatsClientState::Connected;
Ok(())
}

Expand All @@ -123,15 +118,14 @@ impl NatsClientInner {
}

pub(in crate::nats_client) async fn process_nats_event(&self, item: Op) {
self.ping_pong_reset();
self.ping_pong_reset().await;
match item {
Op::CLOSE => {
let _ = self.stop().await;
}
Op::INFO(server_info) => {
if let Ok(mut info) = self.server_info.write() {
*info = Some(server_info)
}
let mut info = self.server_info.write().await;
*info = Some(server_info)
}
Op::PING => {
match self.send_command(Op::PONG).await {
Expand All @@ -156,10 +150,9 @@ impl NatsClientInner {
}
}

pub(in crate::nats_client) fn ping_pong_reset(&self) {
if let Ok(mut last_ping) = self.last_ping.write() {
*last_ping = Self::time_in_millis();
}
pub(in crate::nats_client) async fn ping_pong_reset(&self) {
let mut last_ping = self.last_ping.write().await;
*last_ping = Self::time_in_millis();
}

pub(in crate::nats_client) async fn subscribe(
Expand Down Expand Up @@ -227,12 +220,11 @@ impl NatsClientInner {
}

pub(in crate::nats_client) async fn stop(&self) -> Result<(), RatsioError> {
if let Ok(mut state_guard) = self.state.write() {
*state_guard = NatsClientState::Shutdown;
}
if let Ok(mut reconnect) = self.on_reconnect.lock() {
*reconnect = None;
}
let mut state_guard = self.state.write().await;
*state_guard = NatsClientState::Shutdown;

let mut reconnect = self.on_reconnect.lock().await;
*reconnect = None;

//Close all subscritions.
let mut subscriptions = self.subscriptions.lock().await;
Expand All @@ -245,60 +237,48 @@ impl NatsClientInner {
let _ = self.send_command(cmd).await;
}
subscriptions.clear();
if let Ok(mut client_ref) = self.client_ref.write() {
*client_ref = None
}
let mut client_ref = self.client_ref.write().await;
*client_ref = None;

Ok(())
}

pub async fn reconnect(&self) -> Result<(), RatsioError> {
if let Ok(mut state_guard) = self.state.write() {
if *state_guard == NatsClientState::Disconnected {
*state_guard = NatsClientState::Reconnecting;
} else {
return Ok(());
}
let mut state_guard = self.state.write().await;
if *state_guard == NatsClientState::Disconnected {
*state_guard = NatsClientState::Reconnecting;
} else {
return Err(RatsioError::CannotReconnectToServer);
return Ok(());
}

match self.do_reconnect().await {
Ok(_) => {
if let Ok(mut state_guard) = self.state.write() {
*state_guard = NatsClientState::Connected;
}
let mut state_guard = self.state.write().await;
*state_guard = NatsClientState::Connected;
Ok(())
}
Err(err) => {
error!("Error trying to reconnect to NATS {:?}", err);
if let Ok(mut state_guard) = self.state.write() {
*state_guard = NatsClientState::Disconnected;
}
let mut state_guard = self.state.write().await;
*state_guard = NatsClientState::Disconnected;
Err(err)
}
}
}

async fn do_reconnect(&self) -> Result<(), RatsioError> {
let client_ref = if let Ok(client_ref_guard) = self.client_ref.read() {
if let Some(client_ref) = client_ref_guard.as_ref() {
client_ref.clone()
} else {
return Err(RatsioError::CannotReconnectToServer);
}
let client_ref_guard = self.client_ref.read().await;
let client_ref = if let Some(client_ref) = client_ref_guard.as_ref() {
client_ref.clone()
} else {
return Err(RatsioError::InternalServerError);
return Err(RatsioError::CannotReconnectToServer);
};
let tcp_stream = Self::try_connect(self.opts.clone(), &self.opts.cluster_uris.0, true).await?;
let (sink, stream) = NatsTcpStream::new(tcp_stream).await.split();
*self.conn_sink.lock().await = sink;
let new_version = if let Ok(mut version) = self.reconnect_version.write() {
let new_version = *version + 1;
*version = new_version;
new_version
} else {
return Err(RatsioError::CannotReconnectToServer);
};
let mut version = self.reconnect_version.write().await;
let new_version = *version + 1;
*version = new_version;
info!("Reconnecting to NATS servers 4 - new version {}", new_version);
let _ = NatsClientInner::start(client_ref.inner.clone(), new_version, stream).await?;
if self.opts.subscribe_on_reconnect {
Expand All @@ -314,7 +294,7 @@ impl NatsClientInner {
}
}
}
client_ref.on_reconnect();
client_ref.on_reconnect().await;
Ok(())
}

Expand All @@ -328,10 +308,9 @@ impl NatsClientInner {
let ping_max_out = u128::from(self.opts.ping_max_out);
loop {
let _ = Delay::new(Duration::from_millis((ping_interval / 2) as u64)).await;
if let Ok(state_guard) = self.state.read() {
if *state_guard == NatsClientState::Shutdown {
break;
}
let state_guard = self.state.read().await;
if *state_guard == NatsClientState::Shutdown {
break;
}

let mut reconnect_required = false;
Expand All @@ -345,21 +324,19 @@ impl NatsClientInner {
if !reconnect_required {
let _ = Delay::new(Duration::from_millis((ping_interval / 2) as u64)).await;
let now = Self::time_in_millis();
if let Ok(last_ping) = self.last_ping.read() {
if now - *last_ping > ping_interval {
error!("Missed ping interval")
}
if (now - *last_ping) > (ping_max_out * ping_interval) {
reconnect_required = true;
}
let last_ping = self.last_ping.read().await;
if now - *last_ping > ping_interval {
error!("Missed ping interval")
}
if (now - *last_ping) > (ping_max_out * ping_interval) {
reconnect_required = true;
}
}

if reconnect_required {
error!("Missed too many pings, reconnect is required.");
if let Ok(mut state_guard) = self.state.write() {
*state_guard = NatsClientState::Disconnected
}
let mut state_guard = self.state.write().await;
*state_guard = NatsClientState::Disconnected;
let _ = self.reconnect().await;
}
}
Expand Down
7 changes: 5 additions & 2 deletions src/nats_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ mod client_inner;

use crate::net::nats_tcp_stream::NatsTcpStream;
use crate::ops::{ServerInfo, Op, Message, Subscribe};
use std::sync::{Arc, RwLock};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::RwLock;

use std::fmt::Debug;
use futures::stream::{ SplitSink};
Expand Down Expand Up @@ -116,7 +119,7 @@ pub struct NatsClientInner {
/// Server info
server_info: RwLock<Option<ServerInfo>>,
subscriptions: Arc<Mutex<HashMap<String, (UnboundedSender<ClosableMessage>, Subscribe)>>>,
on_reconnect: std::sync::Mutex<Option<Box<dyn Fn() -> () + Send + Sync>>>,
on_reconnect: tokio::sync::Mutex<Option<Pin<Box<dyn Future<Output=()> + Send + Sync>>>>,
state: RwLock<NatsClientState>,
last_ping: RwLock<u128>,
client_ref: RwLock<Option<Arc<NatsClient>>>,
Expand Down
Loading

0 comments on commit ddb481c

Please sign in to comment.