Skip to content

Commit

Permalink
CassandraConnection::send returns rx channel (#1091)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Mar 21, 2023
1 parent 71c02c7 commit cca5d2b
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 139 deletions.
6 changes: 4 additions & 2 deletions shotover-proxy/src/transforms/cassandra/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,17 @@ impl CassandraConnection {
///
/// If an internal invariant is broken the internal tasks may panic and external invariants will no longer be upheld.
/// But this indicates a bug within CassandraConnection and should be fixed here.
pub fn send(&self, message: Message, return_chan: oneshot::Sender<Response>) -> Result<()> {
pub fn send(&self, message: Message) -> Result<oneshot::Receiver<Response>> {
let (return_chan_tx, return_chan_rx) = oneshot::channel();
// Convert the message to `Request` and send upstream
if let Some(stream_id) = message.stream_id() {
self.connection
.send(Request {
message,
return_chan,
return_chan: return_chan_tx,
stream_id,
})
.map(|_| return_chan_rx)
.map_err(|x| x.into())
} else {
Err(anyhow!("no cassandra frame found"))
Expand Down
109 changes: 47 additions & 62 deletions shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ impl CassandraSinkCluster {
);
}

// This is purely an optimization: To avoid opening these connections sequentially later later on, we open them concurrently now.
// This is purely an optimization: To avoid opening these connections sequentially later on, we open them concurrently now.
try_join_all(
self.pool
.nodes()
Expand Down Expand Up @@ -400,18 +400,14 @@ impl CassandraSinkCluster {
let mut nodes_to_prepare_on: Vec<Uuid> = vec![];

for (i, mut message) in messages.into_iter().enumerate() {
let (return_chan_tx, return_chan_rx) = oneshot::channel();
if self.pool.nodes().is_empty()
let return_chan_rx = if self.pool.nodes().is_empty()
|| !self.init_handshake_complete
// system.local and system.peers must be routed to the same node otherwise the system.local node will be amongst the system.peers nodes and a node will be missing
// DDL statements and system.local must be routed through the same connection, so that schema_version changes appear immediately in system.local
|| is_ddl_statement(&mut message)
|| self.is_system_query(&mut message)
{
self.control_connection
.as_mut()
.unwrap()
.send(message, return_chan_tx)?;
self.control_connection.as_mut().unwrap().send(message)?
} else if is_use_statement(&mut message) {
// Adding the USE statement to the handshake ensures that any new connection
// created will have the correct keyspace setup.
Expand All @@ -420,18 +416,13 @@ impl CassandraSinkCluster {
// Send the USE statement to all open connections to ensure they are all in sync
for (node_index, node) in self.pool.nodes().iter().enumerate() {
if let Some(connection) = &node.outbound {
let (return_chan_tx, return_chan_rx) = oneshot::channel();
connection.send(message.clone(), return_chan_tx)?;
responses_future_use.push_back(return_chan_rx);
responses_future_use.push_back(connection.send(message.clone())?);
use_future_index_to_node_index.push(node_index);
}
}

// Send the USE statement to the handshake connection and use the response as shotovers response
self.control_connection
.as_mut()
.unwrap()
.send(message, return_chan_tx)?;
self.control_connection.as_mut().unwrap().send(message)?
} else if is_prepare_message(&mut message) {
if let Some(rewrite) = tables_to_rewrite.iter().find(|x| x.outgoing_index == i) {
if let RewriteTableTy::Prepare { destination_nodes } = &rewrite.ty {
Expand All @@ -448,41 +439,37 @@ impl CassandraSinkCluster {
.ok_or_else(|| anyhow!("node {next_host_id} has dissapeared"))?
.get_connection(&self.connection_factory)
.await?
.send(message, return_chan_tx)?;
} else {
.send(message)?
} else if let Some((execute, metadata)) = get_execute_message(&mut message) {
// If the message is an execute we should perform token aware routing
if let Some((execute, metadata)) = get_execute_message(&mut message) {
match self
match self
.pool
.get_replica_node_in_dc(
execute,
&self.local_shotover_node.rack,
self.version.unwrap(),
&mut self.rng,
)
.await
{
Ok(replica_node) => replica_node
.get_connection(&self.connection_factory)
.await?
.send(message)?,
Err(GetReplicaErr::NoReplicasFound | GetReplicaErr::NoKeyspaceMetadata) => self
.pool
.get_replica_node_in_dc(
execute,
&self.local_shotover_node.rack,
self.version.unwrap(),
&mut self.rng,
)
.await
{
Ok(replica_node) => {
replica_node
.get_connection(&self.connection_factory)
.await?
.send(message, return_chan_tx)?;
}
Err(GetReplicaErr::NoReplicasFound | GetReplicaErr::NoKeyspaceMetadata) => {
let node = self
.pool
.get_round_robin_node_in_dc_rack(&self.local_shotover_node.rack);
node.get_connection(&self.connection_factory)
.await?
.send(message, return_chan_tx)?;
}
Err(GetReplicaErr::NoPreparedMetadata) => {
let id = execute.id.clone();
tracing::info!("forcing re-prepare on {:?}", id);
// this shotover node doesn't have the metadata.
// send an unprepared error in response to force
// the client to reprepare the query
return_chan_tx
.get_round_robin_node_in_dc_rack(&self.local_shotover_node.rack)
.get_connection(&self.connection_factory)
.await?
.send(message)?,
Err(GetReplicaErr::NoPreparedMetadata) => {
let (return_chan_tx, return_chan_rx) = oneshot::channel();
let id = execute.id.clone();
tracing::info!("forcing re-prepare on {:?}", id);
// this shotover node doesn't have the metadata.
// send an unprepared error in response to force
// the client to reprepare the query
return_chan_tx
.send(Ok(Message::from_frame(Frame::Cassandra(
CassandraFrame {
operation: CassandraOperation::Error(ErrorBody {
Expand All @@ -495,22 +482,20 @@ impl CassandraSinkCluster {
warnings: vec![],
},
)))).expect("the receiver is guaranteed to be alive, so this must succeed");
}
Err(GetReplicaErr::Other(err)) => {
return Err(err);
}
};

// otherwise just send to a random node
} else {
let node = self
.pool
.get_round_robin_node_in_dc_rack(&self.local_shotover_node.rack);
node.get_connection(&self.connection_factory)
.await?
.send(message, return_chan_tx)?;
return_chan_rx
}
Err(GetReplicaErr::Other(err)) => {
return Err(err);
}
}
}
} else {
// otherwise just send to a random node
self.pool
.get_round_robin_node_in_dc_rack(&self.local_shotover_node.rack)
.get_connection(&self.connection_factory)
.await?
.send(message)?
};

responses_future.push_back(return_chan_rx)
}
Expand Down
29 changes: 15 additions & 14 deletions shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use derivative::Derivative;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::net::ToSocketAddrs;
use tokio::sync::{mpsc, oneshot};
use tokio::sync::mpsc;
use uuid::Uuid;

#[derive(Clone, Derivative)]
Expand Down Expand Up @@ -116,30 +116,31 @@ impl ConnectionFactory {
.map_err(|e| e.context("Failed to create new connection"))?;

for handshake_message in &self.init_handshake {
let (return_chan_tx, return_chan_rx) = oneshot::channel();
outbound
.send(handshake_message.clone(), return_chan_tx)
.send(handshake_message.clone())
.map_err(|e| {
anyhow!(e)
.context("Failed to initialize new connection with handshake, tx failed")
})?;
return_chan_rx.await.map_err(|e| {
anyhow!(e).context("Failed to initialize new connection with handshake, rx failed")
})??;
})?
.await
.map_err(|e| {
anyhow!(e)
.context("Failed to initialize new connection with handshake, rx failed")
})??;
}

if let Some(use_message) = &self.use_message {
let (return_chan_tx, return_chan_rx) = oneshot::channel();
outbound
.send(use_message.clone(), return_chan_tx)
.send(use_message.clone())
.map_err(|e| {
anyhow!(e)
.context("Failed to initialize new connection with use message, tx failed")
})?;
return_chan_rx.await.map_err(|e| {
anyhow!(e)
.context("Failed to initialize new connection with use message, rx failed")
})??;
})?
.await
.map_err(|e| {
anyhow!(e)
.context("Failed to initialize new connection with use message, rx failed")
})??;
}

Ok(outbound)
Expand Down
85 changes: 34 additions & 51 deletions shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use cassandra_protocol::token::Murmur3Token;
use std::collections::HashMap;
use std::net::SocketAddr;
use tokio::sync::mpsc::unbounded_channel;
use tokio::sync::{mpsc, oneshot, watch};
use tokio::sync::{mpsc, watch};

#[derive(Debug)]
pub struct TaskConnectionInfo {
Expand Down Expand Up @@ -188,27 +188,24 @@ async fn register_for_topology_and_status_events(
connection: &CassandraConnection,
version: Version,
) -> Result<()> {
let (tx, rx) = oneshot::channel();
connection
.send(
Message::from_frame(Frame::Cassandra(CassandraFrame {
version,
stream_id: 0,
tracing: Tracing::Request(false),
warnings: vec![],
operation: CassandraOperation::Register(BodyReqRegister {
events: vec![
SimpleServerEvent::TopologyChange,
SimpleServerEvent::StatusChange,
SimpleServerEvent::SchemaChange,
],
}),
})),
tx,
)
.unwrap();

if let Some(Frame::Cassandra(CassandraFrame { operation, .. })) = rx.await??.frame() {
let mut response = connection
.send(Message::from_frame(Frame::Cassandra(CassandraFrame {
version,
stream_id: 0,
tracing: Tracing::Request(false),
warnings: vec![],
operation: CassandraOperation::Register(BodyReqRegister {
events: vec![
SimpleServerEvent::TopologyChange,
SimpleServerEvent::StatusChange,
SimpleServerEvent::SchemaChange,
],
}),
})))
.unwrap()
.await??;

if let Some(Frame::Cassandra(CassandraFrame { operation, .. })) = response.frame() {
match operation {
CassandraOperation::Ready(_) => Ok(()),
operation => Err(anyhow!("Expected Cassandra to respond to a Register with a Ready. Instead it responded with {:?}", operation))
Expand Down Expand Up @@ -244,10 +241,8 @@ mod system_keyspaces {
data_center: &str,
version: Version,
) -> Result<HashMap<String, KeyspaceMetadata>> {
let (tx, rx) = oneshot::channel();

connection.send(
Message::from_frame(Frame::Cassandra(CassandraFrame {
let response = connection
.send(Message::from_frame(Frame::Cassandra(CassandraFrame {
version,
stream_id: 0,
tracing: Tracing::Request(false),
Expand All @@ -259,11 +254,8 @@ mod system_keyspaces {

params: Box::default(),
},
})),
tx,
)?;

let response = rx.await??;
})))?
.await??;
into_keyspaces(response, data_center)
}

Expand Down Expand Up @@ -373,9 +365,8 @@ mod system_local {
address: SocketAddr,
version: Version,
) -> Result<Vec<CassandraNode>> {
let (tx, rx) = oneshot::channel();
connection.send(
Message::from_frame(Frame::Cassandra(CassandraFrame {
let response = connection
.send(Message::from_frame(Frame::Cassandra(CassandraFrame {
version,
stream_id: 1,
tracing: Tracing::Request(false),
Expand All @@ -386,11 +377,10 @@ mod system_local {
)),
params: Box::default(),
},
})),
tx,
)?;
})))?
.await??;

into_nodes(rx.await??, data_center, address)
into_nodes(response, data_center, address)
}

fn into_nodes(
Expand Down Expand Up @@ -460,8 +450,7 @@ mod system_peers {
data_center: &str,
version: Version,
) -> Result<Vec<CassandraNode>> {
let (tx, rx) = oneshot::channel();
connection.send(
let mut response = connection.send(
Message::from_frame(Frame::Cassandra(CassandraFrame {
version,
stream_id: 0,
Expand All @@ -474,15 +463,11 @@ mod system_peers {
params: Box::default(),
},
})),
tx,
)?;

let mut response = rx.await??;
)?.await??;

if is_peers_v2_does_not_exist_error(&mut response) {
let (tx, rx) = oneshot::channel();
connection.send(
Message::from_frame(Frame::Cassandra(CassandraFrame {
response = connection
.send(Message::from_frame(Frame::Cassandra(CassandraFrame {
version,
stream_id: 0,
tracing: Tracing::Request(false),
Expand All @@ -493,10 +478,8 @@ mod system_peers {
)),
params: Box::default(),
},
})),
tx,
)?;
response = rx.await??;
})))?
.await??;
}

into_nodes(response, data_center)
Expand Down
Loading

0 comments on commit cca5d2b

Please sign in to comment.