diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs index 4f9160274..8dca9fdae 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs @@ -24,9 +24,10 @@ use futures::StreamExt; use itertools::Itertools; use metrics::{register_counter, Counter}; use node::{CassandraNode, ConnectionFactory}; -use node_pool::{GetReplicaErr, NodePool}; +use node_pool::{GetReplicaErr, KeyspaceMetadata, NodePool}; use rand::prelude::*; use serde::Deserialize; +use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::time::Duration; use tokio::sync::{mpsc, oneshot, watch}; @@ -42,6 +43,9 @@ mod test_router; mod token_map; pub mod topology; +pub type KeyspaceChanTx = watch::Sender>; +pub type KeyspaceChanRx = watch::Receiver>; + #[derive(Deserialize, Debug, Clone)] pub struct CassandraSinkClusterConfig { /// contact points must be within the specified data_center and rack. @@ -120,6 +124,7 @@ pub struct CassandraSinkCluster { /// Addditionally any changes to nodes_rx is observed and copied over. pool: NodePool, nodes_rx: watch::Receiver>, + keyspaces_rx: KeyspaceChanRx, rng: SmallRng, task_handshake_tx: mpsc::Sender, } @@ -145,6 +150,7 @@ impl Clone for CassandraSinkCluster { // Because the self.nodes_rx is always copied from the original nodes_rx created before any node lists were sent, // once a single node list has been sent all new connections will immediately recognize it as a change. nodes_rx: self.nodes_rx.clone(), + keyspaces_rx: self.keyspaces_rx.clone(), rng: SmallRng::from_rng(rand::thread_rng()).unwrap(), task_handshake_tx: self.task_handshake_tx.clone(), } @@ -167,10 +173,14 @@ impl CassandraSinkCluster { let connect_timeout = Duration::from_millis(connect_timeout_ms); let (local_nodes_tx, local_nodes_rx) = watch::channel(vec![]); + let (keyspaces_tx, keyspaces_rx): (KeyspaceChanTx, KeyspaceChanRx) = + watch::channel(HashMap::new()); + let (task_handshake_tx, task_handshake_rx) = mpsc::channel(1); create_topology_task( local_nodes_tx, + keyspaces_tx, task_handshake_rx, local_shotover_node.data_center.clone(), ); @@ -196,6 +206,7 @@ impl CassandraSinkCluster { local_shotover_node, pool: NodePool::new(vec![]), nodes_rx: local_nodes_rx, + keyspaces_rx, rng: SmallRng::from_rng(rand::thread_rng()).unwrap(), task_handshake_tx, } @@ -240,6 +251,10 @@ impl CassandraSinkCluster { } } + if self.keyspaces_rx.has_changed()? { + self.pool.update_keyspaces(&mut self.keyspaces_rx).await; + } + let tables_to_rewrite: Vec = messages .iter_mut() .enumerate() @@ -373,7 +388,6 @@ impl CassandraSinkCluster { &self.local_shotover_node.rack, &metadata.version, &mut self.rng, - 1, ) .await { @@ -383,7 +397,7 @@ impl CassandraSinkCluster { .await? .send(message, return_chan_tx)?; } - Ok(None) => { + Ok(None) | Err(GetReplicaErr::NoKeyspaceMetadata) => { let node = self .pool .get_round_robin_node_in_dc_rack(&self.local_shotover_node.rack); @@ -391,7 +405,7 @@ impl CassandraSinkCluster { .await? .send(message, return_chan_tx)?; } - Err(GetReplicaErr::NoMetadata) => { + Err(GetReplicaErr::NoPreparedMetadata) => { let id = execute.id.clone(); tracing::info!("forcing re-prepare on {:?}", id); // this shotover node doesn't have the metadata diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs index d05e6da46..c63290592 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs @@ -18,12 +18,12 @@ use uuid::Uuid; pub struct CassandraNode { pub address: SocketAddr, pub rack: String, - - #[derivative(Debug = "ignore")] - pub tokens: Vec, pub outbound: Option, pub host_id: Uuid, pub is_up: bool, + + #[derivative(Debug = "ignore")] + pub tokens: Vec, } impl CassandraNode { diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/node_pool.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/node_pool.rs index b09f50cd6..41e11c31f 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/node_pool.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/node_pool.rs @@ -1,6 +1,7 @@ +use super::node::CassandraNode; use super::routing_key::calculate_routing_key; use super::token_map::TokenMap; -use crate::transforms::cassandra::sink_cluster::node::CassandraNode; +use super::KeyspaceChanRx; use anyhow::{anyhow, Error, Result}; use cassandra_protocol::frame::message_execute::BodyReqExecuteOwned; use cassandra_protocol::frame::message_result::PreparedMetadata; @@ -15,13 +16,20 @@ use tokio::sync::{watch, RwLock}; #[derive(Debug)] pub enum GetReplicaErr { - NoMetadata, + NoPreparedMetadata, + NoKeyspaceMetadata, Other(Error), } +#[derive(Debug, Clone, PartialEq)] +pub struct KeyspaceMetadata { + pub replication_factor: usize, +} + #[derive(Debug)] pub struct NodePool { prepared_metadata: Arc>>, + keyspace_metadata: HashMap, token_map: TokenMap, nodes: Vec, prev_idx: usize, @@ -31,6 +39,7 @@ impl Clone for NodePool { fn clone(&self) -> Self { Self { prepared_metadata: self.prepared_metadata.clone(), + keyspace_metadata: self.keyspace_metadata.clone(), token_map: TokenMap::new(&[]), nodes: vec![], prev_idx: 0, @@ -44,6 +53,7 @@ impl NodePool { token_map: TokenMap::new(nodes.as_slice()), nodes, prepared_metadata: Arc::new(RwLock::new(HashMap::new())), + keyspace_metadata: HashMap::new(), prev_idx: 0, } } @@ -70,6 +80,11 @@ impl NodePool { self.token_map = TokenMap::new(self.nodes.as_slice()); } + pub async fn update_keyspaces(&mut self, keyspaces_rx: &mut KeyspaceChanRx) { + let updated_keyspaces = keyspaces_rx.borrow_and_update().clone(); + self.keyspace_metadata = updated_keyspaces; + } + pub async fn add_prepared_result(&mut self, id: CBytesShort, metadata: PreparedMetadata) { let mut write_lock = self.prepared_metadata.write().await; write_lock.insert(id, metadata); @@ -121,29 +136,41 @@ impl NodePool { rack: &str, version: &Version, rng: &mut SmallRng, - rf: usize, // TODO this parameter should be removed ) -> Result, GetReplicaErr> { let metadata = { let read_lock = self.prepared_metadata.read().await; read_lock .get(&execute.id) - .ok_or(GetReplicaErr::NoMetadata)? + .ok_or(GetReplicaErr::NoPreparedMetadata)? .clone() }; + let keyspace = self + .keyspace_metadata + .get( + &metadata + .global_table_spec + .as_ref() + .ok_or(GetReplicaErr::NoKeyspaceMetadata)? + .ks_name, + ) + .ok_or(GetReplicaErr::NoKeyspaceMetadata)?; + let routing_key = calculate_routing_key( &metadata.pk_indexes, execute.query_parameters.values.as_ref().ok_or_else(|| { - GetReplicaErr::Other(anyhow!("Execute body does not have query paramters")) + GetReplicaErr::Other(anyhow!("Execute body does not have query parameters")) })?, *version, ) .unwrap(); - // TODO this should use the keyspace info to properly select the replica count let replica_host_ids = self .token_map - .iter_replica_nodes(Murmur3Token::generate(&routing_key), rf) + .iter_replica_nodes_capped( + Murmur3Token::generate(&routing_key), + keyspace.replication_factor, + ) .collect::>(); let (dc_replicas, rack_replicas) = self diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/test_router.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/test_router.rs index bc8f14378..7973e2995 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/test_router.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/test_router.rs @@ -1,8 +1,9 @@ #[cfg(test)] mod test_token_aware_router { - use super::super::node_pool::NodePool; + use super::super::node_pool::{KeyspaceMetadata, NodePool}; use super::super::routing_key::calculate_routing_key; use crate::transforms::cassandra::sink_cluster::node::CassandraNode; + use crate::transforms::cassandra::sink_cluster::{KeyspaceChanRx, KeyspaceChanTx}; use cassandra_protocol::consistency::Consistency::One; use cassandra_protocol::frame::message_execute::BodyReqExecuteOwned; use cassandra_protocol::frame::message_result::PreparedMetadata; @@ -18,6 +19,7 @@ mod test_token_aware_router { use rand::prelude::*; use std::collections::HashMap; use std::net::SocketAddr; + use tokio::sync::watch; use uuid::Uuid; #[tokio::test] @@ -30,6 +32,23 @@ mod test_token_aware_router { 11, 241, 38, 11, 140, 72, 217, 34, 214, 128, 175, 241, 151, 73, 197, 227, ]); + let keyspace_metadata = KeyspaceMetadata { + replication_factor: 3, + }; + + let (keyspaces_tx, mut keyspaces_rx): (KeyspaceChanTx, KeyspaceChanRx) = + watch::channel(HashMap::new()); + + keyspaces_tx + .send( + [("demo_ks".to_string(), keyspace_metadata)] + .into_iter() + .collect(), + ) + .unwrap(); + + router.update_keyspaces(&mut keyspaces_rx).await; + router .add_prepared_result(id.clone(), prepared_metadata().clone()) .await; @@ -64,7 +83,6 @@ mod test_token_aware_router { "rack1", &Version::V4, &mut rng, - 3, ) .await .unwrap() diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/token_map.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/token_map.rs index 2932d9fe5..fae5fd981 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/token_map.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/token_map.rs @@ -19,7 +19,7 @@ impl TokenMap { } /// Returns nodes starting at given token and going in the direction of replicas. - pub fn iter_replica_nodes( + pub fn iter_replica_nodes_capped( &self, token: Murmur3Token, replica_count: usize, @@ -98,7 +98,7 @@ mod test_token_map { fn verify_tokens(node_host_ids: &[Uuid], token: Murmur3Token) { let token_map = TokenMap::new(prepare_nodes().as_slice()); let nodes = token_map - .iter_replica_nodes(token, node_host_ids.len()) + .iter_replica_nodes_capped(token, node_host_ids.len()) .collect_vec(); assert_eq!(nodes, node_host_ids); diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs index 170d54a7c..d47594ed0 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs @@ -1,4 +1,6 @@ use super::node::{CassandraNode, ConnectionFactory}; +use super::node_pool::KeyspaceMetadata; +use super::KeyspaceChanTx; use crate::frame::cassandra::{parse_statement_single, Tracing}; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; use crate::message::{Message, MessageValue}; @@ -9,6 +11,7 @@ use cassandra_protocol::frame::events::{StatusChangeType, TopologyChangeType}; use cassandra_protocol::frame::message_register::BodyReqRegister; use cassandra_protocol::token::Murmur3Token; use cassandra_protocol::{frame::Version, query::QueryParams}; +use std::collections::HashMap; use std::net::SocketAddr; use tokio::sync::mpsc::unbounded_channel; use tokio::sync::{mpsc, oneshot, watch}; @@ -21,13 +24,21 @@ pub struct TaskConnectionInfo { pub fn create_topology_task( nodes_tx: watch::Sender>, + keyspaces_tx: KeyspaceChanTx, mut connection_info_rx: mpsc::Receiver, data_center: String, ) { tokio::spawn(async move { while let Some(mut connection_info) = connection_info_rx.recv().await { let mut attempts = 0; - match topology_task_process(&nodes_tx, &mut connection_info, &data_center).await { + match topology_task_process( + &nodes_tx, + &keyspaces_tx, + &mut connection_info, + &data_center, + ) + .await + { Err(err) => { tracing::error!("topology task failed, retrying, error was: {err:?}"); attempts += 1; @@ -47,6 +58,7 @@ pub fn create_topology_task( async fn topology_task_process( nodes_tx: &watch::Sender>, + keyspaces_tx: &KeyspaceChanTx, connection_info: &mut TaskConnectionInfo, data_center: &str, ) -> Result<()> { @@ -62,11 +74,16 @@ async fn topology_task_process( let version = connection_info.connection_factory.get_version()?; - let mut nodes = fetch_current_nodes(&connection, connection_info, data_center).await?; + let mut nodes = fetch_current_nodes(&connection, connection_info, data_center, version).await?; if let Err(watch::error::SendError(_)) = nodes_tx.send(nodes.clone()) { return Ok(()); } + let mut keyspaces = system_keyspaces::query(&connection, data_center, version).await?; + if let Err(watch::error::SendError(_)) = keyspaces_tx.send(keyspaces.clone()) { + return Ok(()); + } + register_for_topology_and_status_events(&connection, version).await?; loop { @@ -91,6 +108,7 @@ async fn topology_task_process( &connection, connection_info, data_center, + version, ) .await?; @@ -106,9 +124,21 @@ async fn topology_task_process( } nodes = new_nodes; + + if let Err(watch::error::SendError(_)) = + nodes_tx.send(nodes.clone()) + { + return Ok(()); + } } TopologyChangeType::RemovedNode => { - nodes.retain(|node| node.address != topology.addr) + nodes.retain(|node| node.address != topology.addr); + + if let Err(watch::error::SendError(_)) = + nodes_tx.send(nodes.clone()) + { + return Ok(()); + } } }, ServerEvent::StatusChange(status) => { @@ -120,17 +150,28 @@ async fn topology_task_process( } } } + if let Err(watch::error::SendError(_)) = + nodes_tx.send(nodes.clone()) + { + return Ok(()); + } + } + ServerEvent::SchemaChange(_change) => { + keyspaces = + system_keyspaces::query(&connection, data_center, version) + .await?; + if let Err(watch::error::SendError(_)) = + keyspaces_tx.send(keyspaces.clone()) + { + return Ok(()); + } } - event => tracing::error!("Unexpected event: {:?}", event), } } } } None => return Err(anyhow!("topology control connection was closed")), } - if let Err(watch::error::SendError(_)) = nodes_tx.send(nodes.clone()) { - return Ok(()); - } } } @@ -150,6 +191,7 @@ async fn register_for_topology_and_status_events( events: vec![ SimpleServerEvent::TopologyChange, SimpleServerEvent::StatusChange, + SimpleServerEvent::SchemaChange, ], }), })), @@ -171,10 +213,11 @@ async fn fetch_current_nodes( connection: &CassandraConnection, connection_info: &TaskConnectionInfo, data_center: &str, + version: Version, ) -> Result> { let (new_nodes, more_nodes) = tokio::join!( - system_local::query(connection, data_center, connection_info.address), - system_peers::query(connection, data_center) + system_local::query(connection, data_center, connection_info.address, version), + system_peers::query(connection, data_center, version) ); let mut new_nodes = new_nodes?; @@ -183,6 +226,135 @@ async fn fetch_current_nodes( Ok(new_nodes) } +mod system_keyspaces { + use super::*; + use std::str::FromStr; + + pub async fn query( + connection: &CassandraConnection, + data_center: &str, + version: Version, + ) -> Result> { + let (tx, rx) = oneshot::channel(); + + connection.send( + Message::from_frame(Frame::Cassandra(CassandraFrame { + version, + stream_id: 0, + tracing: Tracing::Request(false), + warnings: vec![], + operation: CassandraOperation::Query { + query: Box::new(parse_statement_single( + "SELECT keyspace_name, replication FROM system_schema.keyspaces", + )), + + params: Box::new(QueryParams::default()), + }, + })), + tx, + )?; + + let response = rx.await?.response?; + into_keyspaces(response, data_center) + } + + fn into_keyspaces( + mut response: Message, + data_center: &str, + ) -> Result> { + if let Some(Frame::Cassandra(frame)) = response.frame() { + match &mut frame.operation { + CassandraOperation::Result(CassandraResult::Rows { rows, .. }) => rows + .drain(..) + .map(|row| build_keyspace(row, data_center)) + .collect(), + operation => Err(anyhow!( + "keyspace query returned unexpected cassandra operation: {:?}", + operation + )), + } + } else { + Err(anyhow!("Failed to parse keyspace query response")) + } + } + + pub fn build_keyspace( + mut row: Vec, + data_center: &str, + ) -> Result<(String, KeyspaceMetadata)> { + let metadata = if let Some(MessageValue::Map(mut replication_strategy)) = row.pop() { + let strategy_name: String = match replication_strategy + .remove(&MessageValue::Varchar("class".into())) + .ok_or_else(|| anyhow!("replication strategy map should have a 'class' field",))? + { + MessageValue::Varchar(name) => name, + _ => return Err(anyhow!("'class' field should be a varchar")), + }; + + match strategy_name.as_str() { + "org.apache.cassandra.locator.SimpleStrategy" | "SimpleStrategy" => { + let rf_str: String = + match replication_strategy.remove(&MessageValue::Varchar("replication_factor".into())).ok_or_else(|| + anyhow!("SimpleStrategy in replication strategy map does not have a replication factor") + )?{ + MessageValue::Varchar(rf) => rf, + _ => return Err(anyhow!("SimpleStrategy replication factor should be a varchar ")) + }; + + let replication_factor: usize = usize::from_str(&rf_str).map_err(|_| { + anyhow!("Could not parse replication factor as an integer",) + })?; + + KeyspaceMetadata { replication_factor } + } + "org.apache.cassandra.locator.NetworkTopologyStrategy" + | "NetworkTopologyStrategy" => { + let data_center_rf = match replication_strategy + .remove(&MessageValue::Varchar(data_center.into())) + { + Some(MessageValue::Varchar(rf_str)) => { + usize::from_str(&rf_str).map_err(|_| { + anyhow!("Could not parse replication factor as an integer",) + })? + } + Some(_other) => { + return Err(anyhow!( + "NetworkTopologyStrategy replication factor should be a varchar" + )) + } + None => 0, + }; + + KeyspaceMetadata { + replication_factor: data_center_rf, + } + } + "org.apache.cassandra.locator.LocalStrategy" | "LocalStrategy" => { + KeyspaceMetadata { + replication_factor: 1, + } + } + _ => { + tracing::warn!("Unrecognised replication strategy: {strategy_name:?}"); + KeyspaceMetadata { + replication_factor: 1, + } + } + } + } else { + return Err(anyhow!("replication strategy should be a map")); + }; + + let name = if let Some(MessageValue::Varchar(name)) = row.pop() { + name + } else { + return Err(anyhow!("system_schema_keyspaces.name should be a varchar")); + }; + + Ok((name, metadata)) + } +} + mod system_local { use super::*; @@ -190,11 +362,12 @@ mod system_local { connection: &CassandraConnection, data_center: &str, address: SocketAddr, + version: Version, ) -> Result> { let (tx, rx) = oneshot::channel(); connection.send( Message::from_frame(Frame::Cassandra(CassandraFrame { - version: Version::V4, + version, stream_id: 1, tracing: Tracing::Request(false), warnings: vec![], @@ -276,11 +449,12 @@ mod system_peers { pub async fn query( connection: &CassandraConnection, data_center: &str, + version: Version, ) -> Result> { let (tx, rx) = oneshot::channel(); connection.send( Message::from_frame(Frame::Cassandra(CassandraFrame { - version: Version::V4, + version, stream_id: 0, tracing: Tracing::Request(false), warnings: vec![], @@ -300,7 +474,7 @@ mod system_peers { let (tx, rx) = oneshot::channel(); connection.send( Message::from_frame(Frame::Cassandra(CassandraFrame { - version: Version::V4, + version, stream_id: 0, tracing: Tracing::Request(false), warnings: vec![], @@ -413,3 +587,75 @@ mod system_peers { } } } + +#[cfg(test)] +mod test_system_keyspaces { + use super::*; + + #[test] + fn test_simple() { + let row = vec![ + MessageValue::Varchar("test".into()), + MessageValue::Map( + vec![ + ( + MessageValue::Varchar("class".into()), + MessageValue::Varchar("org.apache.cassandra.locator.SimpleStrategy".into()), + ), + ( + MessageValue::Varchar("replication_factor".into()), + MessageValue::Varchar("2".into()), + ), + ] + .into_iter() + .collect(), + ), + ]; + + let result = system_keyspaces::build_keyspace(row, "dc1").unwrap(); + assert_eq!( + result, + ( + "test".into(), + KeyspaceMetadata { + replication_factor: 2 + } + ) + ) + } + + #[test] + fn test_network() { + let row = vec![ + MessageValue::Varchar("test".into()), + MessageValue::Map( + vec![ + ( + MessageValue::Varchar("class".into()), + MessageValue::Varchar( + "org.apache.cassandra.locator.NetworkTopologyStrategy".into(), + ), + ), + ( + MessageValue::Varchar("dc1".into()), + MessageValue::Varchar("3".into()), + ), + ] + .into_iter() + .collect(), + ), + ]; + + let result = system_keyspaces::build_keyspace(row, "dc1").unwrap(); + + assert_eq!( + result, + ( + "test".into(), + KeyspaceMetadata { + replication_factor: 3 + } + ) + ) + } +} diff --git a/shotover-proxy/tests/cassandra_int_tests/cluster/mod.rs b/shotover-proxy/tests/cassandra_int_tests/cluster/mod.rs index cd978cebc..42f3bc25b 100644 --- a/shotover-proxy/tests/cassandra_int_tests/cluster/mod.rs +++ b/shotover-proxy/tests/cassandra_int_tests/cluster/mod.rs @@ -6,6 +6,7 @@ use shotover_proxy::transforms::cassandra::sink_cluster::{ node::{CassandraNode, ConnectionFactory}, topology::{create_topology_task, TaskConnectionInfo}, }; +use std::collections::HashMap; use std::time::Duration; use tokio::sync::{mpsc, watch}; use tokio::time::timeout; @@ -17,6 +18,7 @@ pub mod single_rack_v4; pub async fn run_topology_task(ca_path: Option<&str>, port: Option) -> Vec { let port = port.unwrap_or(9042); let (nodes_tx, mut nodes_rx) = watch::channel(vec![]); + let (keyspaces_tx, _keyspaces_rx) = watch::channel(HashMap::new()); let (task_handshake_tx, task_handshake_rx) = mpsc::channel(1); let tls = ca_path.map(|ca_path| { TlsConnector::new(TlsConnectorConfig { @@ -33,7 +35,7 @@ pub async fn run_topology_task(ca_path: Option<&str>, port: Option) -> Vec< connection_factory.push_handshake_message(message); } - create_topology_task(nodes_tx, task_handshake_rx, "dc1".to_string()); + create_topology_task(nodes_tx, keyspaces_tx, task_handshake_rx, "dc1".to_string()); // Give the handshake task a hardcoded handshake. // Normally the handshake is the handshake that the client gave shotover. diff --git a/shotover-proxy/tests/cassandra_int_tests/prepared_statements.rs b/shotover-proxy/tests/cassandra_int_tests/prepared_statements.rs index 48d070177..cde171e64 100644 --- a/shotover-proxy/tests/cassandra_int_tests/prepared_statements.rs +++ b/shotover-proxy/tests/cassandra_int_tests/prepared_statements.rs @@ -1,8 +1,8 @@ -use futures::Future; - +use crate::helpers::cassandra::CassandraDriver; use crate::helpers::cassandra::{ assert_query_result, assert_rows, run_query, CassandraConnection, ResultValue, }; +use futures::Future; async fn delete(session: &CassandraConnection) { let prepared = session @@ -10,7 +10,7 @@ async fn delete(session: &CassandraConnection) { .await; assert_eq!( - session.execute_prepared(&prepared, 1).await, + session.execute_prepared(&prepared, Some(1)).await, Vec::>::new() ); @@ -28,17 +28,17 @@ async fn insert(session: &CassandraConnection) { .await; assert_eq!( - session.execute_prepared(&prepared, 1).await, + session.execute_prepared(&prepared, Some(1)).await, Vec::>::new() ); assert_eq!( - session.execute_prepared(&prepared, 2).await, + session.execute_prepared(&prepared, Some(2)).await, Vec::>::new() ); assert_eq!( - session.execute_prepared(&prepared, 2).await, + session.execute_prepared(&prepared, Some(2)).await, Vec::>::new() ); } @@ -48,7 +48,7 @@ async fn select(session: &CassandraConnection) { .prepare("SELECT id FROM test_prepare_statements.table_1 WHERE id = ?") .await; - let result_rows = session.execute_prepared(&prepared, 1).await; + let result_rows = session.execute_prepared(&prepared, Some(1)).await; assert_rows(result_rows, &[&[ResultValue::Int(1)]]); } @@ -68,11 +68,11 @@ async fn select_cross_connection( let connection_after = connection_creator().await; assert_rows( - connection_before.execute_prepared(&prepared, 1).await, + connection_before.execute_prepared(&prepared, Some(1)).await, &[&[ResultValue::Int(1), ResultValue::Int(1)]], ); assert_rows( - connection_after.execute_prepared(&prepared, 1).await, + connection_after.execute_prepared(&prepared, Some(1)).await, &[&[ResultValue::Int(1), ResultValue::Int(1)]], ); } @@ -89,7 +89,7 @@ async fn use_statement(session: &CassandraConnection) { // observe query completing against the original keyspace without errors assert_eq!( - session.execute_prepared(&prepared, 358).await, + session.execute_prepared(&prepared, Some(358)).await, Vec::>::new() ); @@ -119,4 +119,10 @@ where select_cross_connection(session, connection_creator).await; delete(session).await; use_statement(session).await; + + if session.is(&[CassandraDriver::Scylla, CassandraDriver::CdrsTokio]) { + let cql = "SELECT * FROM system.local WHERE key = 'local'"; + let prepared = session.prepare(cql).await; + session.execute_prepared(&prepared, None).await; + } } diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index 61be0a114..f5b0de47f 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -48,21 +48,21 @@ pub enum PreparedQuery { impl PreparedQuery { #[cfg(feature = "cassandra-cpp-driver-tests")] - fn as_datastax(&self) -> &PreparedStatementCpp { + pub fn as_datastax(&self) -> &PreparedStatementCpp { match self { PreparedQuery::Datastax(p) => p, _ => panic!("Not PreparedQuery::Datastax"), } } - fn as_cdrs(&self) -> &CdrsTokioPreparedQuery { + pub fn as_cdrs(&self) -> &CdrsTokioPreparedQuery { match self { PreparedQuery::CdrsTokio(p) => p, _ => panic!("Not PreparedQuery::CdrsTokio"), } } - fn as_scylla(&self) -> &PreparedStatementScylla { + pub fn as_scylla(&self) -> &PreparedStatementScylla { match self { PreparedQuery::Scylla(s) => s, _ => panic!("Not PreparedQuery::Scylla"), @@ -83,7 +83,7 @@ fn cpp_error_to_cdrs(code: CassErrorCode, message: String) -> ErrorBody { } #[allow(dead_code)] -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Eq, PartialEq)] pub enum CassandraDriver { #[cfg(feature = "cassandra-cpp-driver-tests")] Datastax, @@ -197,6 +197,16 @@ impl CassandraConnection { } } + #[allow(dead_code)] + pub fn is(&self, drivers: &[CassandraDriver]) -> bool { + match self { + Self::CdrsTokio { .. } => drivers.contains(&CassandraDriver::CdrsTokio), + #[cfg(feature = "cassandra-cpp-driver-tests")] + Self::Datastax { .. } => drivers.contains(&CassandraDriver::Datastax), + Self::Scylla { .. } => drivers.contains(&CassandraDriver::Scylla), + } + } + #[cfg(feature = "cassandra-cpp-driver-tests")] #[allow(dead_code)] pub fn as_datastax(&self) -> &DatastaxSession { @@ -461,13 +471,15 @@ impl CassandraConnection { pub async fn execute_prepared( &self, prepared_query: &PreparedQuery, - value: i32, + value: Option, ) -> Vec> { match self { #[cfg(feature = "cassandra-cpp-driver-tests")] Self::Datastax { session, .. } => { let mut statement = prepared_query.as_datastax().bind(); - statement.bind_int32(0, value).unwrap(); + if let Some(value) = value { + statement.bind_int32(0, value).unwrap(); + } statement.set_tracing(true).unwrap(); match session.execute(&statement).await { Ok(result) => result @@ -481,9 +493,12 @@ impl CassandraConnection { } Self::CdrsTokio { session, .. } => { let statement = prepared_query.as_cdrs(); - let query_params = QueryParamsBuilder::new() - .with_values(query_values!(value)) - .build(); + + let mut builder = QueryParamsBuilder::new(); + if let Some(value) = value { + builder = builder.with_values(query_values!(value)); + } + let query_params = builder.build(); let params = StatementParams { query_params, @@ -504,7 +519,11 @@ impl CassandraConnection { } Self::Scylla { session, .. } => { let statement = prepared_query.as_scylla(); - let response = session.execute(statement, (value,)).await.unwrap(); + let response = if let Some(value) = value { + session.execute(statement, (value,)).await.unwrap() + } else { + session.execute(statement, ()).await.unwrap() + }; match response.rows { Some(rows) => rows