Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix detection of closed connections in cluster connection pools #146

Merged
merged 6 commits into from
Aug 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl TransformsFromConfig for RedisClusterConfig {

for (node, _, _) in &slots.masters {
match connection_pool
.get_connection(&node, self.connection_count.unwrap_or(1))
.get_connections(&node, self.connection_count.unwrap_or(1))
.await
{
Ok(conn) => {
Expand Down Expand Up @@ -154,7 +154,7 @@ impl RedisCluster {
if let Ok(res) = timeout(
Duration::from_millis(40),
self.connection_pool
.get_connection(host, self.connection_count),
.get_connections(host, self.connection_count),
)
.await
{
Expand Down
192 changes: 155 additions & 37 deletions shotover-proxy/src/transforms/util/cluster_connection_pool.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,34 @@
use crate::server::CodecReadHalf;
use crate::server::CodecWriteHalf;
use crate::transforms::util::Request;
use crate::{message::Messages, server::Codec};
use anyhow::{anyhow, Result};
use futures::StreamExt;
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::fmt::Formatter;
use std::iter::FromIterator;
use std::sync::Arc;

use anyhow::{anyhow, Result};
use derivative::Derivative;
use futures::StreamExt;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::Mutex;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::codec::{Decoder, FramedRead, FramedWrite};
use tracing::{debug, info};
use tracing::{debug, info, trace};

#[derive(Clone)]
use crate::server::CodecReadHalf;
use crate::server::CodecWriteHalf;
use crate::transforms::util::Request;
use crate::{message::Messages, server::Codec};

#[derive(Clone, Derivative)]
#[derivative(Debug)]
pub struct ConnectionPool<C: Codec> {
host_set: Arc<Mutex<HashSet<String>>>,
queue_map: Arc<Mutex<HashMap<String, Vec<UnboundedSender<Request>>>>>,

#[derivative(Debug = "ignore")]
codec: C,
auth_func: fn(&ConnectionPool<C>, &mut UnboundedSender<Request>) -> Result<()>,
}

impl<C: Codec> fmt::Debug for ConnectionPool<C> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("ConnectionPool")
.field("host_set", &self.host_set)
.field("queue_map", &self.queue_map)
.finish()
}
#[derivative(Debug = "ignore")]
auth_func: fn(&ConnectionPool<C>, &mut UnboundedSender<Request>) -> Result<()>,
}

impl<C: Codec + 'static> ConnectionPool<C> {
Expand Down Expand Up @@ -60,7 +57,7 @@ impl<C: Codec + 'static> ConnectionPool<C> {
/// Try and grab an existing connection, if it's closed (e.g. the listener on the other side
/// has closed due to a TCP error), we'll try to reconnect and return the new connection while
/// updating the connection map. Errors are returned when a connection can't be established.
pub async fn get_connection(
pub async fn get_connections(
&self,
host: &String,
connection_count: i32,
Expand All @@ -71,48 +68,83 @@ impl<C: Codec + 'static> ConnectionPool<C> {
return Ok(x.clone());
}
}
let connection = self.connect(&host, connection_count).await?;
queue_map.insert(host.clone(), connection.clone());
Ok(connection)
let connections = self.new_connections(&host, connection_count).await?;
queue_map.insert(host.clone(), connections.clone());
Ok(connections)
}

pub async fn connect(
pub async fn new_connections(
&self,
host: &String,
connection_count: i32,
) -> Result<Vec<UnboundedSender<Request>>>
where
<C as Decoder>::Error: std::marker::Send,
{
let mut connection_pool: Vec<UnboundedSender<Request>> = Vec::new();
let mut connections: Vec<UnboundedSender<Request>> = Vec::new();

for _i in 0..connection_count {
let socket: TcpStream = TcpStream::connect(host).await?;
let (read, write) = socket.into_split();
let (mut out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<Request>();
let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel::<Request>();
let stream = TcpStream::connect(host).await?;
let mut out_tx = spawn_from_stream(&self.codec, stream);

tokio::spawn(tx_process(write, out_rx, return_tx, self.codec.clone()));

tokio::spawn(rx_process(read, return_rx, self.codec.clone()));
rukai marked this conversation as resolved.
Show resolved Hide resolved
match (self.auth_func)(&self, &mut out_tx) {
Ok(_) => {
connection_pool.push(out_tx);
connections.push(out_tx);
}
Err(e) => {
info!("Could not authenticate to upstream TCP service - {}", e);
}
}
}

if connection_pool.len() == 0 {
if connections.len() == 0 {
Err(anyhow!("Couldn't connect to upstream TCP service"))
} else {
Ok(connection_pool)
Ok(connections)
}
}
}

pub fn spawn_from_stream<C: Codec + 'static>(
codec: &C,
stream: TcpStream,
) -> UnboundedSender<Request> {
let (read, write) = stream.into_split();
let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<Request>();
let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel::<Request>();
let (closed_tx, closed_rx) = tokio::sync::oneshot::channel();

let codec_clone = codec.clone();

tokio::spawn(async move {
tokio::select! {
result = tx_process(write, out_rx, return_tx, codec_clone) => if let Err(e) = result {
trace!("connection write-closed with error: {:?}", e);
} else {
trace!("connection write-closed gracefully");
},
_ = closed_rx => {
trace!("connection write-closed by remote upstream");
},
}
});

let codec_clone = codec.clone();

tokio::spawn(async move {
if let Err(e) = rx_process(read, return_rx, codec_clone).await {
trace!("connection read-closed with error: {:?}", e);
} else {
trace!("connection read-closed gracefully");
}

// Signal the writer to also exit, which then closes `out_tx` - what we consider as the connection.
closed_tx.send(())
});

out_tx
}

rukai marked this conversation as resolved.
Show resolved Hide resolved
async fn tx_process<C: CodecWriteHalf>(
write: OwnedWriteHalf,
out_rx: UnboundedReceiver<Request>,
Expand All @@ -127,8 +159,7 @@ async fn tx_process<C: CodecWriteHalf>(
return_tx.send(x)?;
ret
});
rx_stream.forward(in_w).await?;
Ok(())
rx_stream.forward(in_w).await
}

async fn rx_process<C: CodecReadHalf>(
Expand Down Expand Up @@ -164,3 +195,90 @@ async fn rx_process<C: CodecReadHalf>(
}
Ok(())
}

#[cfg(test)]
mod test {
use std::mem;
use std::time::Duration;

use tokio::io::AsyncReadExt;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio::time::timeout;

use crate::protocols::redis_codec::RedisCodec;
use crate::transforms::util::cluster_connection_pool::spawn_from_stream;

#[tokio::test]
async fn test_remote_shutdown() {
let (log_writer, _log_guard) = tracing_appender::non_blocking(std::io::stdout());
mem::forget(_log_guard);

let builder = tracing_subscriber::fmt()
.with_writer(log_writer)
.with_env_filter("INFO")
.with_filter_reloading();

let _handle = builder.reload_handle();
builder.try_init().ok();

let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
let port = listener.local_addr().unwrap().port();

let remote = tokio::spawn(async move {
// Accept connection and immediately close.
listener.accept().await.is_ok()
});

let stream = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
let codec = RedisCodec::new(true, 3);
let sender = spawn_from_stream(&codec, stream);

assert!(remote.await.unwrap());

assert!(
// NOTE: Typically within 1-10ms.
timeout(Duration::from_millis(100), sender.closed())
.await
.is_ok(),
"local did not detect remote shutdown"
);
}

#[tokio::test]
async fn test_local_shutdown() {
let (log_writer, _log_guard) = tracing_appender::non_blocking(std::io::stdout());
mem::forget(_log_guard);

let builder = tracing_subscriber::fmt()
.with_writer(log_writer)
.with_env_filter("INFO")
.with_filter_reloading();

let _handle = builder.reload_handle();
builder.try_init().ok();

let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
let port = listener.local_addr().unwrap().port();

let remote = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();

// Discard bytes until EOF.
let mut buffer = [0; 1];
while socket.read(&mut buffer[..]).await.unwrap() > 0 {}
});

let stream = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
let codec = RedisCodec::new(true, 3);

// Drop sender immediately.
let _ = spawn_from_stream(&codec, stream);

assert!(
// NOTE: Typically within 1-10ms.
timeout(Duration::from_millis(100), remote).await.is_ok(),
"remote did not detect local shutdown"
);
}
}