From 42816145ec84937e6706193bc299124b37606075 Mon Sep 17 00:00:00 2001 From: Kuangda He Date: Sun, 29 Aug 2021 08:37:35 +1000 Subject: [PATCH] Fix detection of closed connections in cluster connection pools (#146) * Fix detection of closed connections When the remote peer closes a connection, this causes the rx task to stop, but the tx side keeps running until it is used to send a message, which is guaranteed to fail. The fix is to add a 'closed' signal to tell the tx task to stop after the rx task finishes. * Add shutdown tests * Fix logging * Fix naming * Fix use declaration ordering * Update cluster_connection_pool.rs * Test clean connection shutdown * Ninja: Changed logging level in tests to INFO --- .../redis_transforms/redis_cluster.rs | 4 +- .../util/cluster_connection_pool.rs | 192 ++++++++++++++---- 2 files changed, 157 insertions(+), 39 deletions(-) diff --git a/shotover-proxy/src/transforms/redis_transforms/redis_cluster.rs b/shotover-proxy/src/transforms/redis_transforms/redis_cluster.rs index 9fb065d44..b9f4c2df3 100644 --- a/shotover-proxy/src/transforms/redis_transforms/redis_cluster.rs +++ b/shotover-proxy/src/transforms/redis_transforms/redis_cluster.rs @@ -57,7 +57,7 @@ impl TransformsFromConfig for RedisClusterConfig { for node in slots.masters.values() { match connection_pool - .get_connection(&node, self.connection_count.unwrap_or(1)) + .get_connections(&node, self.connection_count.unwrap_or(1)) .await { Ok(conn) => { @@ -134,7 +134,7 @@ impl RedisCluster { if let Ok(res) = timeout( Duration::from_millis(40), self.connection_pool - .get_connection(host, self.connection_count), + .get_connections(host, self.connection_count), ) .await { diff --git a/shotover-proxy/src/transforms/util/cluster_connection_pool.rs b/shotover-proxy/src/transforms/util/cluster_connection_pool.rs index b4e50af65..019ebb2e8 100644 --- a/shotover-proxy/src/transforms/util/cluster_connection_pool.rs +++ b/shotover-proxy/src/transforms/util/cluster_connection_pool.rs @@ -1,37 +1,34 @@ -use crate::server::CodecReadHalf; -use crate::server::CodecWriteHalf; -use crate::transforms::util::Request; -use crate::{message::Messages, server::Codec}; -use anyhow::{anyhow, Result}; -use futures::StreamExt; use std::collections::{HashMap, HashSet}; -use std::fmt; -use std::fmt::Formatter; use std::iter::FromIterator; use std::sync::Arc; + +use anyhow::{anyhow, Result}; +use derivative::Derivative; +use futures::StreamExt; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::Mutex; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::codec::{Decoder, FramedRead, FramedWrite}; -use tracing::{debug, info}; +use tracing::{debug, info, trace}; -#[derive(Clone)] +use crate::server::CodecReadHalf; +use crate::server::CodecWriteHalf; +use crate::transforms::util::Request; +use crate::{message::Messages, server::Codec}; + +#[derive(Clone, Derivative)] +#[derivative(Debug)] pub struct ConnectionPool { host_set: Arc>>, queue_map: Arc>>>>, + + #[derivative(Debug = "ignore")] codec: C, - auth_func: fn(&ConnectionPool, &mut UnboundedSender) -> Result<()>, -} -impl fmt::Debug for ConnectionPool { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("ConnectionPool") - .field("host_set", &self.host_set) - .field("queue_map", &self.queue_map) - .finish() - } + #[derivative(Debug = "ignore")] + auth_func: fn(&ConnectionPool, &mut UnboundedSender) -> Result<()>, } impl ConnectionPool { @@ -60,7 +57,7 @@ impl ConnectionPool { /// Try and grab an existing connection, if it's closed (e.g. the listener on the other side /// has closed due to a TCP error), we'll try to reconnect and return the new connection while /// updating the connection map. Errors are returned when a connection can't be established. - pub async fn get_connection( + pub async fn get_connections( &self, host: &String, connection_count: i32, @@ -71,12 +68,12 @@ impl ConnectionPool { return Ok(x.clone()); } } - let connection = self.connect(&host, connection_count).await?; - queue_map.insert(host.clone(), connection.clone()); - Ok(connection) + let connections = self.new_connections(&host, connection_count).await?; + queue_map.insert(host.clone(), connections.clone()); + Ok(connections) } - pub async fn connect( + pub async fn new_connections( &self, host: &String, connection_count: i32, @@ -84,20 +81,15 @@ impl ConnectionPool { where ::Error: std::marker::Send, { - let mut connection_pool: Vec> = Vec::new(); + let mut connections: Vec> = Vec::new(); for _i in 0..connection_count { - let socket: TcpStream = TcpStream::connect(host).await?; - let (read, write) = socket.into_split(); - let (mut out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::(); - let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel::(); + let stream = TcpStream::connect(host).await?; + let mut out_tx = spawn_from_stream(&self.codec, stream); - tokio::spawn(tx_process(write, out_rx, return_tx, self.codec.clone())); - - tokio::spawn(rx_process(read, return_rx, self.codec.clone())); match (self.auth_func)(&self, &mut out_tx) { Ok(_) => { - connection_pool.push(out_tx); + connections.push(out_tx); } Err(e) => { info!("Could not authenticate to upstream TCP service - {}", e); @@ -105,14 +97,54 @@ impl ConnectionPool { } } - if connection_pool.len() == 0 { + if connections.len() == 0 { Err(anyhow!("Couldn't connect to upstream TCP service")) } else { - Ok(connection_pool) + Ok(connections) } } } +pub fn spawn_from_stream( + codec: &C, + stream: TcpStream, +) -> UnboundedSender { + let (read, write) = stream.into_split(); + let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::(); + let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel::(); + let (closed_tx, closed_rx) = tokio::sync::oneshot::channel(); + + let codec_clone = codec.clone(); + + tokio::spawn(async move { + tokio::select! { + result = tx_process(write, out_rx, return_tx, codec_clone) => if let Err(e) = result { + trace!("connection write-closed with error: {:?}", e); + } else { + trace!("connection write-closed gracefully"); + }, + _ = closed_rx => { + trace!("connection write-closed by remote upstream"); + }, + } + }); + + let codec_clone = codec.clone(); + + tokio::spawn(async move { + if let Err(e) = rx_process(read, return_rx, codec_clone).await { + trace!("connection read-closed with error: {:?}", e); + } else { + trace!("connection read-closed gracefully"); + } + + // Signal the writer to also exit, which then closes `out_tx` - what we consider as the connection. + closed_tx.send(()) + }); + + out_tx +} + async fn tx_process( write: OwnedWriteHalf, out_rx: UnboundedReceiver, @@ -127,8 +159,7 @@ async fn tx_process( return_tx.send(x)?; ret }); - rx_stream.forward(in_w).await?; - Ok(()) + rx_stream.forward(in_w).await } async fn rx_process( @@ -164,3 +195,90 @@ async fn rx_process( } Ok(()) } + +#[cfg(test)] +mod test { + use std::mem; + use std::time::Duration; + + use tokio::io::AsyncReadExt; + use tokio::net::TcpListener; + use tokio::net::TcpStream; + use tokio::time::timeout; + + use crate::protocols::redis_codec::RedisCodec; + use crate::transforms::util::cluster_connection_pool::spawn_from_stream; + + #[tokio::test] + async fn test_remote_shutdown() { + let (log_writer, _log_guard) = tracing_appender::non_blocking(std::io::stdout()); + mem::forget(_log_guard); + + let builder = tracing_subscriber::fmt() + .with_writer(log_writer) + .with_env_filter("INFO") + .with_filter_reloading(); + + let _handle = builder.reload_handle(); + builder.try_init().ok(); + + let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + let remote = tokio::spawn(async move { + // Accept connection and immediately close. + listener.accept().await.is_ok() + }); + + let stream = TcpStream::connect(("127.0.0.1", port)).await.unwrap(); + let codec = RedisCodec::new(true, 3); + let sender = spawn_from_stream(&codec, stream); + + assert!(remote.await.unwrap()); + + assert!( + // NOTE: Typically within 1-10ms. + timeout(Duration::from_millis(100), sender.closed()) + .await + .is_ok(), + "local did not detect remote shutdown" + ); + } + + #[tokio::test] + async fn test_local_shutdown() { + let (log_writer, _log_guard) = tracing_appender::non_blocking(std::io::stdout()); + mem::forget(_log_guard); + + let builder = tracing_subscriber::fmt() + .with_writer(log_writer) + .with_env_filter("INFO") + .with_filter_reloading(); + + let _handle = builder.reload_handle(); + builder.try_init().ok(); + + let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + let remote = tokio::spawn(async move { + let (mut socket, _) = listener.accept().await.unwrap(); + + // Discard bytes until EOF. + let mut buffer = [0; 1]; + while socket.read(&mut buffer[..]).await.unwrap() > 0 {} + }); + + let stream = TcpStream::connect(("127.0.0.1", port)).await.unwrap(); + let codec = RedisCodec::new(true, 3); + + // Drop sender immediately. + let _ = spawn_from_stream(&codec, stream); + + assert!( + // NOTE: Typically within 1-10ms. + timeout(Duration::from_millis(100), remote).await.is_ok(), + "remote did not detect local shutdown" + ); + } +}