Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Close connection on transform error #707

Merged
merged 4 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 95 additions & 39 deletions shotover-proxy/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::message::Messages;
use crate::tls::TlsAcceptor;
use crate::transforms::chain::TransformChain;
use crate::transforms::Wrapper;
use anyhow::{anyhow, Result};
use futures::StreamExt;
use anyhow::{anyhow, Context, Result};
use futures::{SinkExt, StreamExt};
use metrics::{register_gauge, Gauge};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
Expand All @@ -13,7 +13,6 @@ use tokio::sync::{mpsc, watch, Semaphore};
use tokio::time;
use tokio::time::timeout;
use tokio::time::Duration;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::codec::{Decoder, Encoder};
use tokio_util::codec::{FramedRead, FramedWrite};
use tracing::Instrument;
Expand Down Expand Up @@ -231,8 +230,8 @@ impl<C: Codec + 'static> TcpCodecListener<C> {
// Receive shutdown notifications.
shutdown: Shutdown::new(self.trigger_shutdown_rx.clone()),

terminate_tasks: None,
tls: self.tls.clone(),

timeout: self.timeout,
};

Expand Down Expand Up @@ -356,6 +355,7 @@ pub struct Handler<C: Codec> {
/// which point the connection is terminated.
shutdown: Shutdown,

terminate_tasks: Option<watch::Sender<()>>,
tls: Option<TlsAcceptor>,

/// Timeout in seconds after which to kill an idle connection. No timeout means connections will never be timed out.
Expand All @@ -371,28 +371,54 @@ fn spawn_read_write_tasks<
rx: R,
tx: W,
in_tx: UnboundedSender<Messages>,
out_rx: UnboundedReceiver<Messages>,
mut out_rx: UnboundedReceiver<Messages>,
out_tx: UnboundedSender<Messages>,
mut terminate_tasks_rx: watch::Receiver<()>,
) {
let mut reader = FramedRead::new(rx, codec.clone());
let writer = FramedWrite::new(tx, codec);

let mut writer = FramedWrite::new(tx, codec);

// Shutdown flows
//
// main task shuts down due to transform error:
// 1. The main task terminates, sending terminate_tasks_tx and dropping the first out_tx
// 2. The reader task detects change on terminate_tasks_rx and terminates, the last out_tx instance is dropped
// 3. The writer task detects that the last out_tx is dropped by out_rx returning None and terminates
//
// client closes connection:
// 1. The reader task detects that the client has closed the connection via reader returning None and terminates, dropping in_tx and the first out_tx
// 2. The main task detects that in_tx is dropped by in_rx returning None and terminates, dropping the last out_tx
// 3. The writer task detects that the last out_tx is dropped by out_rx returning None and terminates

// reader task
tokio::spawn(
async move {
while let Some(message) = reader.next().await {
match message {
Ok(message) => {
let remaining_messages =
process_return_to_sender_messages(message, &out_tx);
if !remaining_messages.is_empty() {
if let Err(error) = in_tx.send(remaining_messages) {
warn!("failed to pass on received message: {}", error);
return;
loop {
tokio::select! {
result = reader.next() => {
if let Some(message) = result {
match message {
Ok(message) => {
let remaining_messages =
process_return_to_sender_messages(message, &out_tx);
if !remaining_messages.is_empty() {
if let Err(error) = in_tx.send(remaining_messages) {
warn!("failed to pass on received message: {}", error);
return;
}
}
}
Err(error) => {
warn!("failed to receive or decode message: {:?}", error);
return;
}
}
} else {
debug!("client has closed the connection");
return;
}
}
Err(error) => {
warn!("failed to receive or decode message: {:?}", error);
_ = terminate_tasks_rx.changed() => {
return;
}
}
Expand All @@ -401,11 +427,27 @@ fn spawn_read_write_tasks<
.in_current_span(),
);

// sender task
tokio::spawn(
async move {
let rx_stream = UnboundedReceiverStream::new(out_rx).map(Ok);
if let Err(err) = rx_stream.forward(writer).await {
error!("failed to send or encode message: {:?}", err);
loop {
if let Some(message) = out_rx.recv().await {
if let Err(err) = writer.send(message).await {
error!("failed to send or encode message: {:?}", err);
}
} else {
// Main task has ended.
// First flush out any remaining messages.
// Then end the task thus closing the connection by dropping the write half
while let Ok(message) = out_rx.try_recv() {
if let Err(err) = writer.send(message).await {
error!(
"while flushing messages: failed to send or encode message: {err:?}",
);
}
}
break;
}
}
}
.in_current_span(),
Expand All @@ -430,16 +472,35 @@ impl<C: Codec + 'static> Handler<C> {
// new request frame.
let mut idle_time_seconds: u64 = 1;

let (terminate_tx, terminate_rx) = watch::channel::<()>(());
self.terminate_tasks = Some(terminate_tx);

let (in_tx, mut in_rx) = mpsc::unbounded_channel::<Messages>();
let (out_tx, out_rx) = mpsc::unbounded_channel::<Messages>();

if let Some(tls) = &self.tls {
let tls_stream = tls.accept(stream).await?;
let (rx, tx) = tokio::io::split(tls_stream);
spawn_read_write_tasks(self.codec.clone(), rx, tx, in_tx, out_rx, out_tx.clone());
spawn_read_write_tasks(
self.codec.clone(),
rx,
tx,
in_tx,
out_rx,
out_tx.clone(),
terminate_rx,
);
} else {
let (rx, tx) = stream.into_split();
spawn_read_write_tasks(self.codec.clone(), rx, tx, in_tx, out_rx, out_tx.clone());
spawn_read_write_tasks(
self.codec.clone(),
rx,
tx,
in_tx,
out_rx,
out_tx.clone(),
terminate_rx,
);
};

while !self.shutdown.is_shutdown() {
Expand All @@ -448,7 +509,7 @@ impl<C: Codec + 'static> Handler<C> {
let mut reverse_chain = false;

let messages = tokio::select! {
res = timeout(Duration::from_secs(idle_time_seconds) , in_rx.recv()) => {
res = timeout(Duration::from_secs(idle_time_seconds), in_rx.recv()) => {
match res {
Ok(maybe_message) => {
idle_time_seconds = 1;
Expand Down Expand Up @@ -491,27 +552,18 @@ impl<C: Codec + 'static> Handler<C> {
self.chain.name.clone(),
);

let chain_result = if reverse_chain {
let modified_messages = if reverse_chain {
self.chain.process_request_rev(wrapper).await
} else {
self.chain
.process_request(wrapper, self.client_details.clone())
.await
};

match chain_result {
Ok(modified_messages) => {
debug!("sending message: {:?}", modified_messages);
// send the result of the process up stream
out_tx.send(modified_messages)?;
}
Err(e) => {
error!(
"{:?}",
e.context("chain failed to send and/or receive messages")
);
}
}
.context("chain failed to send and/or receive messages")?;

debug!("sending message: {:?}", modified_messages);
// send the result of the process up stream
out_tx.send(modified_messages)?;
}
Ok(())
}
Expand All @@ -531,6 +583,10 @@ impl<C: Codec> Drop for Handler<C> {
// semaphore.

self.limit_connections.add_permits(1);

if let Some(terminate_tasks) = &self.terminate_tasks {
terminate_tasks.send(()).ok();
}
}
}
/// Listens for the server shutdown signal.
Expand Down
5 changes: 1 addition & 4 deletions shotover-proxy/src/transforms/redis/sink_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,7 @@ impl Transform for RedisSinkSingle {

// self.outbound is gauranteed to be Some by the previous block
let outbound_framed_codec = self.outbound.as_mut().unwrap();
outbound_framed_codec
.send(message_wrapper.messages)
.await
.ok();
outbound_framed_codec.send(message_wrapper.messages).await?;

match outbound_framed_codec.next().fuse().await {
Some(mut a) => {
Expand Down
63 changes: 63 additions & 0 deletions shotover-proxy/tests/redis_int_tests/basic_driver_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use redis::{AsyncCommands, Commands, ErrorKind, RedisError, Value};
use serial_test::serial;
use shotover_proxy::tls::TlsConfig;
use std::collections::{HashMap, HashSet};
use std::io::{Read, Write};
use std::path::Path;
use std::thread::sleep;
use std::time::Duration;
Expand Down Expand Up @@ -1212,6 +1213,55 @@ async fn test_dr_auth(shotover_manager: &ShotoverManager) {
);
}

/// A driver variant of this test case is provided so that we can ensure that
/// at least one driver handles this as we expect.
async fn test_trigger_transform_failure_driver(connection: &mut Connection) {
assert_eq!(
redis::cmd("SET")
.arg("foo")
.arg(42)
.query_async::<_, ()>(connection)
.await
.unwrap_err()
.to_string(),
"unexpected end of file".to_string()
);
}

/// A raw variant of this test case is provided so that we can make a strong assertion about the way shotover handles this case.
///
/// CAREFUL: This lacks any kind of check that shotover is ready,
/// so make sure shotover_manager.redis_connection is run on 6379 before calling this.
fn test_trigger_transform_failure_raw() {
// Send invalid redis command
// To correctly handle this shotover should close the connection
let mut connection = std::net::TcpStream::connect("127.0.0.1:6379").unwrap();
connection.write_all(b"*1\r\n$4\r\nping\r\n").unwrap();
connection
.set_read_timeout(Some(Duration::from_secs(10)))
.unwrap();
// If the connection was closed by shotover then we will succesfully read 0 bytes.
// If the connection was not closed by shotover then read will block for 10 seconds until the time is hit and then the unwrap will panic.
let amount = connection.read(&mut [0; 1]).unwrap();
assert_eq!(amount, 0);
}

/// CAREFUL: This lacks any kind of check that shotover is ready,
/// so make sure shotover_manager.redis_connection is run on 6379 before calling this.
fn test_invalid_frame() {
// Send invalid redis command
// To correctly handle this shotover should close the connection
let mut connection = std::net::TcpStream::connect("127.0.0.1:6379").unwrap();
connection.write_all(b"invalid_redis_frame\r\n").unwrap();
connection
.set_read_timeout(Some(Duration::from_secs(10)))
.unwrap();
// If the connection was closed by shotover then we will succesfully read 0 bytes.
// If the connection was not closed by shotover then read will block for 10 seconds until the time is hit and then the unwrap will panic.
let amount = connection.read(&mut [0; 1]).unwrap();
assert_eq!(amount, 0);
}

#[tokio::test(flavor = "multi_thread")]
#[serial]
async fn test_passthrough() {
Expand All @@ -1223,6 +1273,19 @@ async fn test_passthrough() {
Flusher::new_single_connection(shotover_manager.redis_connection_async(6379).await).await;

run_all(&mut connection, &mut flusher).await;
test_invalid_frame();
}

#[tokio::test(flavor = "multi_thread")]
#[serial]
async fn test_passthrough_redis_down() {
let shotover_manager =
ShotoverManager::from_topology_file("example-configs/redis-passthrough/topology.yaml");
let mut connection = shotover_manager.redis_connection_async(6379).await;

test_trigger_transform_failure_driver(&mut connection).await;
test_trigger_transform_failure_raw();
test_invalid_frame();
}

#[tokio::test(flavor = "multi_thread")]
Expand Down