diff --git a/shotover-proxy/tests/redis_int_tests/mod.rs b/shotover-proxy/tests/redis_int_tests/mod.rs index 4f711e3c4..4a79602a2 100644 --- a/shotover-proxy/tests/redis_int_tests/mod.rs +++ b/shotover-proxy/tests/redis_int_tests/mod.rs @@ -62,6 +62,7 @@ async fn passthrough_redis_down() { shotover .shutdown_and_then_consume_events(&[ + // Error occurs when client sends a message to shotover EventMatcher::new() .with_level(Level::Error) .with_target("shotover::server") @@ -75,6 +76,19 @@ Caused by: 3: Connection refused (os error 111)"#, ) .with_count(Count::Times(2)), + // When the chain is flushed on client connection close we hit the same error again + EventMatcher::new() + .with_level(Level::Error) + .with_target("shotover::server") + .with_message( + r#"encountered an error when flushing the chain redis for shutdown + +Caused by: + 0: RedisSinkSingle transform failed + 1: Failed to connect to destination "127.0.0.1:1111" + 2: Connection refused (os error 111)"#, + ) + .with_count(Count::Times(3)), invalid_frame_event(), ]) .await; diff --git a/shotover/src/server.rs b/shotover/src/server.rs index c28ebcec4..52f5a3801 100644 --- a/shotover/src/server.rs +++ b/shotover/src/server.rs @@ -444,7 +444,7 @@ async fn spawn_websocket_read_write_tasks< ); } -fn spawn_read_write_tasks< +pub fn spawn_read_write_tasks< C: CodecBuilder + 'static, R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, @@ -455,6 +455,7 @@ fn spawn_read_write_tasks< in_tx: mpsc::Sender, mut out_rx: UnboundedReceiver, out_tx: UnboundedSender, + force_run_chain: Option>, ) { let (decoder, encoder) = codec.build(); let mut reader = FramedRead::new(rx, decoder); @@ -487,6 +488,9 @@ fn spawn_read_write_tasks< // main task has shutdown, this task is no longer needed return; } + if let Some(force_run_chain) = force_run_chain.as_ref() { + force_run_chain.notify_one(); + } } Err(CodecReadError::RespondAndThenCloseConnection(messages)) => { if let Err(err) = out_tx.send(messages) { @@ -652,6 +656,7 @@ impl Handler { in_tx, out_rx, out_tx.clone(), + None, ); } else { let (rx, tx) = stream.into_split(); @@ -662,6 +667,7 @@ impl Handler { in_tx, out_rx, out_tx.clone(), + None, ); }; } diff --git a/shotover/src/transforms/redis/sink_single.rs b/shotover/src/transforms/redis/sink_single.rs index 78af0788a..34085b895 100644 --- a/shotover/src/transforms/redis/sink_single.rs +++ b/shotover/src/transforms/redis/sink_single.rs @@ -1,29 +1,22 @@ +use crate::codec::{CodecBuilder, Direction}; use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, Messages}; +use crate::server::spawn_read_write_tasks; use crate::tcp; -use crate::tls::{AsyncStream, TlsConnector, TlsConnectorConfig}; +use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::{ Transform, TransformBuilder, TransformConfig, TransformContextBuilder, Wrapper, }; -use crate::{ - codec::{ - redis::{RedisCodecBuilder, RedisDecoder, RedisEncoder}, - CodecBuilder, CodecReadError, Direction, - }, - transforms::TransformContextConfig, -}; +use crate::{codec::redis::RedisCodecBuilder, transforms::TransformContextConfig}; use anyhow::{anyhow, Result}; use async_trait::async_trait; -use futures::{FutureExt, SinkExt, StreamExt}; use metrics::{counter, Counter}; use serde::{Deserialize, Serialize}; use std::fmt::Debug; -use std::pin::Pin; +use std::sync::Arc; use std::time::Duration; -use tokio::io::{ReadHalf, WriteHalf}; -use tokio::sync::mpsc; -use tokio_util::codec::{FramedRead, FramedWrite}; -use tracing::Instrument; +use tokio::io::split; +use tokio::sync::{mpsc, Notify}; #[derive(Serialize, Deserialize, Debug)] #[serde(deny_unknown_fields)] @@ -80,14 +73,14 @@ impl RedisSinkSingleBuilder { } impl TransformBuilder for RedisSinkSingleBuilder { - fn build(&self, _transform_context: TransformContextBuilder) -> Box { + fn build(&self, transform_context: TransformContextBuilder) -> Box { Box::new(RedisSinkSingle { address: self.address.clone(), tls: self.tls.clone(), connection: None, failed_requests: self.failed_requests.clone(), - pushed_messages_tx: None, connect_timeout: self.connect_timeout, + force_run_chain: transform_context.force_run_chain, }) } @@ -100,21 +93,18 @@ impl TransformBuilder for RedisSinkSingleBuilder { } } -type PinStream = Pin>; - -struct Connection { - outbound_tx: FramedWrite, RedisEncoder>, - response_messages_rx: mpsc::UnboundedReceiver, - sent_message_type_tx: mpsc::UnboundedSender, -} - pub struct RedisSinkSingle { address: String, tls: Option, connection: Option, failed_requests: Counter, - pushed_messages_tx: Option>, connect_timeout: Duration, + force_run_chain: Arc, +} + +struct Connection { + in_rx: mpsc::Receiver>, + out_tx: mpsc::UnboundedSender>, } #[async_trait] @@ -123,247 +113,85 @@ impl Transform for RedisSinkSingle { NAME } - async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { - // Return immediately if we have no messages. - // If we tried to send no messages we would block forever waiting for a reply that will never come. - if requests_wrapper.requests.is_empty() { - return Ok(requests_wrapper.requests); - } - + async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { if self.connection.is_none() { - let generic_stream = if let Some(tls) = self.tls.as_mut() { - let tls_stream = tls - .connect(self.connect_timeout, self.address.clone()) - .await?; - Box::pin(tls_stream) as Pin> + let (in_tx, in_rx) = mpsc::channel::(10_000); + let (out_tx, out_rx) = mpsc::unbounded_channel::(); + let codec = RedisCodecBuilder::new(Direction::Sink, "RedisSinkSingle".to_owned()); + if let Some(tls) = self.tls.as_mut() { + let tls_stream = tls.connect(self.connect_timeout, &self.address).await?; + let (rx, tx) = split(tls_stream); + spawn_read_write_tasks( + codec, + rx, + tx, + in_tx, + out_rx, + out_tx.clone(), + Some(self.force_run_chain.clone()), + ); } else { - let tcp_stream = - tcp::tcp_stream(self.connect_timeout, self.address.clone()).await?; - Box::pin(tcp_stream) as Pin> - }; - - let (decoder, encoder) = - RedisCodecBuilder::new(Direction::Sink, "RedisSinkSingle".to_owned()).build(); - let (stream_rx, stream_tx) = tokio::io::split(generic_stream); - let outbound_tx = FramedWrite::new(stream_tx, encoder); - let outbound_rx = FramedRead::new(stream_rx, decoder); - let (response_messages_tx, response_messages_rx) = mpsc::unbounded_channel(); - let (sent_message_type_tx, sent_message_type_rx) = mpsc::unbounded_channel(); - - tokio::spawn( - server_response_processing_task( - outbound_rx, - self.pushed_messages_tx.clone(), - response_messages_tx, - sent_message_type_rx, - ) - .in_current_span(), - ); - self.connection = Some(Connection { - response_messages_rx, - sent_message_type_tx, - outbound_tx, - }) + let tcp_stream = tcp::tcp_stream(self.connect_timeout, &self.address).await?; + let (rx, tx) = tcp_stream.into_split(); + spawn_read_write_tasks( + codec, + rx, + tx, + in_tx, + out_rx, + out_tx.clone(), + Some(self.force_run_chain.clone()), + ); + } + self.connection = Some(Connection { in_rx, out_tx }); } - let connection = self.connection.as_mut().unwrap(); - - for message in &mut requests_wrapper.requests { - let ty = if let Some(Frame::Redis(RedisFrame::Array(array))) = message.frame() { - if let Some(RedisFrame::BulkString(bytes)) = array.first() { - match bytes.to_ascii_uppercase().as_slice() { - b"SUBSCRIBE" | b"PSUBSCRIBE" | b"SSUBSCRIBE" => MessageType::Subscribe, - b"UNSUBSCRIBE" | b"PUNSUBSCRIBE" | b"SUNSUBSCRIBE" => { - MessageType::Unsubscribe - } - b"RESET" => MessageType::Reset, - _ => MessageType::Other, + if requests_wrapper.requests.is_empty() { + // there are no requests, so no point sending any, but we should check for any responses without awaiting + if let Ok(mut responses) = self.connection.as_mut().unwrap().in_rx.try_recv() { + for response in &mut responses { + if let Some(Frame::Redis(RedisFrame::Error(_))) = response.frame() { + self.failed_requests.increment(1); } - } else { - MessageType::Other } + Ok(responses) } else { - MessageType::Other - }; - connection - .sent_message_type_tx - .send(ty) - .map_err(|_| anyhow!("Failed to send message type because RedisSinkSingle response processing task is dead"))?; - } - - let messages_len = requests_wrapper.requests.len(); - connection - .outbound_tx - .send(requests_wrapper.requests) - .await - .map_err(|err| anyhow!("Failed to send messages to redis destination: {err:?}"))?; - - let mut result = Vec::with_capacity(messages_len); - while result.len() < messages_len { - let mut message = connection - .response_messages_rx - .recv() - .await - .ok_or_else(|| anyhow!("Failed to receive message because RedisSinkSingle response processing task is dead"))?; - if let Some(Frame::Redis(RedisFrame::Error(_))) = message.frame() { - self.failed_requests.increment(1); + Ok(vec![]) } - result.push(message); - } - Ok(result) - } - - fn set_pushed_messages_tx(&mut self, pushed_messages_tx: mpsc::UnboundedSender) { - self.pushed_messages_tx = Some(pushed_messages_tx); - } -} - -/// Processes responses coming in from the server. -/// Responses are then filtered into either the regular chain or pushed messages chain -/// depending on if they are a subscription or response message. -/// -/// A separate task is needed to process the incoming messages so that subscription messages can be sent immediately -/// without waiting for an incoming request to trigger the RedisSinkSingle transform again. -/// -/// The task will end silently if either the RedisSinkSingle transform is dropped or the server closes the connection. -async fn server_response_processing_task( - mut outbound_rx: FramedRead, RedisDecoder>, - subscribe_tx: Option>, - response_messages_tx: mpsc::UnboundedSender, - mut sent_message_type: mpsc::UnboundedReceiver, -) { - let mut is_subscribed = true; - loop { - tokio::select! { - responses = outbound_rx.next().fuse() => { - if process_server_response( - responses, - &subscribe_tx, - &response_messages_tx, - &mut is_subscribed, - &mut sent_message_type - ).await { - return; - } - }, - _ = response_messages_tx.closed() => { - tracing::debug!("RedisSinkSingle dropped, redis single subscription task shutting down"); - return; - }, - } - } -} - -/// returns true when the task should shutdown -async fn process_server_response( - responses: Option>, - subscribe_tx: &Option>, - response_messages_tx: &mpsc::UnboundedSender, - is_subscribed: &mut bool, - sent_message_type: &mut mpsc::UnboundedReceiver, -) -> bool { - match responses { - Some(Ok(messages)) => { - for mut message in messages { - // Notes on subscription responses - // - // There are 3 types of pubsub responses and the type is determined by the first value in the array: - // * `subscribe` - a response to a SUBSCRIBE, PSUBSCRIBE or SSUBSCRIBE request - // * `unsubscribe` - a response to an UNSUBSCRIBE, PUNSUBSCRIBE or SUNSUBSCRIBE request - // * `message` - a subscription message - // - // Additionally redis will: - // * accept a few regular commands while in pubsub mode: PING, RESET and QUIT - // * return an error response when a nonexistent or non pubsub compatible command is used - // - // Note: PING has a custom response when in pubsub mode. - // It returns an array ['pong', $pingMessage] instead of directly returning $pingMessage. - // But this doesnt cause any problems for us. - - // Determine if message is a `message` subscription message - // - // Because PING, RESET, QUIT and error responses never return a RedisFrame::Array starting with `message`, - // they have no way to collide with the `message` value of a subscription message. - // So while we are in subscription mode we can use that to determine if an - // incoming message is a subscription message. - let is_subscription_message = if *is_subscribed { - if let Some(Frame::Redis(RedisFrame::Array(array))) = message.frame() { - if let [RedisFrame::BulkString(ty), ..] = array.as_slice() { - ty.as_ref() == b"message" - } else { - false - } - } else { - false - } - } else { - false - }; - - // Update is_subscribed state - // - // In order to make sense of a response we need the main task to - // send us the type of its corresponding request. - // - // In order to keep the incoming request MessageTypes in sync with their corresponding responses - // we must only process a MessageType when the message is not a subscription message. - // This is fine because subscription messages cannot affect the is_subscribed state. - if !is_subscription_message { - match sent_message_type.recv().await { - Some(MessageType::Subscribe) | Some(MessageType::Unsubscribe) => { - if let Some(Frame::Redis(RedisFrame::Array(array))) = message.frame() { - if let Some(RedisFrame::Integer(number_of_subscribed_channels)) = - array.get(2) - { - *is_subscribed = *number_of_subscribed_channels != 0; - } - } - } - Some(MessageType::Other) => {} - Some(MessageType::Reset) => { - *is_subscribed = false; - } - None => { - tracing::debug!("RedisSinkSingle dropped after a message was received from server, RedisSinkSingle request processor task shutting down"); - return true; - } + } else { + let requests_count = requests_wrapper.requests.len(); + self.connection + .as_mut() + .unwrap() + .out_tx + .send(requests_wrapper.requests) + .map_err(|err| anyhow!("Failed to send messages to redis destination: {err:?}"))?; + + let mut result = vec![]; + let mut responses_count = 0; + while responses_count < requests_count { + let mut responses = self + .connection + .as_mut() + .unwrap() + .in_rx + .recv() + .await + .ok_or_else(|| { + anyhow!("Failed to receive message because recv task is dead") + })?; + + for response in &mut responses { + if let Some(Frame::Redis(RedisFrame::Error(_))) = response.frame() { + self.failed_requests.increment(1); } - } - - // Route the message down the correct path: - // * `message` subscription messages: - // needs to be routed down the pushed_messages chain - // * everything else: - // needs to be routed down the regular chain - if is_subscription_message { - // subscribe_tx may not exist if we are e.g. in an alternate chain of a tee transform - if let Some(subscribe_tx) = subscribe_tx { - if let Err(mpsc::error::SendError(_)) = subscribe_tx.send(vec![message]) { - tracing::debug!("shotover chain is terminated, will continue running until Transform is dropped"); - } + if response.request_id().is_some() { + responses_count += 1; } - } else if let Err(mpsc::error::SendError(_)) = response_messages_tx.send(message) { - tracing::debug!("RedisSinkSingle dropped after a message was received from server, RedisSinkSingle request processor task shutting down"); - return true; } + result.extend(responses); } - false - } - Some(Err(err)) => { - tracing::error!("encountered error in redis stream: {err:?}"); - true - } - None => { - tracing::debug!("sink stream ended, redis single subscription task shutting down"); - true + Ok(result) } } } - -#[derive(Debug)] -enum MessageType { - Other, - Subscribe, - Unsubscribe, - Reset, -}