Skip to content

Commit

Permalink
Cassandra: Do not log error on TCP RST (#850)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Oct 13, 2022
1 parent c7c9f49 commit 6d4c9b8
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 34 deletions.
130 changes: 96 additions & 34 deletions shotover-proxy/src/transforms/cassandra/connection.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
use crate::codec::cassandra::CassandraCodec;
use crate::frame::cassandra::CassandraMetadata;
use crate::message::{Message, Metadata};
use crate::server::CodecReadError;
use crate::tls::TlsConnector;
use crate::transforms::util::Response;
use crate::transforms::Messages;
use anyhow::{anyhow, Result};
use cassandra_protocol::frame::Opcode;
use derivative::Derivative;
use futures::stream::FuturesOrdered;
use futures::StreamExt;
use futures::{SinkExt, StreamExt};
use halfbrown::HashMap;
use std::time::Duration;
use tokio::io::{split, AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::sync::{mpsc, oneshot};
use tokio::time::timeout;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::codec::{FramedRead, FramedWrite};
use tracing::{error, Instrument};

Expand Down Expand Up @@ -55,19 +55,52 @@ impl CassandraConnection {

let (out_tx, out_rx) = mpsc::unbounded_channel::<Request>();
let (return_tx, return_rx) = mpsc::unbounded_channel::<Request>();
let (rx_process_has_shutdown_tx, rx_process_has_shutdown_rx) = oneshot::channel::<()>();

if let Some(tls) = tls.as_mut() {
let tls_stream = tls.connect(tcp_stream).await?;
let (read, write) = split(tls_stream);
tokio::spawn(tx_process(write, out_rx, return_tx, codec.clone()).in_current_span());
tokio::spawn(
rx_process(read, return_rx, codec.clone(), pushed_messages_tx).in_current_span(),
tx_process(
write,
out_rx,
return_tx,
codec.clone(),
rx_process_has_shutdown_rx,
)
.in_current_span(),
);
tokio::spawn(
rx_process(
read,
return_rx,
codec.clone(),
pushed_messages_tx,
rx_process_has_shutdown_tx,
)
.in_current_span(),
);
} else {
let (read, write) = split(tcp_stream);
tokio::spawn(tx_process(write, out_rx, return_tx, codec.clone()).in_current_span());
tokio::spawn(
rx_process(read, return_rx, codec.clone(), pushed_messages_tx).in_current_span(),
tx_process(
write,
out_rx,
return_tx,
codec.clone(),
rx_process_has_shutdown_rx,
)
.in_current_span(),
);
tokio::spawn(
rx_process(
read,
return_rx,
codec.clone(),
pushed_messages_tx,
rx_process_has_shutdown_tx,
)
.in_current_span(),
);
};

Expand Down Expand Up @@ -96,37 +129,61 @@ async fn tx_process<T: AsyncWrite>(
out_rx: mpsc::UnboundedReceiver<Request>,
return_tx: mpsc::UnboundedSender<Request>,
codec: CassandraCodec,
rx_process_has_shutdown_rx: oneshot::Receiver<()>,
) {
if let Err(err) = tx_process_fallible(write, out_rx, return_tx, codec).await {
if let Err(err) =
tx_process_fallible(write, out_rx, return_tx, codec, rx_process_has_shutdown_rx).await
{
error!("{:?}", err.context("tx_process task terminated"));
}
}

async fn tx_process_fallible<T: AsyncWrite>(
write: WriteHalf<T>,
out_rx: mpsc::UnboundedReceiver<Request>,
mut out_rx: mpsc::UnboundedReceiver<Request>,
return_tx: mpsc::UnboundedSender<Request>,
codec: CassandraCodec,
rx_process_has_shutdown_rx: oneshot::Receiver<()>,
) -> Result<()> {
let in_w = FramedWrite::new(write, codec);
let rx_stream = UnboundedReceiverStream::new(out_rx).map(|x| {
let ret = Ok(vec![x.message.clone()]);
return_tx.send(x)?;
ret
});
rx_stream.forward(in_w).await?;
Ok(())
let mut in_w = FramedWrite::new(write, codec);
loop {
if let Some(request) = out_rx.recv().await {
in_w.send(vec![request.message.clone()]).await?;
return_tx.send(request)?;
} else {
// transform is shutting down, time to cleanly shutdown both tx_process and rx_process.
// We need to ensure that the rx_process task has shutdown before closing the write half of the tcpstream
// If we dont do this, rx_process may attempt to read from the tcp stream after the write half has closed.
// Closing the write half will send a TCP FIN ACK to the server.
// The server may then respond with a TCP RST, after which any reads from the read half would return a ConnectionReset error

// first we drop return_tx which will instruct rx_process to shutdown
std::mem::drop(return_tx);

// wait for rx_process to shutdown
rx_process_has_shutdown_rx.await.ok();

// Now that rx_process is shutdown we can safely drop the write half of the
// tcp stream without the read half hitting errors due to the connection being closed or reset.
std::mem::drop(in_w);
return Ok(());
}
}
}

async fn rx_process<T: AsyncRead>(
read: ReadHalf<T>,
return_rx: mpsc::UnboundedReceiver<Request>,
codec: CassandraCodec,
pushed_messages_tx: Option<mpsc::UnboundedSender<Messages>>,
rx_process_has_shutdown_tx: oneshot::Sender<()>,
) {
if let Err(err) = rx_process_fallible(read, return_rx, codec, pushed_messages_tx).await {
error!("{:?}", err.context("rx_process task terminated"));
}

// Just dropping this is enough to notify of shutdown
std::mem::drop(rx_process_has_shutdown_tx);
}

async fn rx_process_fallible<T: AsyncRead>(
Expand All @@ -142,12 +199,12 @@ async fn rx_process_fallible<T: AsyncRead>(

loop {
tokio::select! {
Some(response) = reader.next() => {
response = reader.next() => {
match response {
Ok(response) => {
Some(Ok(response)) => {
for m in response {
if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = m.metadata() {
if let Some(ref pushed_messages_tx) = pushed_messages_tx {
if let Some(pushed_messages_tx) = pushed_messages_tx.as_ref() {
pushed_messages_tx.send(vec![m]).unwrap();
}
} else if let Some(stream_id) = m.stream_id() {
Expand All @@ -163,26 +220,31 @@ async fn rx_process_fallible<T: AsyncRead>(
}
}
}
Err(e) => {
return Err(anyhow!("{:?}", e).context("Encountered error while communicating with destination cassandra node"));
Some(Err(CodecReadError::Io(err))) => {
return Err(anyhow!(err)
.context("Encountered IO error while communicating with destination cassandra node"))
}
Some(Err(err)) => {
return Err(anyhow!("{:?}", err).context("Encountered error while communicating with destination cassandra node"));
}
None => return Ok(())
}
},
Some(original_request) = return_rx.recv() => {
let Request { message, return_chan, message_id } = original_request;
match return_message_map.remove(&message_id) {
None => {
return_channel_map.insert(message_id, (return_chan, message));
}
Some(m) => {
return_chan.send(Response { original: message, response: Ok(m) })
.map_err(|_| anyhow!("couldn't send message"))?;
original_request = return_rx.recv() => {
if let Some(Request { message, return_chan, message_id }) = original_request {
match return_message_map.remove(&message_id) {
None => {
return_channel_map.insert(message_id, (return_chan, message));
}
Some(m) => {
return_chan.send(Response { original: message, response: Ok(m) })
.map_err(|_| anyhow!("couldn't send message"))?;
}
}
};
} else {
return Ok(())
}
},
else => {
return Ok(())
}
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions shotover-proxy/src/transforms/util/cluster_connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ async fn rx_process<C: CodecReadHalf, R: AsyncRead + Unpin + Send + 'static>(
) -> Result<()> {
let mut reader = FramedRead::new(read, codec);

// TODO: This reader.next() may perform reads after tx_process has shutdown the write half.
// This may result in unexpected ConnectionReset errors.
// refer to the cassandra connection logic.
while let Some(responses) = reader.next().await {
match responses {
Ok(responses) => {
Expand Down

0 comments on commit 6d4c9b8

Please sign in to comment.