Skip to content

Commit

Permalink
Fix detection of closed connections
Browse files Browse the repository at this point in the history
When the remote peer closes a connection, this causes the rx task to
stop, but the tx side keeps running until it is used to send a message,
which is guaranteed to fail. The fix is to add a 'closed' signal to
tell the tx task to stop after the rx task finishes.
  • Loading branch information
XA21X committed Aug 26, 2021
1 parent e033bf7 commit 6821cd4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl TransformsFromConfig for RedisClusterConfig {

for (node, _, _) in &slots.masters {
match connection_pool
.get_connection(&node, self.connection_count.unwrap_or(1))
.get_connections(&node, self.connection_count.unwrap_or(1))
.await
{
Ok(conn) => {
Expand Down Expand Up @@ -154,7 +154,7 @@ impl RedisCluster {
if let Ok(res) = timeout(
Duration::from_millis(40),
self.connection_pool
.get_connection(host, self.connection_count),
.get_connections(host, self.connection_count),
)
.await
{
Expand Down
105 changes: 68 additions & 37 deletions shotover-proxy/src/transforms/util/cluster_connection_pool.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,34 @@
use crate::server::CodecReadHalf;
use crate::server::CodecWriteHalf;
use crate::transforms::util::Request;
use crate::{message::Messages, server::Codec};
use anyhow::{anyhow, Result};
use futures::StreamExt;
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::fmt::Formatter;
use std::iter::FromIterator;
use std::sync::Arc;

use anyhow::{anyhow, Result};
use derivative::Derivative;
use futures::StreamExt;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::Mutex;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::codec::{Decoder, FramedRead, FramedWrite};
use tracing::{debug, info};
use tracing::{debug, info, trace};

#[derive(Clone)]
use crate::server::CodecReadHalf;
use crate::server::CodecWriteHalf;
use crate::transforms::util::Request;
use crate::{message::Messages, server::Codec};

#[derive(Clone, Derivative)]
#[derivative(Debug)]
pub struct ConnectionPool<C: Codec> {
host_set: Arc<Mutex<HashSet<String>>>,
queue_map: Arc<Mutex<HashMap<String, Vec<UnboundedSender<Request>>>>>,

#[derivative(Debug = "ignore")]
codec: C,
auth_func: fn(&ConnectionPool<C>, &mut UnboundedSender<Request>) -> Result<()>,
}

impl<C: Codec> fmt::Debug for ConnectionPool<C> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("ConnectionPool")
.field("host_set", &self.host_set)
.field("queue_map", &self.queue_map)
.finish()
}
#[derivative(Debug = "ignore")]
auth_func: fn(&ConnectionPool<C>, &mut UnboundedSender<Request>) -> Result<()>,
}

impl<C: Codec + 'static> ConnectionPool<C> {
Expand Down Expand Up @@ -60,7 +57,7 @@ impl<C: Codec + 'static> ConnectionPool<C> {
/// Try and grab an existing connection, if it's closed (e.g. the listener on the other side
/// has closed due to a TCP error), we'll try to reconnect and return the new connection while
/// updating the connection map. Errors are returned when a connection can't be established.
pub async fn get_connection(
pub async fn get_connections(
&self,
host: &String,
connection_count: i32,
Expand All @@ -71,48 +68,83 @@ impl<C: Codec + 'static> ConnectionPool<C> {
return Ok(x.clone());
}
}
let connection = self.connect(&host, connection_count).await?;
queue_map.insert(host.clone(), connection.clone());
Ok(connection)
let connections = self.new_connections(&host, connection_count).await?;
queue_map.insert(host.clone(), connections.clone());
Ok(connections)
}

pub async fn connect(
pub async fn new_connections(
&self,
host: &String,
connection_count: i32,
) -> Result<Vec<UnboundedSender<Request>>>
where
<C as Decoder>::Error: std::marker::Send,
{
let mut connection_pool: Vec<UnboundedSender<Request>> = Vec::new();
let mut connections: Vec<UnboundedSender<Request>> = Vec::new();

for _i in 0..connection_count {
let socket: TcpStream = TcpStream::connect(host).await?;
let (read, write) = socket.into_split();
let (mut out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<Request>();
let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel::<Request>();

tokio::spawn(tx_process(write, out_rx, return_tx, self.codec.clone()));
let stream = TcpStream::connect(host).await?;
let mut out_tx = spawn_from_stream(&self.codec, stream);

tokio::spawn(rx_process(read, return_rx, self.codec.clone()));
match (self.auth_func)(&self, &mut out_tx) {
Ok(_) => {
connection_pool.push(out_tx);
connections.push(out_tx);
}
Err(e) => {
info!("Could not authenticate to upstream TCP service - {}", e);
}
}
}

if connection_pool.len() == 0 {
if connections.len() == 0 {
Err(anyhow!("Couldn't connect to upstream TCP service"))
} else {
Ok(connection_pool)
Ok(connections)
}
}
}

pub fn spawn_from_stream<C: Codec + 'static>(
codec: &C,
stream: TcpStream,
) -> UnboundedSender<Request> {
let (read, write) = stream.into_split();
let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<Request>();
let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel::<Request>();
let (closed_tx, closed_rx) = tokio::sync::oneshot::channel();

let codec_clone = codec.clone();

tokio::spawn(async move {
tokio::select! {
result = tx_process(write, out_rx, return_tx, codec_clone) => if let Err(e) = result {
trace!("connection write-closed with error: {:?}", e);
} else {
trace!("connection write-closed gracefully");
},
_ = closed_rx => {
trace!("connection write-closed by remote upstream");
},
}
});

let codec_clone = codec.clone();

tokio::spawn(async move {
if let Err(e) = rx_process(read, return_rx, codec_clone).await {
trace!("connection read-closed with error: {:?}", e);
} else {
trace!("connection read-closed gracefully");
}

// Signal the writer to also exit, which then closes `out_tx` - what we consider as the connection.
closed_tx.send(())
});

out_tx
}

async fn tx_process<C: CodecWriteHalf>(
write: OwnedWriteHalf,
out_rx: UnboundedReceiver<Request>,
Expand All @@ -127,8 +159,7 @@ async fn tx_process<C: CodecWriteHalf>(
return_tx.send(x)?;
ret
});
rx_stream.forward(in_w).await?;
Ok(())
rx_stream.forward(in_w).await
}

async fn rx_process<C: CodecReadHalf>(
Expand Down

0 comments on commit 6821cd4

Please sign in to comment.