From 6821cd47517c9edf833c1494a36d8f41cdbbd317 Mon Sep 17 00:00:00 2001 From: Kuangda He Date: Mon, 23 Aug 2021 14:15:21 +1000 Subject: [PATCH] Fix detection of closed connections When the remote peer closes a connection, this causes the rx task to stop, but the tx side keeps running until it is used to send a message, which is guaranteed to fail. The fix is to add a 'closed' signal to tell the tx task to stop after the rx task finishes. --- .../redis_transforms/redis_cluster.rs | 4 +- .../util/cluster_connection_pool.rs | 105 ++++++++++++------ 2 files changed, 70 insertions(+), 39 deletions(-) diff --git a/shotover-proxy/src/transforms/redis_transforms/redis_cluster.rs b/shotover-proxy/src/transforms/redis_transforms/redis_cluster.rs index 8de660cc1..74e9c6e41 100644 --- a/shotover-proxy/src/transforms/redis_transforms/redis_cluster.rs +++ b/shotover-proxy/src/transforms/redis_transforms/redis_cluster.rs @@ -59,7 +59,7 @@ impl TransformsFromConfig for RedisClusterConfig { for (node, _, _) in &slots.masters { match connection_pool - .get_connection(&node, self.connection_count.unwrap_or(1)) + .get_connections(&node, self.connection_count.unwrap_or(1)) .await { Ok(conn) => { @@ -154,7 +154,7 @@ impl RedisCluster { if let Ok(res) = timeout( Duration::from_millis(40), self.connection_pool - .get_connection(host, self.connection_count), + .get_connections(host, self.connection_count), ) .await { diff --git a/shotover-proxy/src/transforms/util/cluster_connection_pool.rs b/shotover-proxy/src/transforms/util/cluster_connection_pool.rs index b4e50af65..b5296fdea 100644 --- a/shotover-proxy/src/transforms/util/cluster_connection_pool.rs +++ b/shotover-proxy/src/transforms/util/cluster_connection_pool.rs @@ -1,37 +1,34 @@ -use crate::server::CodecReadHalf; -use crate::server::CodecWriteHalf; -use crate::transforms::util::Request; -use crate::{message::Messages, server::Codec}; -use anyhow::{anyhow, Result}; -use futures::StreamExt; use std::collections::{HashMap, HashSet}; -use std::fmt; -use std::fmt::Formatter; use std::iter::FromIterator; use std::sync::Arc; + +use anyhow::{anyhow, Result}; +use derivative::Derivative; +use futures::StreamExt; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::Mutex; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::codec::{Decoder, FramedRead, FramedWrite}; -use tracing::{debug, info}; +use tracing::{debug, info, trace}; -#[derive(Clone)] +use crate::server::CodecReadHalf; +use crate::server::CodecWriteHalf; +use crate::transforms::util::Request; +use crate::{message::Messages, server::Codec}; + +#[derive(Clone, Derivative)] +#[derivative(Debug)] pub struct ConnectionPool { host_set: Arc>>, queue_map: Arc>>>>, + + #[derivative(Debug = "ignore")] codec: C, - auth_func: fn(&ConnectionPool, &mut UnboundedSender) -> Result<()>, -} -impl fmt::Debug for ConnectionPool { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("ConnectionPool") - .field("host_set", &self.host_set) - .field("queue_map", &self.queue_map) - .finish() - } + #[derivative(Debug = "ignore")] + auth_func: fn(&ConnectionPool, &mut UnboundedSender) -> Result<()>, } impl ConnectionPool { @@ -60,7 +57,7 @@ impl ConnectionPool { /// Try and grab an existing connection, if it's closed (e.g. the listener on the other side /// has closed due to a TCP error), we'll try to reconnect and return the new connection while /// updating the connection map. Errors are returned when a connection can't be established. - pub async fn get_connection( + pub async fn get_connections( &self, host: &String, connection_count: i32, @@ -71,12 +68,12 @@ impl ConnectionPool { return Ok(x.clone()); } } - let connection = self.connect(&host, connection_count).await?; - queue_map.insert(host.clone(), connection.clone()); - Ok(connection) + let connections = self.new_connections(&host, connection_count).await?; + queue_map.insert(host.clone(), connections.clone()); + Ok(connections) } - pub async fn connect( + pub async fn new_connections( &self, host: &String, connection_count: i32, @@ -84,20 +81,15 @@ impl ConnectionPool { where ::Error: std::marker::Send, { - let mut connection_pool: Vec> = Vec::new(); + let mut connections: Vec> = Vec::new(); for _i in 0..connection_count { - let socket: TcpStream = TcpStream::connect(host).await?; - let (read, write) = socket.into_split(); - let (mut out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::(); - let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel::(); - - tokio::spawn(tx_process(write, out_rx, return_tx, self.codec.clone())); + let stream = TcpStream::connect(host).await?; + let mut out_tx = spawn_from_stream(&self.codec, stream); - tokio::spawn(rx_process(read, return_rx, self.codec.clone())); match (self.auth_func)(&self, &mut out_tx) { Ok(_) => { - connection_pool.push(out_tx); + connections.push(out_tx); } Err(e) => { info!("Could not authenticate to upstream TCP service - {}", e); @@ -105,14 +97,54 @@ impl ConnectionPool { } } - if connection_pool.len() == 0 { + if connections.len() == 0 { Err(anyhow!("Couldn't connect to upstream TCP service")) } else { - Ok(connection_pool) + Ok(connections) } } } +pub fn spawn_from_stream( + codec: &C, + stream: TcpStream, +) -> UnboundedSender { + let (read, write) = stream.into_split(); + let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::(); + let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel::(); + let (closed_tx, closed_rx) = tokio::sync::oneshot::channel(); + + let codec_clone = codec.clone(); + + tokio::spawn(async move { + tokio::select! { + result = tx_process(write, out_rx, return_tx, codec_clone) => if let Err(e) = result { + trace!("connection write-closed with error: {:?}", e); + } else { + trace!("connection write-closed gracefully"); + }, + _ = closed_rx => { + trace!("connection write-closed by remote upstream"); + }, + } + }); + + let codec_clone = codec.clone(); + + tokio::spawn(async move { + if let Err(e) = rx_process(read, return_rx, codec_clone).await { + trace!("connection read-closed with error: {:?}", e); + } else { + trace!("connection read-closed gracefully"); + } + + // Signal the writer to also exit, which then closes `out_tx` - what we consider as the connection. + closed_tx.send(()) + }); + + out_tx +} + async fn tx_process( write: OwnedWriteHalf, out_rx: UnboundedReceiver, @@ -127,8 +159,7 @@ async fn tx_process( return_tx.send(x)?; ret }); - rx_stream.forward(in_w).await?; - Ok(()) + rx_stream.forward(in_w).await } async fn rx_process(