diff --git a/shotover/src/transforms/cassandra/sink_cluster/mod.rs b/shotover/src/transforms/cassandra/sink_cluster/mod.rs index 07edbbf74..c05e63331 100644 --- a/shotover/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover/src/transforms/cassandra/sink_cluster/mod.rs @@ -1,6 +1,6 @@ use self::connection::CassandraConnection; use self::node_pool::{get_accessible_owned_connection, NodePoolBuilder, PreparedMetadata}; -use self::rewrite::MessageRewriter; +use self::rewrite::{BatchMode, MessageRewriter}; use crate::frame::cassandra::{CassandraMetadata, Tracing}; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; use crate::message::{Message, MessageIdMap, Messages, Metadata}; @@ -279,7 +279,8 @@ impl CassandraSinkCluster { let mut responses = vec![]; - self.message_rewriter + let batch_mode = self + .message_rewriter .rewrite_requests( &mut messages, &self.connection_factory, @@ -288,6 +289,19 @@ impl CassandraSinkCluster { ) .await?; + if let BatchMode::Isolated = batch_mode { + if let Some(connection) = self.control_connection.as_mut() { + connection + .recv_all_pending(&mut responses, self.version.unwrap()) + .await + .ok(); + } + for node in self.pool.nodes_mut().iter_mut() { + node.recv_all_pending(&mut responses, self.version.unwrap()) + .await; + } + } + // Create the initial connection. // Messages will be sent through this connection until we have extracted the handshake. if self.control_connection.is_none() { @@ -486,15 +500,29 @@ impl CassandraSinkCluster { } // receive messages from all connections - if let Some(connection) = self.control_connection.as_mut() { - connection - .recv_all_pending(&mut responses, self.version.unwrap()) - .await - .ok(); - } - for node in self.pool.nodes_mut().iter_mut() { - node.recv_all_pending(&mut responses, self.version.unwrap()) - .await; + match batch_mode { + BatchMode::Isolated => { + if let Some(connection) = self.control_connection.as_mut() { + connection + .recv_all_pending(&mut responses, self.version.unwrap()) + .await + .ok(); + } + for node in self.pool.nodes_mut().iter_mut() { + node.recv_all_pending(&mut responses, self.version.unwrap()) + .await; + } + } + BatchMode::Pipelined => { + if let Some(connection) = self.control_connection.as_mut() { + connection + .try_recv(&mut responses, self.version.unwrap()) + .ok(); + } + for node in self.pool.nodes_mut().iter_mut() { + node.try_recv(&mut responses, self.version.unwrap()); + } + } } // When the server indicates that it is ready for normal operation via Ready or AuthSuccess, diff --git a/shotover/src/transforms/cassandra/sink_cluster/rewrite.rs b/shotover/src/transforms/cassandra/sink_cluster/rewrite.rs index d2612101f..ac54a76bb 100644 --- a/shotover/src/transforms/cassandra/sink_cluster/rewrite.rs +++ b/shotover/src/transforms/cassandra/sink_cluster/rewrite.rs @@ -59,6 +59,16 @@ pub struct MessageRewriter { pub prepare_requests_to_destination_nodes: MessageIdMap, } +pub enum BatchMode { + /// When processing the current request batch ensure that: + /// * all responses are flushed before sending the requests + /// * all responses are flushed after sending the requests + /// This must be done to ensure that stream_ids generated by shotover do not collide with stream_ids generated by the client. + Isolated, + /// There are no extra requirements on the current request batch. + Pipelined, +} + impl MessageRewriter { /// Insert any extra requests into requests. /// A Vec is returned which keeps track of the requests added @@ -69,7 +79,8 @@ impl MessageRewriter { connection_factory: &ConnectionFactory, pool: &mut NodePool, version: Version, - ) -> Result<()> { + ) -> Result { + let mut batch_mode = BatchMode::Pipelined; let mut new_rewrites: Vec<_> = messages .iter_mut() .enumerate() @@ -86,6 +97,7 @@ impl MessageRewriter { .collected_messages .push(MessageOrId::Id(message.id())); messages.push(message); + batch_mode = BatchMode::Isolated; } RewriteTableTy::Peers => { let query = "SELECT rack, data_center, schema_version, tokens, release_version FROM system.peers"; @@ -101,6 +113,7 @@ impl MessageRewriter { .collected_messages .push(MessageOrId::Id(message.id())); messages.push(message); + batch_mode = BatchMode::Isolated; } RewriteTableTy::Prepare { clone_index } => { let mut first = true; @@ -135,12 +148,16 @@ impl MessageRewriter { .map(|node| node.get_connection(connection_factory)), ) .await?; + // Theoretically prepare statements shouldnt need to be isolated, + // but for some reason we are hitting issues with Pipelined, so for now we use Isolated. + // In the future we should investigate using Pipelined instead. + batch_mode = BatchMode::Isolated; } } } self.to_rewrite.extend(new_rewrites); - Ok(()) + Ok(batch_mode) } /// Returns any information required to correctly rewrite the response. @@ -668,6 +685,7 @@ fn create_query(messages: &Messages, query: &str, version: Version) -> Result Result { // start at an unusual number to hopefully avoid looping many times when we receive stream ids that look like [0, 1, 2, ..] // We can quite happily give up 358 stream ids as that still allows for shotover message batches containing 2 ** 16 - 358 = 65178 messages