Skip to content

Commit

Permalink
Fix connections with no messages sent during node outage (#940)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Dec 8, 2022
1 parent f4d014b commit 8e24912
Show file tree
Hide file tree
Showing 8 changed files with 290 additions and 189 deletions.
96 changes: 50 additions & 46 deletions shotover-proxy/src/transforms/cassandra/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use derivative::Derivative;
use futures::stream::FuturesOrdered;
use futures::{SinkExt, StreamExt};
use halfbrown::HashMap;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::io::{split, AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio::net::ToSocketAddrs;
Expand All @@ -27,10 +28,25 @@ struct Request {
stream_id: i16,
}

#[derive(Debug)]
pub struct Response {
pub type Response = Result<Message, ResponseError>;

#[derive(Debug, thiserror::Error)]
#[error("Connection to destination cassandra node {destination} was closed: {cause}")]
pub struct ResponseError {
#[source]
pub cause: anyhow::Error,
pub destination: SocketAddr,
pub stream_id: i16,
pub response: Result<Message>,
}

impl ResponseError {
pub fn to_response(&self, version: Version) -> Message {
Message::from_frame(Frame::Cassandra(CassandraFrame::shotover_error(
self.stream_id,
version,
&format!("{:#}", self.cause),
)))
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -59,7 +75,7 @@ impl CassandraConnection {
let (return_tx, return_rx) = mpsc::unbounded_channel::<ReturnChannel>();
let (rx_process_has_shutdown_tx, rx_process_has_shutdown_rx) = oneshot::channel::<String>();

let destination = format!("{host:?}");
let destination = tokio::net::lookup_host(&host).await?.next().unwrap();

if let Some(tls) = tls.as_mut() {
let tls_stream = tls.connect(connect_timeout, host).await?;
Expand All @@ -71,7 +87,7 @@ impl CassandraConnection {
return_tx,
codec.clone(),
rx_process_has_shutdown_rx,
destination.clone(),
destination,
)
.in_current_span(),
);
Expand All @@ -96,7 +112,7 @@ impl CassandraConnection {
return_tx,
codec.clone(),
rx_process_has_shutdown_rx,
destination.clone(),
destination,
)
.in_current_span(),
);
Expand Down Expand Up @@ -156,7 +172,7 @@ async fn tx_process<T: AsyncWrite>(
codec: CassandraCodec,
mut rx_process_has_shutdown_rx: oneshot::Receiver<String>,
// Only used for error reporting
destination: String,
destination: SocketAddr,
) {
let mut in_w = FramedWrite::new(write, codec);

Expand All @@ -167,10 +183,10 @@ async fn tx_process<T: AsyncWrite>(
loop {
if let Some(request) = out_rx.recv().await {
if let Some(error) = &connection_dead_error {
send_error_to_request(request.return_chan, request.stream_id, &destination, error);
send_error_to_request(request.return_chan, request.stream_id, destination, error);
} else if let Err(error) = in_w.send(vec![request.message]).await {
let error = format!("{:?}", error);
send_error_to_request(request.return_chan, request.stream_id, &destination, &error);
send_error_to_request(request.return_chan, request.stream_id, destination, &error);
connection_dead_error = Some(error.clone());
} else if let Err(mpsc::error::SendError(return_chan)) = return_tx.send(ReturnChannel {
return_chan: request.return_chan,
Expand All @@ -182,7 +198,7 @@ async fn tx_process<T: AsyncWrite>(
send_error_to_request(
return_chan.return_chan,
return_chan.stream_id,
&destination,
destination,
&error,
);
connection_dead_error = Some(error.clone());
Expand Down Expand Up @@ -214,16 +230,15 @@ async fn tx_process<T: AsyncWrite>(
fn send_error_to_request(
return_chan: oneshot::Sender<Response>,
stream_id: i16,
destination: &str,
destination: SocketAddr,
error: &str,
) {
return_chan
.send(Response {
.send(Err(ResponseError {
cause: anyhow!(error.to_owned()),
destination,
stream_id,
response: Err(anyhow!(
"Connection to destination cassandra node {destination} was closed: {error}"
)),
})
}))
.ok();
}

Expand All @@ -234,7 +249,7 @@ async fn rx_process<T: AsyncRead>(
pushed_messages_tx: Option<mpsc::UnboundedSender<Messages>>,
rx_process_has_shutdown_tx: oneshot::Sender<String>,
// Only used for error reporting
destination: String,
destination: SocketAddr,
) {
let mut reader = FramedRead::new(read, codec);

Expand Down Expand Up @@ -274,7 +289,7 @@ async fn rx_process<T: AsyncRead>(
from_server.insert(stream_id, m);
},
Some(return_tx) => {
return_tx.send(Response { stream_id, response: Ok(m) }).ok();
return_tx.send(Ok(m)).ok();
}
}
}
Expand Down Expand Up @@ -306,7 +321,7 @@ async fn rx_process<T: AsyncRead>(
from_tx_process.insert(stream_id, return_chan);
}
Some(m) => {
return_chan.send(Response { stream_id, response: Ok(m) }).ok();
return_chan.send(Ok(m)).ok();
}
}
} else {
Expand All @@ -325,7 +340,7 @@ async fn send_errors_and_shutdown(
mut return_rx: mpsc::UnboundedReceiver<ReturnChannel>,
mut waiting: HashMap<i16, oneshot::Sender<Response>>,
rx_process_has_shutdown_tx: oneshot::Sender<String>,
destination: String,
destination: SocketAddr,
message: &str,
) {
// Ensure we send this before closing return_rx.
Expand All @@ -337,26 +352,25 @@ async fn send_errors_and_shutdown(

return_rx.close();

let full_message =
format!("Connection to destination cassandra node {destination} was closed: {message}");

for (stream_id, return_tx) in waiting.drain() {
return_tx
.send(Response {
.send(Err(ResponseError {
cause: anyhow!(message.to_owned()),
destination,
stream_id,
response: Err(anyhow!(full_message.to_owned())),
})
}))
.ok();
}

// return_rx is already closed so by looping over all remaning values we ensure there are no dropped unused return_chan's
while let Some(return_chan) = return_rx.recv().await {
return_chan
.return_chan
.send(Response {
.send(Err(ResponseError {
cause: anyhow!(message.to_owned()),
destination,
stream_id: return_chan.stream_id,
response: Err(anyhow!(full_message.to_owned())),
})
}))
.ok();
}
}
Expand All @@ -365,20 +379,19 @@ pub async fn receive(
timeout_duration: Option<Duration>,
failed_requests: &metrics::Counter,
mut results: FuturesOrdered<oneshot::Receiver<Response>>,
version: Version,
) -> Result<Messages> {
) -> Result<Vec<Result<Message, ResponseError>>> {
let expected_size = results.len();
let mut responses = Vec::with_capacity(expected_size);
while responses.len() < expected_size {
if let Some(timeout_duration) = timeout_duration {
match timeout(
timeout_duration,
receive_message(failed_requests, &mut results, version),
receive_message(failed_requests, &mut results),
)
.await
{
Ok(response) => {
responses.push(response?);
responses.push(response);
}
Err(_) => {
return Err(anyhow!(
Expand All @@ -389,7 +402,7 @@ pub async fn receive(
}
}
} else {
responses.push(receive_message(failed_requests, &mut results, version).await?);
responses.push(receive_message(failed_requests, &mut results).await);
}
}
Ok(responses)
Expand All @@ -398,14 +411,10 @@ pub async fn receive(
async fn receive_message(
failed_requests: &metrics::Counter,
results: &mut FuturesOrdered<oneshot::Receiver<Response>>,
version: Version,
) -> Result<Message> {
) -> Result<Message, ResponseError> {
match results.next().await {
Some(result) => match result.expect("The tx_process task must always return a value") {
Response {
response: Ok(message),
..
} => {
Ok(message) => {
if let Ok(Metadata::Cassandra(CassandraMetadata {
opcode: Opcode::Error,
..
Expand All @@ -415,12 +424,7 @@ async fn receive_message(
}
Ok(message)
}
Response {
stream_id,
response: Err(err),
} => Ok(Message::from_frame(Frame::Cassandra(
CassandraFrame::shotover_error(stream_id, version, &format!("{:?}", err)),
))),
err => err,
},
None => unreachable!("Ran out of responses"),
}
Expand Down
74 changes: 43 additions & 31 deletions shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ impl CassandraSinkCluster {
}

if self.nodes_rx.has_changed()? {
// This approach to keeping nodes list up to date has a problem when a node goes down and then up again before this transform instance can process the down going down.
// When this happens we never detect that the node went down and a dead connection is left around.
// Broadcast channel's SendError::Lagged would solve this problem but we cant use broadcast channels because cloning them doesnt keep a past value.
// It might be worth implementing a custom watch channel that supports Lagged errors to improve correctness.
//
// However none of this is actually a problem because dead connection detection logic handles this case for us.
self.pool.update_nodes(&mut self.nodes_rx);

// recreate the control connection if it is down
Expand Down Expand Up @@ -434,23 +440,18 @@ impl CassandraSinkCluster {
// send an unprepared error in response to force
// the client to reprepare the query
return_chan_tx
.send(Response {
stream_id: metadata.stream_id,
response: Ok(Message::from_frame(Frame::Cassandra(
CassandraFrame {
operation: CassandraOperation::Error(ErrorBody {
message: "Shotover does not have this query's metadata. Please re-prepare on this Shotover host before sending again.".into(),
ty: ErrorType::Unprepared(UnpreparedError {
id,
}),
}),
stream_id: metadata.stream_id,
tracing: Tracing::Response(None), // We didn't actually hit a node so we don't have a tracing id
version: self.version.unwrap(),
warnings: vec![],
},
))),
}).expect("the receiver is guaranteed to be alive, so this must succeed");
.send(Ok(Message::from_frame(Frame::Cassandra(
CassandraFrame {
operation: CassandraOperation::Error(ErrorBody {
message: "Shotover does not have this query's metadata. Please re-prepare on this Shotover host before sending again.".into(),
ty: ErrorType::Unprepared(UnpreparedError { id }),
}),
stream_id: metadata.stream_id,
tracing: Tracing::Response(None), // We didn't actually hit a node so we don't have a tracing id
version: self.version.unwrap(),
warnings: vec![],
},
)))).expect("the receiver is guaranteed to be alive, so this must succeed");
}
Err(GetReplicaErr::Other(err)) => {
return Err(err);
Expand All @@ -471,22 +472,37 @@ impl CassandraSinkCluster {
responses_future.push_back(return_chan_rx)
}

let mut responses = super::connection::receive(
self.read_timeout,
&self.failed_requests,
responses_future,
self.version.unwrap(),
)
.await?;
let response_results =
super::connection::receive(self.read_timeout, &self.failed_requests, responses_future)
.await?;
let mut responses = vec![];
for response in response_results {
match response {
Ok(response) => responses.push(response),
Err(error) => {
self.pool.report_issue_with_node(error.destination);
responses.push(error.to_response(self.version.unwrap()));
}
}
}

{
let mut prepare_responses = super::connection::receive(
let prepare_response_results = super::connection::receive(
self.read_timeout,
&self.failed_requests,
responses_future_prepare,
self.version.unwrap(),
)
.await?;
let mut prepare_responses = vec![];
for response in prepare_response_results {
match response {
Ok(response) => prepare_responses.push(response),
Err(error) => {
self.pool.report_issue_with_node(error.destination);
prepare_responses.push(error.to_response(self.version.unwrap()));
}
}
}

let prepared_results: Vec<&mut Box<BodyResResultPrepared>> = prepare_responses
.iter_mut()
Expand Down Expand Up @@ -1119,11 +1135,7 @@ fn is_ddl_statement(request: &mut Message) -> bool {
}

fn is_use_statement_successful(response: Option<Result<Response>>) -> bool {
if let Some(Ok(Response {
response: Ok(mut response),
..
})) = response
{
if let Some(Ok(Ok(mut response))) = response {
if let Some(Frame::Cassandra(CassandraFrame {
operation: CassandraOperation::Result(CassandraResult::SetKeyspace(_)),
..
Expand Down
4 changes: 2 additions & 2 deletions shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl ConnectionFactory {
})?;
return_chan_rx.await.map_err(|e| {
anyhow!(e).context("Failed to initialize new connection with handshake, rx failed")
})?;
})??;
}

if let Some(use_message) = &self.use_message {
Expand All @@ -138,7 +138,7 @@ impl ConnectionFactory {
return_chan_rx.await.map_err(|e| {
anyhow!(e)
.context("Failed to initialize new connection with use message, rx failed")
})?;
})??;
}

Ok(outbound)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ impl NodePool {
self.token_map = TokenMap::new(self.nodes.as_slice());
}

pub fn report_issue_with_node(&mut self, address: SocketAddr) {
for node in &mut self.nodes {
if node.address == address {
node.is_up = false;
node.outbound = None;
}
}
}

pub async fn update_keyspaces(&mut self, keyspaces_rx: &mut KeyspaceChanRx) {
let updated_keyspaces = keyspaces_rx.borrow_and_update().clone();
self.keyspace_metadata = updated_keyspaces;
Expand Down
Loading

0 comments on commit 8e24912

Please sign in to comment.