diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index f6620636e..075ea2025 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -8,7 +8,7 @@ use cassandra_protocol::frame::events::ServerEvent; use cassandra_protocol::frame::message_batch::{ BatchQuery, BatchQuerySubj, BatchType, BodyReqBatch, }; -use cassandra_protocol::frame::message_error::ErrorBody; +use cassandra_protocol::frame::message_error::{ErrorBody, ErrorType}; use cassandra_protocol::frame::message_event::BodyResEvent; use cassandra_protocol::frame::message_execute::BodyReqExecuteOwned; use cassandra_protocol::frame::message_query::BodyReqQuery; @@ -161,6 +161,19 @@ impl CassandraFrame { }) } + pub fn shotover_error(stream_id: i16, version: Version, message: &str) -> Self { + CassandraFrame { + version, + stream_id, + operation: CassandraOperation::Error(ErrorBody { + message: format!("Internal shotover error: {message}"), + ty: ErrorType::Server, + }), + tracing: Tracing::Response(None), + warnings: vec![], + } + } + pub fn from_bytes(bytes: Bytes) -> Result { let frame = RawCassandraFrame::from_buffer(&bytes, Compression::None) .map_err(|e| anyhow!("{e:?}"))? diff --git a/shotover-proxy/src/transforms/cassandra/connection.rs b/shotover-proxy/src/transforms/cassandra/connection.rs index 5cb1312de..5c2f6c7a2 100644 --- a/shotover-proxy/src/transforms/cassandra/connection.rs +++ b/shotover-proxy/src/transforms/cassandra/connection.rs @@ -1,13 +1,13 @@ use crate::codec::cassandra::CassandraCodec; use crate::frame::cassandra::CassandraMetadata; +use crate::frame::{CassandraFrame, Frame}; use crate::message::{Message, Metadata}; use crate::server::CodecReadError; use crate::tcp; use crate::tls::{TlsConnector, ToHostname}; -use crate::transforms::util::Response; use crate::transforms::Messages; use anyhow::{anyhow, Result}; -use cassandra_protocol::frame::Opcode; +use cassandra_protocol::frame::{Opcode, Version}; use derivative::Derivative; use futures::stream::FuturesOrdered; use futures::{SinkExt, StreamExt}; @@ -25,7 +25,19 @@ use tracing::{error, Instrument}; struct Request { message: Message, return_chan: oneshot::Sender, - message_id: i16, + stream_id: i16, +} + +#[derive(Debug)] +pub struct Response { + pub stream_id: i16, + pub response: Result, +} + +#[derive(Debug)] +struct ReturnChannel { + return_chan: oneshot::Sender, + stream_id: i16, } #[derive(Clone, Derivative)] @@ -43,7 +55,7 @@ impl CassandraConnection { pushed_messages_tx: Option>, ) -> Result { let (out_tx, out_rx) = mpsc::unbounded_channel::(); - let (return_tx, return_rx) = mpsc::unbounded_channel::(); + let (return_tx, return_rx) = mpsc::unbounded_channel::(); let (rx_process_has_shutdown_tx, rx_process_has_shutdown_rx) = oneshot::channel::<()>(); if let Some(tls) = tls.as_mut() { @@ -100,12 +112,12 @@ impl CassandraConnection { /// Send a `Message` to this `CassandraConnection` and expect a response on `return_chan` pub fn send(&self, message: Message, return_chan: oneshot::Sender) -> Result<()> { // Convert the message to `Request` and send upstream - if let Some(message_id) = message.stream_id() { + if let Some(stream_id) = message.stream_id() { self.connection .send(Request { message, return_chan, - message_id, + stream_id, }) .map_err(|x| x.into()) } else { @@ -117,7 +129,7 @@ impl CassandraConnection { async fn tx_process( write: WriteHalf, out_rx: mpsc::UnboundedReceiver, - return_tx: mpsc::UnboundedSender, + return_tx: mpsc::UnboundedSender, codec: CassandraCodec, rx_process_has_shutdown_rx: oneshot::Receiver<()>, ) { @@ -131,15 +143,18 @@ async fn tx_process( async fn tx_process_fallible( write: WriteHalf, mut out_rx: mpsc::UnboundedReceiver, - return_tx: mpsc::UnboundedSender, + return_tx: mpsc::UnboundedSender, codec: CassandraCodec, rx_process_has_shutdown_rx: oneshot::Receiver<()>, ) -> Result<()> { let mut in_w = FramedWrite::new(write, codec); loop { if let Some(request) = out_rx.recv().await { - in_w.send(vec![request.message.clone()]).await?; - return_tx.send(request)?; + in_w.send(vec![request.message]).await?; + return_tx.send(ReturnChannel { + return_chan: request.return_chan, + stream_id: request.stream_id, + })?; } else { // transform is shutting down, time to cleanly shutdown both tx_process and rx_process. // We need to ensure that the rx_process task has shutdown before closing the write half of the tcpstream @@ -163,7 +178,7 @@ async fn tx_process_fallible( async fn rx_process( read: ReadHalf, - return_rx: mpsc::UnboundedReceiver, + return_rx: mpsc::UnboundedReceiver, codec: CassandraCodec, pushed_messages_tx: Option>, rx_process_has_shutdown_tx: oneshot::Sender<()>, @@ -178,7 +193,7 @@ async fn rx_process( async fn rx_process_fallible( read: ReadHalf, - mut return_rx: mpsc::UnboundedReceiver, + mut return_rx: mpsc::UnboundedReceiver, codec: CassandraCodec, pushed_messages_tx: Option>, ) -> Result<()> { @@ -192,13 +207,13 @@ async fn rx_process_fallible( // Implementation: // To process a message we need to receive things from two different sources: // 1. the response from the cassandra server - // 2. the oneshot::Sender and original message from the tx_process task + // 2. the oneshot::Sender from the tx_process task // // We can receive these in any order. // In order to handle that we have two seperate maps. // - // We store the sender + original message here if we receive from the tx_process task first - let mut from_tx_process: HashMap, Message)> = HashMap::new(); + // We store the sender here if we receive from the tx_process task first + let mut from_tx_process: HashMap> = HashMap::new(); // We store the response message here if we receive from the server first. let mut from_server: HashMap = HashMap::new(); @@ -209,7 +224,8 @@ async fn rx_process_fallible( match response { Some(Ok(response)) => { for m in response { - if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = m.metadata() { + let meta = m.metadata(); + if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = meta { if let Some(pushed_messages_tx) = pushed_messages_tx.as_ref() { pushed_messages_tx.send(vec![m]).unwrap(); } @@ -218,8 +234,8 @@ async fn rx_process_fallible( None => { from_server.insert(stream_id, m); }, - Some((return_tx, original)) => { - return_tx.send(Response { original, response: Ok(m) }) + Some(return_tx) => { + return_tx.send(Response { stream_id, response: Ok(m) }) .map_err(|_| anyhow!("couldn't send message"))?; } } @@ -236,14 +252,14 @@ async fn rx_process_fallible( None => return Ok(()) } }, - original_request = return_rx.recv() => { - if let Some(Request { message, return_chan, message_id }) = original_request { - match from_server.remove(&message_id) { + return_chan = return_rx.recv() => { + if let Some(ReturnChannel { return_chan, stream_id }) = return_chan { + match from_server.remove(&stream_id) { None => { - from_tx_process.insert(message_id, (return_chan, message)); + from_tx_process.insert(stream_id, return_chan); } Some(m) => { - return_chan.send(Response { original: message, response: Ok(m) }) + return_chan.send(Response { stream_id, response: Ok(m) }) .map_err(|_| anyhow!("couldn't send message"))?; } } @@ -259,6 +275,7 @@ pub async fn receive( timeout_duration: Option, failed_requests: &metrics::Counter, mut results: FuturesOrdered>, + version: Version, ) -> Result { let expected_size = results.len(); let mut responses = Vec::with_capacity(expected_size); @@ -266,7 +283,7 @@ pub async fn receive( if let Some(timeout_duration) = timeout_duration { match timeout( timeout_duration, - receive_message(failed_requests, &mut results), + receive_message(failed_requests, &mut results, version), ) .await { @@ -282,7 +299,7 @@ pub async fn receive( } } } else { - responses.push(receive_message(failed_requests, &mut results).await?); + responses.push(receive_message(failed_requests, &mut results, version).await?); } } Ok(responses) @@ -291,6 +308,7 @@ pub async fn receive( pub async fn receive_message( failed_requests: &metrics::Counter, results: &mut FuturesOrdered>, + version: Version, ) -> Result { match results.next().await { Some(result) => match result? { @@ -308,12 +326,11 @@ pub async fn receive_message( Ok(message) } Response { - mut original, + stream_id, response: Err(err), - } => { - original.set_error(err.to_string()); - Ok(original) - } + } => Ok(Message::from_frame(Frame::Cassandra( + CassandraFrame::shotover_error(stream_id, version, &err.to_string()), + ))), }, None => unreachable!("Ran out of responses"), } diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs index ac0b1d278..da0efa74d 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs @@ -1,10 +1,9 @@ use crate::error::ChainResponse; use crate::frame::cassandra::{parse_statement_single, CassandraMetadata, Tracing}; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; -use crate::message::{IntSize, Message, MessageValue, Messages}; +use crate::message::{IntSize, Message, MessageValue, Messages, Metadata}; use crate::tls::{TlsConnector, TlsConnectorConfig}; -use crate::transforms::cassandra::connection::CassandraConnection; -use crate::transforms::util::Response; +use crate::transforms::cassandra::connection::{CassandraConnection, Response}; use crate::transforms::{Transform, Transforms, Wrapper}; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -111,6 +110,7 @@ pub struct CassandraSinkCluster { control_connection_address: Option, init_handshake_complete: bool, + version: Option, chain_name: String, failed_requests: Counter, read_timeout: Option, @@ -138,6 +138,7 @@ impl Clone for CassandraSinkCluster { connection_factory: self.connection_factory.new_with_same_config(), control_connection_address: None, init_handshake_complete: false, + version: self.version, chain_name: self.chain_name.clone(), failed_requests: self.failed_requests.clone(), read_timeout: self.read_timeout, @@ -192,6 +193,7 @@ impl CassandraSinkCluster { control_connection: None, control_connection_address: None, init_handshake_complete: false, + version: None, chain_name, failed_requests, read_timeout: receive_timeout, @@ -229,6 +231,25 @@ fn create_query(messages: &Messages, query: &str, version: Version) -> Result ChainResponse { + if self.version.is_none() { + if let Some(message) = messages.first() { + if let Ok(Metadata::Cassandra(CassandraMetadata { version, .. })) = + message.metadata() + { + self.version = Some(version); + } else { + return Err(anyhow!( + "Failed to extract cassandra version from incoming message: Not a valid cassandra message" + )); + } + } else { + // It's an invariant that self.version is Some. + // Since we were unable to set it, we need to return immediately. + // This is ok because if there are no messages then we have no work to do anyway. + return Ok(vec![]); + } + } + if self.nodes_rx.has_changed()? { self.pool.update_nodes(&mut self.nodes_rx); @@ -265,13 +286,13 @@ impl CassandraSinkCluster { let query = "SELECT rack, data_center, schema_version, tokens, release_version FROM system.peers"; messages.insert( table_to_rewrite.index + 1, - create_query(&messages, query, table_to_rewrite.version)?, + create_query(&messages, query, self.version.unwrap())?, ); if let RewriteTableTy::Peers = table_to_rewrite.ty { let query = "SELECT rack, data_center, schema_version, tokens, release_version FROM system.local"; messages.insert( table_to_rewrite.index + 2, - create_query(&messages, query, table_to_rewrite.version)?, + create_query(&messages, query, self.version.unwrap())?, ); } } @@ -386,7 +407,7 @@ impl CassandraSinkCluster { .get_replica_node_in_dc( execute, &self.local_shotover_node.rack, - &metadata.version, + self.version.unwrap(), &mut self.rng, ) .await @@ -408,12 +429,12 @@ impl CassandraSinkCluster { Err(GetReplicaErr::NoPreparedMetadata) => { let id = execute.id.clone(); tracing::info!("forcing re-prepare on {:?}", id); - // this shotover node doesn't have the metadata + // this shotover node doesn't have the metadata. // send an unprepared error in response to force // the client to reprepare the query return_chan_tx .send(Response { - original: message.clone(), + stream_id: metadata.stream_id, response: Ok(Message::from_frame(Frame::Cassandra( CassandraFrame { operation: CassandraOperation::Error(ErrorBody { @@ -424,7 +445,7 @@ impl CassandraSinkCluster { }), stream_id: metadata.stream_id, tracing: Tracing::Response(None), // We didn't actually hit a node so we don't have a tracing id - version: metadata.version, + version: self.version.unwrap(), warnings: vec![], }, ))), @@ -449,15 +470,20 @@ impl CassandraSinkCluster { responses_future.push_back(return_chan_rx) } - let mut responses = - super::connection::receive(self.read_timeout, &self.failed_requests, responses_future) - .await?; + let mut responses = super::connection::receive( + self.read_timeout, + &self.failed_requests, + responses_future, + self.version.unwrap(), + ) + .await?; { let mut prepare_responses = super::connection::receive( self.read_timeout, &self.failed_requests, responses_future_prepare, + self.version.unwrap(), ) .await?; @@ -635,7 +661,6 @@ impl CassandraSinkCluster { index, ty, warnings, - version: cassandra.version, selects: select.columns.clone(), }); } @@ -997,7 +1022,6 @@ impl CassandraSinkCluster { struct TableToRewrite { index: usize, ty: RewriteTableTy, - version: Version, selects: Vec, warnings: Vec, } diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/node_pool.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/node_pool.rs index 41e11c31f..313656671 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/node_pool.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/node_pool.rs @@ -134,7 +134,7 @@ impl NodePool { &mut self, execute: &BodyReqExecuteOwned, rack: &str, - version: &Version, + version: Version, rng: &mut SmallRng, ) -> Result, GetReplicaErr> { let metadata = { @@ -161,7 +161,7 @@ impl NodePool { execute.query_parameters.values.as_ref().ok_or_else(|| { GetReplicaErr::Other(anyhow!("Execute body does not have query parameters")) })?, - *version, + version, ) .unwrap(); diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/test_router.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/test_router.rs index 7973e2995..05578b0ed 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/test_router.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/test_router.rs @@ -81,7 +81,7 @@ mod test_token_aware_router { .get_replica_node_in_dc( &execute_body(id.clone(), query_parameters), "rack1", - &Version::V4, + Version::V4, &mut rng, ) .await diff --git a/shotover-proxy/src/transforms/cassandra/sink_single.rs b/shotover-proxy/src/transforms/cassandra/sink_single.rs index 5a34c25b0..6b7ae9844 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_single.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_single.rs @@ -1,12 +1,14 @@ use super::connection::CassandraConnection; use crate::codec::cassandra::CassandraCodec; use crate::error::ChainResponse; -use crate::message::Messages; +use crate::frame::cassandra::CassandraMetadata; +use crate::message::{Messages, Metadata}; use crate::tls::{TlsConnector, TlsConnectorConfig}; -use crate::transforms::util::Response; +use crate::transforms::cassandra::connection::Response; use crate::transforms::{Transform, Transforms, Wrapper}; -use anyhow::Result; +use anyhow::{anyhow, Result}; use async_trait::async_trait; +use cassandra_protocol::frame::Version; use futures::stream::FuturesOrdered; use metrics::{register_counter, Counter}; use serde::Deserialize; @@ -37,6 +39,7 @@ impl CassandraSinkSingleConfig { } pub struct CassandraSinkSingle { + version: Option, address: String, outbound: Option, chain_name: String, @@ -50,6 +53,7 @@ pub struct CassandraSinkSingle { impl Clone for CassandraSinkSingle { fn clone(&self) -> Self { CassandraSinkSingle { + version: self.version, address: self.address.clone(), outbound: None, chain_name: self.chain_name.clone(), @@ -74,6 +78,7 @@ impl CassandraSinkSingle { let receive_timeout = timeout.map(Duration::from_secs); CassandraSinkSingle { + version: None, address, outbound: None, chain_name, @@ -88,6 +93,25 @@ impl CassandraSinkSingle { impl CassandraSinkSingle { async fn send_message(&mut self, messages: Messages) -> ChainResponse { + if self.version.is_none() { + if let Some(message) = messages.first() { + if let Ok(Metadata::Cassandra(CassandraMetadata { version, .. })) = + message.metadata() + { + self.version = Some(version); + } else { + return Err(anyhow!( + "Failed to extract cassandra version from incoming message: Not a valid cassandra message" + )); + } + } else { + // It's an invariant that self.version is Some. + // Since we were unable to set it, we need to return immediately. + // This is ok because if there are no messages then we have no work to do anyway. + return Ok(vec![]); + } + } + if self.outbound.is_none() { trace!("creating outbound connection {:?}", self.address); self.outbound = Some( @@ -114,8 +138,13 @@ impl CassandraSinkSingle { }) .collect(); - super::connection::receive(self.read_timeout, &self.failed_requests, responses_future?) - .await + super::connection::receive( + self.read_timeout, + &self.failed_requests, + responses_future?, + self.version.unwrap(), + ) + .await } }