Skip to content

Commit

Permalink
Close connection on transform error (#707)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Jul 26, 2022
1 parent fb5d5a9 commit a42da2c
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 43 deletions.
134 changes: 95 additions & 39 deletions shotover-proxy/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::message::Messages;
use crate::tls::TlsAcceptor;
use crate::transforms::chain::TransformChain;
use crate::transforms::Wrapper;
use anyhow::{anyhow, Result};
use futures::StreamExt;
use anyhow::{anyhow, Context, Result};
use futures::{SinkExt, StreamExt};
use metrics::{register_gauge, Gauge};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
Expand All @@ -13,7 +13,6 @@ use tokio::sync::{mpsc, watch, Semaphore};
use tokio::time;
use tokio::time::timeout;
use tokio::time::Duration;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::codec::{Decoder, Encoder};
use tokio_util::codec::{FramedRead, FramedWrite};
use tracing::Instrument;
Expand Down Expand Up @@ -231,8 +230,8 @@ impl<C: Codec + 'static> TcpCodecListener<C> {
// Receive shutdown notifications.
shutdown: Shutdown::new(self.trigger_shutdown_rx.clone()),

terminate_tasks: None,
tls: self.tls.clone(),

timeout: self.timeout,
};

Expand Down Expand Up @@ -356,6 +355,7 @@ pub struct Handler<C: Codec> {
/// which point the connection is terminated.
shutdown: Shutdown,

terminate_tasks: Option<watch::Sender<()>>,
tls: Option<TlsAcceptor>,

/// Timeout in seconds after which to kill an idle connection. No timeout means connections will never be timed out.
Expand All @@ -371,28 +371,54 @@ fn spawn_read_write_tasks<
rx: R,
tx: W,
in_tx: UnboundedSender<Messages>,
out_rx: UnboundedReceiver<Messages>,
mut out_rx: UnboundedReceiver<Messages>,
out_tx: UnboundedSender<Messages>,
mut terminate_tasks_rx: watch::Receiver<()>,
) {
let mut reader = FramedRead::new(rx, codec.clone());
let writer = FramedWrite::new(tx, codec);

let mut writer = FramedWrite::new(tx, codec);

// Shutdown flows
//
// main task shuts down due to transform error:
// 1. The main task terminates, sending terminate_tasks_tx and dropping the first out_tx
// 2. The reader task detects change on terminate_tasks_rx and terminates, the last out_tx instance is dropped
// 3. The writer task detects that the last out_tx is dropped by out_rx returning None and terminates
//
// client closes connection:
// 1. The reader task detects that the client has closed the connection via reader returning None and terminates, dropping in_tx and the first out_tx
// 2. The main task detects that in_tx is dropped by in_rx returning None and terminates, dropping the last out_tx
// 3. The writer task detects that the last out_tx is dropped by out_rx returning None and terminates

// reader task
tokio::spawn(
async move {
while let Some(message) = reader.next().await {
match message {
Ok(message) => {
let remaining_messages =
process_return_to_sender_messages(message, &out_tx);
if !remaining_messages.is_empty() {
if let Err(error) = in_tx.send(remaining_messages) {
warn!("failed to pass on received message: {}", error);
return;
loop {
tokio::select! {
result = reader.next() => {
if let Some(message) = result {
match message {
Ok(message) => {
let remaining_messages =
process_return_to_sender_messages(message, &out_tx);
if !remaining_messages.is_empty() {
if let Err(error) = in_tx.send(remaining_messages) {
warn!("failed to pass on received message: {}", error);
return;
}
}
}
Err(error) => {
warn!("failed to receive or decode message: {:?}", error);
return;
}
}
} else {
debug!("client has closed the connection");
return;
}
}
Err(error) => {
warn!("failed to receive or decode message: {:?}", error);
_ = terminate_tasks_rx.changed() => {
return;
}
}
Expand All @@ -401,11 +427,27 @@ fn spawn_read_write_tasks<
.in_current_span(),
);

// sender task
tokio::spawn(
async move {
let rx_stream = UnboundedReceiverStream::new(out_rx).map(Ok);
if let Err(err) = rx_stream.forward(writer).await {
error!("failed to send or encode message: {:?}", err);
loop {
if let Some(message) = out_rx.recv().await {
if let Err(err) = writer.send(message).await {
error!("failed to send or encode message: {:?}", err);
}
} else {
// Main task has ended.
// First flush out any remaining messages.
// Then end the task thus closing the connection by dropping the write half
while let Ok(message) = out_rx.try_recv() {
if let Err(err) = writer.send(message).await {
error!(
"while flushing messages: failed to send or encode message: {err:?}",
);
}
}
break;
}
}
}
.in_current_span(),
Expand All @@ -430,16 +472,35 @@ impl<C: Codec + 'static> Handler<C> {
// new request frame.
let mut idle_time_seconds: u64 = 1;

let (terminate_tx, terminate_rx) = watch::channel::<()>(());
self.terminate_tasks = Some(terminate_tx);

let (in_tx, mut in_rx) = mpsc::unbounded_channel::<Messages>();
let (out_tx, out_rx) = mpsc::unbounded_channel::<Messages>();

if let Some(tls) = &self.tls {
let tls_stream = tls.accept(stream).await?;
let (rx, tx) = tokio::io::split(tls_stream);
spawn_read_write_tasks(self.codec.clone(), rx, tx, in_tx, out_rx, out_tx.clone());
spawn_read_write_tasks(
self.codec.clone(),
rx,
tx,
in_tx,
out_rx,
out_tx.clone(),
terminate_rx,
);
} else {
let (rx, tx) = stream.into_split();
spawn_read_write_tasks(self.codec.clone(), rx, tx, in_tx, out_rx, out_tx.clone());
spawn_read_write_tasks(
self.codec.clone(),
rx,
tx,
in_tx,
out_rx,
out_tx.clone(),
terminate_rx,
);
};

while !self.shutdown.is_shutdown() {
Expand All @@ -448,7 +509,7 @@ impl<C: Codec + 'static> Handler<C> {
let mut reverse_chain = false;

let messages = tokio::select! {
res = timeout(Duration::from_secs(idle_time_seconds) , in_rx.recv()) => {
res = timeout(Duration::from_secs(idle_time_seconds), in_rx.recv()) => {
match res {
Ok(maybe_message) => {
idle_time_seconds = 1;
Expand Down Expand Up @@ -491,27 +552,18 @@ impl<C: Codec + 'static> Handler<C> {
self.chain.name.clone(),
);

let chain_result = if reverse_chain {
let modified_messages = if reverse_chain {
self.chain.process_request_rev(wrapper).await
} else {
self.chain
.process_request(wrapper, self.client_details.clone())
.await
};

match chain_result {
Ok(modified_messages) => {
debug!("sending message: {:?}", modified_messages);
// send the result of the process up stream
out_tx.send(modified_messages)?;
}
Err(e) => {
error!(
"{:?}",
e.context("chain failed to send and/or receive messages")
);
}
}
.context("chain failed to send and/or receive messages")?;

debug!("sending message: {:?}", modified_messages);
// send the result of the process up stream
out_tx.send(modified_messages)?;
}
Ok(())
}
Expand All @@ -531,6 +583,10 @@ impl<C: Codec> Drop for Handler<C> {
// semaphore.

self.limit_connections.add_permits(1);

if let Some(terminate_tasks) = &self.terminate_tasks {
terminate_tasks.send(()).ok();
}
}
}
/// Listens for the server shutdown signal.
Expand Down
5 changes: 1 addition & 4 deletions shotover-proxy/src/transforms/redis/sink_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,7 @@ impl Transform for RedisSinkSingle {

// self.outbound is gauranteed to be Some by the previous block
let outbound_framed_codec = self.outbound.as_mut().unwrap();
outbound_framed_codec
.send(message_wrapper.messages)
.await
.ok();
outbound_framed_codec.send(message_wrapper.messages).await?;

match outbound_framed_codec.next().fuse().await {
Some(mut a) => {
Expand Down
63 changes: 63 additions & 0 deletions shotover-proxy/tests/redis_int_tests/basic_driver_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use redis::{AsyncCommands, Commands, ErrorKind, RedisError, Value};
use serial_test::serial;
use shotover_proxy::tls::TlsConfig;
use std::collections::{HashMap, HashSet};
use std::io::{Read, Write};
use std::path::Path;
use std::thread::sleep;
use std::time::Duration;
Expand Down Expand Up @@ -1212,6 +1213,55 @@ async fn test_dr_auth(shotover_manager: &ShotoverManager) {
);
}

/// A driver variant of this test case is provided so that we can ensure that
/// at least one driver handles this as we expect.
async fn test_trigger_transform_failure_driver(connection: &mut Connection) {
assert_eq!(
redis::cmd("SET")
.arg("foo")
.arg(42)
.query_async::<_, ()>(connection)
.await
.unwrap_err()
.to_string(),
"unexpected end of file".to_string()
);
}

/// A raw variant of this test case is provided so that we can make a strong assertion about the way shotover handles this case.
///
/// CAREFUL: This lacks any kind of check that shotover is ready,
/// so make sure shotover_manager.redis_connection is run on 6379 before calling this.
fn test_trigger_transform_failure_raw() {
// Send invalid redis command
// To correctly handle this shotover should close the connection
let mut connection = std::net::TcpStream::connect("127.0.0.1:6379").unwrap();
connection.write_all(b"*1\r\n$4\r\nping\r\n").unwrap();
connection
.set_read_timeout(Some(Duration::from_secs(10)))
.unwrap();
// If the connection was closed by shotover then we will succesfully read 0 bytes.
// If the connection was not closed by shotover then read will block for 10 seconds until the time is hit and then the unwrap will panic.
let amount = connection.read(&mut [0; 1]).unwrap();
assert_eq!(amount, 0);
}

/// CAREFUL: This lacks any kind of check that shotover is ready,
/// so make sure shotover_manager.redis_connection is run on 6379 before calling this.
fn test_invalid_frame() {
// Send invalid redis command
// To correctly handle this shotover should close the connection
let mut connection = std::net::TcpStream::connect("127.0.0.1:6379").unwrap();
connection.write_all(b"invalid_redis_frame\r\n").unwrap();
connection
.set_read_timeout(Some(Duration::from_secs(10)))
.unwrap();
// If the connection was closed by shotover then we will succesfully read 0 bytes.
// If the connection was not closed by shotover then read will block for 10 seconds until the time is hit and then the unwrap will panic.
let amount = connection.read(&mut [0; 1]).unwrap();
assert_eq!(amount, 0);
}

#[tokio::test(flavor = "multi_thread")]
#[serial]
async fn test_passthrough() {
Expand All @@ -1223,6 +1273,19 @@ async fn test_passthrough() {
Flusher::new_single_connection(shotover_manager.redis_connection_async(6379).await).await;

run_all(&mut connection, &mut flusher).await;
test_invalid_frame();
}

#[tokio::test(flavor = "multi_thread")]
#[serial]
async fn test_passthrough_redis_down() {
let shotover_manager =
ShotoverManager::from_topology_file("example-configs/redis-passthrough/topology.yaml");
let mut connection = shotover_manager.redis_connection_async(6379).await;

test_trigger_transform_failure_driver(&mut connection).await;
test_trigger_transform_failure_raw();
test_invalid_frame();
}

#[tokio::test(flavor = "multi_thread")]
Expand Down

0 comments on commit a42da2c

Please sign in to comment.