diff --git a/shotover-proxy/src/transforms/cassandra/connection.rs b/shotover-proxy/src/transforms/cassandra/connection.rs index dcdda656e..03a12406e 100644 --- a/shotover-proxy/src/transforms/cassandra/connection.rs +++ b/shotover-proxy/src/transforms/cassandra/connection.rs @@ -145,15 +145,17 @@ impl CassandraConnection { /// /// If an internal invariant is broken the internal tasks may panic and external invariants will no longer be upheld. /// But this indicates a bug within CassandraConnection and should be fixed here. - pub fn send(&self, message: Message, return_chan: oneshot::Sender) -> Result<()> { + pub fn send(&self, message: Message) -> Result> { + let (return_chan_tx, return_chan_rx) = oneshot::channel(); // Convert the message to `Request` and send upstream if let Some(stream_id) = message.stream_id() { self.connection .send(Request { message, - return_chan, + return_chan: return_chan_tx, stream_id, }) + .map(|_| return_chan_rx) .map_err(|x| x.into()) } else { Err(anyhow!("no cassandra frame found")) diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs index 6c6a470ab..df0626f8b 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs @@ -342,7 +342,7 @@ impl CassandraSinkCluster { ); } - // This is purely an optimization: To avoid opening these connections sequentially later later on, we open them concurrently now. + // This is purely an optimization: To avoid opening these connections sequentially later on, we open them concurrently now. try_join_all( self.pool .nodes() @@ -400,18 +400,14 @@ impl CassandraSinkCluster { let mut nodes_to_prepare_on: Vec = vec![]; for (i, mut message) in messages.into_iter().enumerate() { - let (return_chan_tx, return_chan_rx) = oneshot::channel(); - if self.pool.nodes().is_empty() + let return_chan_rx = if self.pool.nodes().is_empty() || !self.init_handshake_complete // system.local and system.peers must be routed to the same node otherwise the system.local node will be amongst the system.peers nodes and a node will be missing // DDL statements and system.local must be routed through the same connection, so that schema_version changes appear immediately in system.local || is_ddl_statement(&mut message) || self.is_system_query(&mut message) { - self.control_connection - .as_mut() - .unwrap() - .send(message, return_chan_tx)?; + self.control_connection.as_mut().unwrap().send(message)? } else if is_use_statement(&mut message) { // Adding the USE statement to the handshake ensures that any new connection // created will have the correct keyspace setup. @@ -420,18 +416,13 @@ impl CassandraSinkCluster { // Send the USE statement to all open connections to ensure they are all in sync for (node_index, node) in self.pool.nodes().iter().enumerate() { if let Some(connection) = &node.outbound { - let (return_chan_tx, return_chan_rx) = oneshot::channel(); - connection.send(message.clone(), return_chan_tx)?; - responses_future_use.push_back(return_chan_rx); + responses_future_use.push_back(connection.send(message.clone())?); use_future_index_to_node_index.push(node_index); } } // Send the USE statement to the handshake connection and use the response as shotovers response - self.control_connection - .as_mut() - .unwrap() - .send(message, return_chan_tx)?; + self.control_connection.as_mut().unwrap().send(message)? } else if is_prepare_message(&mut message) { if let Some(rewrite) = tables_to_rewrite.iter().find(|x| x.outgoing_index == i) { if let RewriteTableTy::Prepare { destination_nodes } = &rewrite.ty { @@ -448,41 +439,37 @@ impl CassandraSinkCluster { .ok_or_else(|| anyhow!("node {next_host_id} has dissapeared"))? .get_connection(&self.connection_factory) .await? - .send(message, return_chan_tx)?; - } else { + .send(message)? + } else if let Some((execute, metadata)) = get_execute_message(&mut message) { // If the message is an execute we should perform token aware routing - if let Some((execute, metadata)) = get_execute_message(&mut message) { - match self + match self + .pool + .get_replica_node_in_dc( + execute, + &self.local_shotover_node.rack, + self.version.unwrap(), + &mut self.rng, + ) + .await + { + Ok(replica_node) => replica_node + .get_connection(&self.connection_factory) + .await? + .send(message)?, + Err(GetReplicaErr::NoReplicasFound | GetReplicaErr::NoKeyspaceMetadata) => self .pool - .get_replica_node_in_dc( - execute, - &self.local_shotover_node.rack, - self.version.unwrap(), - &mut self.rng, - ) - .await - { - Ok(replica_node) => { - replica_node - .get_connection(&self.connection_factory) - .await? - .send(message, return_chan_tx)?; - } - Err(GetReplicaErr::NoReplicasFound | GetReplicaErr::NoKeyspaceMetadata) => { - let node = self - .pool - .get_round_robin_node_in_dc_rack(&self.local_shotover_node.rack); - node.get_connection(&self.connection_factory) - .await? - .send(message, return_chan_tx)?; - } - Err(GetReplicaErr::NoPreparedMetadata) => { - let id = execute.id.clone(); - tracing::info!("forcing re-prepare on {:?}", id); - // this shotover node doesn't have the metadata. - // send an unprepared error in response to force - // the client to reprepare the query - return_chan_tx + .get_round_robin_node_in_dc_rack(&self.local_shotover_node.rack) + .get_connection(&self.connection_factory) + .await? + .send(message)?, + Err(GetReplicaErr::NoPreparedMetadata) => { + let (return_chan_tx, return_chan_rx) = oneshot::channel(); + let id = execute.id.clone(); + tracing::info!("forcing re-prepare on {:?}", id); + // this shotover node doesn't have the metadata. + // send an unprepared error in response to force + // the client to reprepare the query + return_chan_tx .send(Ok(Message::from_frame(Frame::Cassandra( CassandraFrame { operation: CassandraOperation::Error(ErrorBody { @@ -495,22 +482,20 @@ impl CassandraSinkCluster { warnings: vec![], }, )))).expect("the receiver is guaranteed to be alive, so this must succeed"); - } - Err(GetReplicaErr::Other(err)) => { - return Err(err); - } - }; - - // otherwise just send to a random node - } else { - let node = self - .pool - .get_round_robin_node_in_dc_rack(&self.local_shotover_node.rack); - node.get_connection(&self.connection_factory) - .await? - .send(message, return_chan_tx)?; + return_chan_rx + } + Err(GetReplicaErr::Other(err)) => { + return Err(err); + } } - } + } else { + // otherwise just send to a random node + self.pool + .get_round_robin_node_in_dc_rack(&self.local_shotover_node.rack) + .get_connection(&self.connection_factory) + .await? + .send(message)? + }; responses_future.push_back(return_chan_rx) } diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs index a04093bfb..da981765e 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs @@ -11,7 +11,7 @@ use derivative::Derivative; use std::net::SocketAddr; use std::time::Duration; use tokio::net::ToSocketAddrs; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::mpsc; use uuid::Uuid; #[derive(Clone, Derivative)] @@ -116,30 +116,31 @@ impl ConnectionFactory { .map_err(|e| e.context("Failed to create new connection"))?; for handshake_message in &self.init_handshake { - let (return_chan_tx, return_chan_rx) = oneshot::channel(); outbound - .send(handshake_message.clone(), return_chan_tx) + .send(handshake_message.clone()) .map_err(|e| { anyhow!(e) .context("Failed to initialize new connection with handshake, tx failed") - })?; - return_chan_rx.await.map_err(|e| { - anyhow!(e).context("Failed to initialize new connection with handshake, rx failed") - })??; + })? + .await + .map_err(|e| { + anyhow!(e) + .context("Failed to initialize new connection with handshake, rx failed") + })??; } if let Some(use_message) = &self.use_message { - let (return_chan_tx, return_chan_rx) = oneshot::channel(); outbound - .send(use_message.clone(), return_chan_tx) + .send(use_message.clone()) .map_err(|e| { anyhow!(e) .context("Failed to initialize new connection with use message, tx failed") - })?; - return_chan_rx.await.map_err(|e| { - anyhow!(e) - .context("Failed to initialize new connection with use message, rx failed") - })??; + })? + .await + .map_err(|e| { + anyhow!(e) + .context("Failed to initialize new connection with use message, rx failed") + })??; } Ok(outbound) diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs index 152f2c39c..08157ec0e 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs @@ -15,7 +15,7 @@ use cassandra_protocol::token::Murmur3Token; use std::collections::HashMap; use std::net::SocketAddr; use tokio::sync::mpsc::unbounded_channel; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, watch}; #[derive(Debug)] pub struct TaskConnectionInfo { @@ -188,27 +188,24 @@ async fn register_for_topology_and_status_events( connection: &CassandraConnection, version: Version, ) -> Result<()> { - let (tx, rx) = oneshot::channel(); - connection - .send( - Message::from_frame(Frame::Cassandra(CassandraFrame { - version, - stream_id: 0, - tracing: Tracing::Request(false), - warnings: vec![], - operation: CassandraOperation::Register(BodyReqRegister { - events: vec![ - SimpleServerEvent::TopologyChange, - SimpleServerEvent::StatusChange, - SimpleServerEvent::SchemaChange, - ], - }), - })), - tx, - ) - .unwrap(); - - if let Some(Frame::Cassandra(CassandraFrame { operation, .. })) = rx.await??.frame() { + let mut response = connection + .send(Message::from_frame(Frame::Cassandra(CassandraFrame { + version, + stream_id: 0, + tracing: Tracing::Request(false), + warnings: vec![], + operation: CassandraOperation::Register(BodyReqRegister { + events: vec![ + SimpleServerEvent::TopologyChange, + SimpleServerEvent::StatusChange, + SimpleServerEvent::SchemaChange, + ], + }), + }))) + .unwrap() + .await??; + + if let Some(Frame::Cassandra(CassandraFrame { operation, .. })) = response.frame() { match operation { CassandraOperation::Ready(_) => Ok(()), operation => Err(anyhow!("Expected Cassandra to respond to a Register with a Ready. Instead it responded with {:?}", operation)) @@ -244,10 +241,8 @@ mod system_keyspaces { data_center: &str, version: Version, ) -> Result> { - let (tx, rx) = oneshot::channel(); - - connection.send( - Message::from_frame(Frame::Cassandra(CassandraFrame { + let response = connection + .send(Message::from_frame(Frame::Cassandra(CassandraFrame { version, stream_id: 0, tracing: Tracing::Request(false), @@ -259,11 +254,8 @@ mod system_keyspaces { params: Box::default(), }, - })), - tx, - )?; - - let response = rx.await??; + })))? + .await??; into_keyspaces(response, data_center) } @@ -373,9 +365,8 @@ mod system_local { address: SocketAddr, version: Version, ) -> Result> { - let (tx, rx) = oneshot::channel(); - connection.send( - Message::from_frame(Frame::Cassandra(CassandraFrame { + let response = connection + .send(Message::from_frame(Frame::Cassandra(CassandraFrame { version, stream_id: 1, tracing: Tracing::Request(false), @@ -386,11 +377,10 @@ mod system_local { )), params: Box::default(), }, - })), - tx, - )?; + })))? + .await??; - into_nodes(rx.await??, data_center, address) + into_nodes(response, data_center, address) } fn into_nodes( @@ -460,8 +450,7 @@ mod system_peers { data_center: &str, version: Version, ) -> Result> { - let (tx, rx) = oneshot::channel(); - connection.send( + let mut response = connection.send( Message::from_frame(Frame::Cassandra(CassandraFrame { version, stream_id: 0, @@ -474,15 +463,11 @@ mod system_peers { params: Box::default(), }, })), - tx, - )?; - - let mut response = rx.await??; + )?.await??; if is_peers_v2_does_not_exist_error(&mut response) { - let (tx, rx) = oneshot::channel(); - connection.send( - Message::from_frame(Frame::Cassandra(CassandraFrame { + response = connection + .send(Message::from_frame(Frame::Cassandra(CassandraFrame { version, stream_id: 0, tracing: Tracing::Request(false), @@ -493,10 +478,8 @@ mod system_peers { )), params: Box::default(), }, - })), - tx, - )?; - response = rx.await??; + })))? + .await??; } into_nodes(response, data_center) diff --git a/shotover-proxy/src/transforms/cassandra/sink_single.rs b/shotover-proxy/src/transforms/cassandra/sink_single.rs index 3f03018f4..1a29b201f 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_single.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_single.rs @@ -141,15 +141,8 @@ impl CassandraSinkSingle { trace!("sending frame upstream"); let outbound = self.outbound.as_mut().unwrap(); - let responses_future: Result>> = messages - .into_iter() - .map(|m| { - let (return_chan_tx, return_chan_rx) = oneshot::channel(); - outbound.send(m, return_chan_tx)?; - - Ok(return_chan_rx) - }) - .collect(); + let responses_future: Result>> = + messages.into_iter().map(|m| outbound.send(m)).collect(); super::connection::receive(self.read_timeout, &self.failed_requests, responses_future?) .await diff --git a/shotover-proxy/tests/cassandra_int_tests/cluster/single_rack_v4.rs b/shotover-proxy/tests/cassandra_int_tests/cluster/single_rack_v4.rs index 7d69eb502..5e074ad66 100644 --- a/shotover-proxy/tests/cassandra_int_tests/cluster/single_rack_v4.rs +++ b/shotover-proxy/tests/cassandra_int_tests/cluster/single_rack_v4.rs @@ -295,7 +295,7 @@ pub async fn test_node_going_down(compose: &DockerCompose, driver: CassandraDriv { // stop one of the containers to trigger a status change event. - // event_connection_direct is connecting to cassandra-one, so make sure to instead kill caassandra-two. + // event_connections.direct is connecting to cassandra-one, so make sure to instead kill caassandra-two. compose.stop_service("cassandra-two"); assert_down_event(&mut event_connections).await;