diff --git a/shotover-proxy/src/server.rs b/shotover-proxy/src/server.rs index eb76fae0f..6a774c050 100644 --- a/shotover-proxy/src/server.rs +++ b/shotover-proxy/src/server.rs @@ -2,8 +2,8 @@ use crate::message::Messages; use crate::tls::TlsAcceptor; use crate::transforms::chain::TransformChain; use crate::transforms::Wrapper; -use anyhow::{anyhow, Result}; -use futures::StreamExt; +use anyhow::{anyhow, Context, Result}; +use futures::{SinkExt, StreamExt}; use metrics::{register_gauge, Gauge}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; @@ -13,7 +13,6 @@ use tokio::sync::{mpsc, watch, Semaphore}; use tokio::time; use tokio::time::timeout; use tokio::time::Duration; -use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::codec::{Decoder, Encoder}; use tokio_util::codec::{FramedRead, FramedWrite}; use tracing::Instrument; @@ -231,8 +230,8 @@ impl TcpCodecListener { // Receive shutdown notifications. shutdown: Shutdown::new(self.trigger_shutdown_rx.clone()), + terminate_tasks: None, tls: self.tls.clone(), - timeout: self.timeout, }; @@ -356,6 +355,7 @@ pub struct Handler { /// which point the connection is terminated. shutdown: Shutdown, + terminate_tasks: Option>, tls: Option, /// Timeout in seconds after which to kill an idle connection. No timeout means connections will never be timed out. @@ -371,28 +371,54 @@ fn spawn_read_write_tasks< rx: R, tx: W, in_tx: UnboundedSender, - out_rx: UnboundedReceiver, + mut out_rx: UnboundedReceiver, out_tx: UnboundedSender, + mut terminate_tasks_rx: watch::Receiver<()>, ) { let mut reader = FramedRead::new(rx, codec.clone()); - let writer = FramedWrite::new(tx, codec); - + let mut writer = FramedWrite::new(tx, codec); + + // Shutdown flows + // + // main task shuts down due to transform error: + // 1. The main task terminates, sending terminate_tasks_tx and dropping the first out_tx + // 2. The reader task detects change on terminate_tasks_rx and terminates, the last out_tx instance is dropped + // 3. The writer task detects that the last out_tx is dropped by out_rx returning None and terminates + // + // client closes connection: + // 1. The reader task detects that the client has closed the connection via reader returning None and terminates, dropping in_tx and the first out_tx + // 2. The main task detects that in_tx is dropped by in_rx returning None and terminates, dropping the last out_tx + // 3. The writer task detects that the last out_tx is dropped by out_rx returning None and terminates + + // reader task tokio::spawn( async move { - while let Some(message) = reader.next().await { - match message { - Ok(message) => { - let remaining_messages = - process_return_to_sender_messages(message, &out_tx); - if !remaining_messages.is_empty() { - if let Err(error) = in_tx.send(remaining_messages) { - warn!("failed to pass on received message: {}", error); - return; + loop { + tokio::select! { + result = reader.next() => { + if let Some(message) = result { + match message { + Ok(message) => { + let remaining_messages = + process_return_to_sender_messages(message, &out_tx); + if !remaining_messages.is_empty() { + if let Err(error) = in_tx.send(remaining_messages) { + warn!("failed to pass on received message: {}", error); + return; + } + } + } + Err(error) => { + warn!("failed to receive or decode message: {:?}", error); + return; + } } + } else { + debug!("client has closed the connection"); + return; } } - Err(error) => { - warn!("failed to receive or decode message: {:?}", error); + _ = terminate_tasks_rx.changed() => { return; } } @@ -401,11 +427,27 @@ fn spawn_read_write_tasks< .in_current_span(), ); + // sender task tokio::spawn( async move { - let rx_stream = UnboundedReceiverStream::new(out_rx).map(Ok); - if let Err(err) = rx_stream.forward(writer).await { - error!("failed to send or encode message: {:?}", err); + loop { + if let Some(message) = out_rx.recv().await { + if let Err(err) = writer.send(message).await { + error!("failed to send or encode message: {:?}", err); + } + } else { + // Main task has ended. + // First flush out any remaining messages. + // Then end the task thus closing the connection by dropping the write half + while let Ok(message) = out_rx.try_recv() { + if let Err(err) = writer.send(message).await { + error!( + "while flushing messages: failed to send or encode message: {err:?}", + ); + } + } + break; + } } } .in_current_span(), @@ -430,16 +472,35 @@ impl Handler { // new request frame. let mut idle_time_seconds: u64 = 1; + let (terminate_tx, terminate_rx) = watch::channel::<()>(()); + self.terminate_tasks = Some(terminate_tx); + let (in_tx, mut in_rx) = mpsc::unbounded_channel::(); let (out_tx, out_rx) = mpsc::unbounded_channel::(); if let Some(tls) = &self.tls { let tls_stream = tls.accept(stream).await?; let (rx, tx) = tokio::io::split(tls_stream); - spawn_read_write_tasks(self.codec.clone(), rx, tx, in_tx, out_rx, out_tx.clone()); + spawn_read_write_tasks( + self.codec.clone(), + rx, + tx, + in_tx, + out_rx, + out_tx.clone(), + terminate_rx, + ); } else { let (rx, tx) = stream.into_split(); - spawn_read_write_tasks(self.codec.clone(), rx, tx, in_tx, out_rx, out_tx.clone()); + spawn_read_write_tasks( + self.codec.clone(), + rx, + tx, + in_tx, + out_rx, + out_tx.clone(), + terminate_rx, + ); }; while !self.shutdown.is_shutdown() { @@ -448,7 +509,7 @@ impl Handler { let mut reverse_chain = false; let messages = tokio::select! { - res = timeout(Duration::from_secs(idle_time_seconds) , in_rx.recv()) => { + res = timeout(Duration::from_secs(idle_time_seconds), in_rx.recv()) => { match res { Ok(maybe_message) => { idle_time_seconds = 1; @@ -491,27 +552,18 @@ impl Handler { self.chain.name.clone(), ); - let chain_result = if reverse_chain { + let modified_messages = if reverse_chain { self.chain.process_request_rev(wrapper).await } else { self.chain .process_request(wrapper, self.client_details.clone()) .await - }; - - match chain_result { - Ok(modified_messages) => { - debug!("sending message: {:?}", modified_messages); - // send the result of the process up stream - out_tx.send(modified_messages)?; - } - Err(e) => { - error!( - "{:?}", - e.context("chain failed to send and/or receive messages") - ); - } } + .context("chain failed to send and/or receive messages")?; + + debug!("sending message: {:?}", modified_messages); + // send the result of the process up stream + out_tx.send(modified_messages)?; } Ok(()) } @@ -531,6 +583,10 @@ impl Drop for Handler { // semaphore. self.limit_connections.add_permits(1); + + if let Some(terminate_tasks) = &self.terminate_tasks { + terminate_tasks.send(()).ok(); + } } } /// Listens for the server shutdown signal. diff --git a/shotover-proxy/src/transforms/redis/sink_single.rs b/shotover-proxy/src/transforms/redis/sink_single.rs index e1c4ba97e..a9de7ee29 100644 --- a/shotover-proxy/src/transforms/redis/sink_single.rs +++ b/shotover-proxy/src/transforms/redis/sink_single.rs @@ -93,10 +93,7 @@ impl Transform for RedisSinkSingle { // self.outbound is gauranteed to be Some by the previous block let outbound_framed_codec = self.outbound.as_mut().unwrap(); - outbound_framed_codec - .send(message_wrapper.messages) - .await - .ok(); + outbound_framed_codec.send(message_wrapper.messages).await?; match outbound_framed_codec.next().fuse().await { Some(mut a) => { diff --git a/shotover-proxy/tests/redis_int_tests/basic_driver_tests.rs b/shotover-proxy/tests/redis_int_tests/basic_driver_tests.rs index d20bfc161..178fccee6 100644 --- a/shotover-proxy/tests/redis_int_tests/basic_driver_tests.rs +++ b/shotover-proxy/tests/redis_int_tests/basic_driver_tests.rs @@ -8,6 +8,7 @@ use redis::{AsyncCommands, Commands, ErrorKind, RedisError, Value}; use serial_test::serial; use shotover_proxy::tls::TlsConfig; use std::collections::{HashMap, HashSet}; +use std::io::{Read, Write}; use std::path::Path; use std::thread::sleep; use std::time::Duration; @@ -1212,6 +1213,55 @@ async fn test_dr_auth(shotover_manager: &ShotoverManager) { ); } +/// A driver variant of this test case is provided so that we can ensure that +/// at least one driver handles this as we expect. +async fn test_trigger_transform_failure_driver(connection: &mut Connection) { + assert_eq!( + redis::cmd("SET") + .arg("foo") + .arg(42) + .query_async::<_, ()>(connection) + .await + .unwrap_err() + .to_string(), + "unexpected end of file".to_string() + ); +} + +/// A raw variant of this test case is provided so that we can make a strong assertion about the way shotover handles this case. +/// +/// CAREFUL: This lacks any kind of check that shotover is ready, +/// so make sure shotover_manager.redis_connection is run on 6379 before calling this. +fn test_trigger_transform_failure_raw() { + // Send invalid redis command + // To correctly handle this shotover should close the connection + let mut connection = std::net::TcpStream::connect("127.0.0.1:6379").unwrap(); + connection.write_all(b"*1\r\n$4\r\nping\r\n").unwrap(); + connection + .set_read_timeout(Some(Duration::from_secs(10))) + .unwrap(); + // If the connection was closed by shotover then we will succesfully read 0 bytes. + // If the connection was not closed by shotover then read will block for 10 seconds until the time is hit and then the unwrap will panic. + let amount = connection.read(&mut [0; 1]).unwrap(); + assert_eq!(amount, 0); +} + +/// CAREFUL: This lacks any kind of check that shotover is ready, +/// so make sure shotover_manager.redis_connection is run on 6379 before calling this. +fn test_invalid_frame() { + // Send invalid redis command + // To correctly handle this shotover should close the connection + let mut connection = std::net::TcpStream::connect("127.0.0.1:6379").unwrap(); + connection.write_all(b"invalid_redis_frame\r\n").unwrap(); + connection + .set_read_timeout(Some(Duration::from_secs(10))) + .unwrap(); + // If the connection was closed by shotover then we will succesfully read 0 bytes. + // If the connection was not closed by shotover then read will block for 10 seconds until the time is hit and then the unwrap will panic. + let amount = connection.read(&mut [0; 1]).unwrap(); + assert_eq!(amount, 0); +} + #[tokio::test(flavor = "multi_thread")] #[serial] async fn test_passthrough() { @@ -1223,6 +1273,19 @@ async fn test_passthrough() { Flusher::new_single_connection(shotover_manager.redis_connection_async(6379).await).await; run_all(&mut connection, &mut flusher).await; + test_invalid_frame(); +} + +#[tokio::test(flavor = "multi_thread")] +#[serial] +async fn test_passthrough_redis_down() { + let shotover_manager = + ShotoverManager::from_topology_file("example-configs/redis-passthrough/topology.yaml"); + let mut connection = shotover_manager.redis_connection_async(6379).await; + + test_trigger_transform_failure_driver(&mut connection).await; + test_trigger_transform_failure_raw(); + test_invalid_frame(); } #[tokio::test(flavor = "multi_thread")]