Skip to content

Commit

Permalink
Merge pull request #59 from h3poteto/feat/autoreconnect
Browse files Browse the repository at this point in the history
Auto-reconnect WebSocket with tokio-tungstenite
  • Loading branch information
h3poteto authored Dec 27, 2022
2 parents 64d3499 + 8a261cf commit faee8f7
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 119 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ url = "2.2.2"
oauth2 = { version = "4.2" }
sha1 = { version = "0.10" }
hex = { version = "0.4" }
tungstenite = { version = "0.18", features = ["native-tls"] }
tokio-tungstenite = { version ="0.18", features = ["native-tls"] }
urlencoding = { version = "2.1" }
log = "0.4"
thiserror = "1"
futures-util = "0.3"

[dev-dependencies]
env_logger = "0.10"
2 changes: 1 addition & 1 deletion examples/mastodon_streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ fn streaming(url: &str, access_token: String) {
Some(access_token),
None,
);
let streaming = client.public_streaming(url.to_string());
let streaming = client.user_streaming(url.to_string());

streaming.listen(Box::new(|message| match message {
Message::Update(mes) => {
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub enum Error {
/// WebSocketError from [`tungstenite::error::Error`].
/// This error will be raised when tungstenite WebSocket raises an error.
#[error(transparent)]
WebSocketError(#[from] tungstenite::error::Error),
WebSocketError(#[from] tokio_tungstenite::tungstenite::error::Error),
/// JsonError from [`serde_json::Error`].
/// This error will be raised when failed to parse some json.
#[error(transparent)]
Expand Down
102 changes: 44 additions & 58 deletions src/mastodon/web_socket.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
use std::fmt;
use std::ops::Add;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;

use super::entities;
use crate::error::{Error, Kind};
use crate::streaming::{Message, Streaming};
use chrono::Utc;
use serde::Deserialize;
use tungstenite::protocol::frame::coding::CloseCode;
use tungstenite::protocol::CloseFrame;
use tungstenite::{connect, Message as WebSocketMessage};

use futures_util::{SinkExt, StreamExt};
use tokio::runtime::Runtime;
use tokio_tungstenite::{
connect_async, tungstenite::protocol::frame::coding::CloseCode,
tungstenite::protocol::Message as WebSocketMessage,
};
use url::Url;

const RECONNECT_INTERVAL: u64 = 1000;
const READ_MESSAGE_TIMEOUT_SECONDS: i64 = 60;
const RECONNECT_INTERVAL: u64 = 5000;
const READ_MESSAGE_TIMEOUT_SECONDS: u64 = 60;

#[derive(Debug, Clone)]
pub struct WebSocket {
Expand Down Expand Up @@ -110,15 +110,19 @@ impl WebSocket {

fn connect(&self, url: &str, callback: Box<dyn Fn(Message)>) {
loop {
match self.do_connect(url, &callback) {
match Runtime::new()
.unwrap()
.block_on(self.do_connect(url, &callback))
{
Ok(()) => {
log::info!("connection for {} is closed", url);
return;
}
Err(err) => match err.kind {
InnerKind::ConnectionError
| InnerKind::SocketReadError
| InnerKind::UnusualSocketCloseError => {
| InnerKind::UnusualSocketCloseError
| InnerKind::TimeoutError => {
thread::sleep(Duration::from_millis(RECONNECT_INTERVAL));
log::info!("Reconnecting to {}", url);
continue;
Expand All @@ -128,11 +132,16 @@ impl WebSocket {
}
}

fn do_connect(&self, url: &str, callback: &Box<dyn Fn(Message)>) -> Result<(), InnerError> {
let (socket, response) = connect(Url::parse(url).unwrap()).map_err(|e| {
log::error!("Failed to connect: {}", e);
InnerError::new(InnerKind::ConnectionError)
})?;
async fn do_connect(
&self,
url: &str,
callback: &Box<dyn Fn(Message)>,
) -> Result<(), InnerError> {
let (mut socket, response) =
connect_async(Url::parse(url).unwrap()).await.map_err(|e| {
log::error!("Failed to connect: {}", e);
InnerError::new(InnerKind::ConnectionError)
})?;

log::debug!("Connected to {}", url);
log::debug!("Response HTTP code: {}", response.status());
Expand All @@ -141,60 +150,35 @@ impl WebSocket {
log::debug!("* {}", header);
}

let last_received = Arc::new(Mutex::new(Utc::now()));
let last_received_check = Arc::clone(&last_received);
let socket = Arc::new(Mutex::new(socket));
let socket_check = Arc::clone(&socket);

let stop = Arc::new(AtomicBool::new(false));
let stop_check = Arc::clone(&stop);

thread::spawn(move || loop {
thread::sleep(Duration::from_secs(10));

if stop_check.load(Ordering::Relaxed) {
return;
}

let ts = last_received_check.lock().unwrap();
log::debug!("last received: {}", ts);
let diff = Utc::now() - ts.add(chrono::Duration::seconds(READ_MESSAGE_TIMEOUT_SECONDS));
if diff > chrono::Duration::seconds(0) {
log::warn!("closing connection because timeout");
socket_check
.lock()
.unwrap()
.close(Some(CloseFrame {
code: CloseCode::Again,
reason: std::borrow::Cow::Borrowed("Timeout"),
}))
.unwrap();
return;
}
});

loop {
let msg = socket.lock().unwrap().read_message().map_err(|e| {
let res = tokio::time::timeout(
Duration::from_secs(READ_MESSAGE_TIMEOUT_SECONDS),
socket.next(),
)
.await
.map_err(|e| {
log::error!("Timeout reading message: {}", e);
InnerError::new(InnerKind::TimeoutError)
})?;
let Some(r) = res else {
log::warn!("Response is empty");
continue;
};
let msg = r.map_err(|e| {
log::error!("Failed to read message: {}", e);
stop.store(true, Ordering::Relaxed);
InnerError::new(InnerKind::SocketReadError)
})?;
let mut ts = last_received.lock().unwrap();
*ts = Utc::now();
drop(ts);
if msg.is_ping() {
let _ = socket
.lock()
.unwrap()
.write_message(WebSocketMessage::Pong(Vec::<u8>::new()))
.send(WebSocketMessage::Pong(Vec::<u8>::new()))
.await
.map_err(|e| {
log::error!("{:#?}", e);
e
});
}
if msg.is_close() {
stop.store(true, Ordering::Relaxed);
let _ = socket.lock().unwrap().close(None).map_err(|e| {
let _ = socket.close(None).await.map_err(|e| {
log::error!("{:#?}", e);
e
});
Expand Down Expand Up @@ -248,6 +232,8 @@ enum InnerKind {
SocketReadError,
#[error("unusual socket close error")]
UnusualSocketCloseError,
#[error("timeout error")]
TimeoutError,
}

impl InnerError {
Expand Down
102 changes: 44 additions & 58 deletions src/pleroma/web_socket.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
use std::fmt;
use std::ops::Add;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;

use super::entities;
use crate::error::{Error, Kind};
use crate::streaming::{Message, Streaming};
use chrono::Utc;
use serde::Deserialize;
use tungstenite::protocol::frame::coding::CloseCode;
use tungstenite::protocol::CloseFrame;
use tungstenite::{connect, Message as WebSocketMessage};

use futures_util::{SinkExt, StreamExt};
use tokio::runtime::Runtime;
use tokio_tungstenite::{
connect_async, tungstenite::protocol::frame::coding::CloseCode,
tungstenite::protocol::Message as WebSocketMessage,
};
use url::Url;

const RECONNECT_INTERVAL: u64 = 1000;
const READ_MESSAGE_TIMEOUT_SECONDS: i64 = 60;
const RECONNECT_INTERVAL: u64 = 5000;
const READ_MESSAGE_TIMEOUT_SECONDS: u64 = 60;

#[derive(Debug, Clone)]
pub struct WebSocket {
Expand Down Expand Up @@ -110,15 +110,19 @@ impl WebSocket {

fn connect(&self, url: &str, callback: Box<dyn Fn(Message)>) {
loop {
match self.do_connect(url, &callback) {
match Runtime::new()
.unwrap()
.block_on(self.do_connect(url, &callback))
{
Ok(()) => {
log::info!("connection for {} is closed", url);
return;
}
Err(err) => match err.kind {
InnerKind::ConnectionError
| InnerKind::SocketReadError
| InnerKind::UnusualSocketCloseError => {
| InnerKind::UnusualSocketCloseError
| InnerKind::TimeoutError => {
thread::sleep(Duration::from_millis(RECONNECT_INTERVAL));
log::info!("Reconnecting to {}", url);
continue;
Expand All @@ -128,11 +132,16 @@ impl WebSocket {
}
}

fn do_connect(&self, url: &str, callback: &Box<dyn Fn(Message)>) -> Result<(), InnerError> {
let (socket, response) = connect(Url::parse(url).unwrap()).map_err(|e| {
log::error!("Failed to connect: {}", e);
InnerError::new(InnerKind::ConnectionError)
})?;
async fn do_connect(
&self,
url: &str,
callback: &Box<dyn Fn(Message)>,
) -> Result<(), InnerError> {
let (mut socket, response) =
connect_async(Url::parse(url).unwrap()).await.map_err(|e| {
log::error!("Failed to connect: {}", e);
InnerError::new(InnerKind::ConnectionError)
})?;

log::debug!("Connected to {}", url);
log::debug!("Response HTTP code: {}", response.status());
Expand All @@ -141,60 +150,35 @@ impl WebSocket {
log::debug!("* {}", header);
}

let last_received = Arc::new(Mutex::new(Utc::now()));
let last_received_check = Arc::clone(&last_received);
let socket = Arc::new(Mutex::new(socket));
let socket_check = Arc::clone(&socket);

let stop = Arc::new(AtomicBool::new(false));
let stop_check = Arc::clone(&stop);

thread::spawn(move || loop {
thread::sleep(Duration::from_secs(10));

if stop_check.load(Ordering::Relaxed) {
return;
}

let ts = last_received_check.lock().unwrap();
log::debug!("last received: {}", ts);
let diff = Utc::now() - ts.add(chrono::Duration::seconds(READ_MESSAGE_TIMEOUT_SECONDS));
if diff > chrono::Duration::seconds(0) {
log::warn!("closing connection because timeout");
socket_check
.lock()
.unwrap()
.close(Some(CloseFrame {
code: CloseCode::Again,
reason: std::borrow::Cow::Borrowed("Timeout"),
}))
.unwrap();
return;
}
});

loop {
let msg = socket.lock().unwrap().read_message().map_err(|e| {
let res = tokio::time::timeout(
Duration::from_secs(READ_MESSAGE_TIMEOUT_SECONDS),
socket.next(),
)
.await
.map_err(|e| {
log::error!("Timeout reading message: {}", e);
InnerError::new(InnerKind::TimeoutError)
})?;
let Some(r) = res else {
log::warn!("Response is empty");
continue;
};
let msg = r.map_err(|e| {
log::error!("Failed to read message: {}", e);
stop.store(true, Ordering::Relaxed);
InnerError::new(InnerKind::SocketReadError)
})?;
let mut ts = last_received.lock().unwrap();
*ts = Utc::now();
drop(ts);
if msg.is_ping() {
let _ = socket
.lock()
.unwrap()
.write_message(WebSocketMessage::Pong(Vec::<u8>::new()))
.send(WebSocketMessage::Pong(Vec::<u8>::new()))
.await
.map_err(|e| {
log::error!("{:#?}", e);
e
});
}
if msg.is_close() {
stop.store(true, Ordering::Relaxed);
let _ = socket.lock().unwrap().close(None).map_err(|e| {
let _ = socket.close(None).await.map_err(|e| {
log::error!("{:#?}", e);
e
});
Expand Down Expand Up @@ -248,6 +232,8 @@ enum InnerKind {
SocketReadError,
#[error("unusual socket close error")]
UnusualSocketCloseError,
#[error("timeout error")]
TimeoutError,
}

impl InnerError {
Expand Down

0 comments on commit faee8f7

Please sign in to comment.