diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs index 1d30a9433..ccc4789d7 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs @@ -24,9 +24,10 @@ use futures::StreamExt; use itertools::Itertools; use metrics::{register_counter, Counter}; use node::{CassandraNode, ConnectionFactory}; -use node_pool::{GetReplicaErr, NodePool}; +use node_pool::{GetReplicaErr, KeyspaceMetadata, NodePool}; use rand::prelude::*; use serde::Deserialize; +use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::time::Duration; use tokio::sync::{mpsc, oneshot, watch}; @@ -42,6 +43,9 @@ mod test_router; mod token_map; pub mod topology; +pub type KeyspaceChanTx = watch::Sender>; +pub type KeyspaceChanRx = watch::Receiver>; + #[derive(Deserialize, Debug, Clone)] pub struct CassandraSinkClusterConfig { /// contact points must be within the specified data_center and rack. @@ -118,6 +122,7 @@ pub struct CassandraSinkCluster { /// Addditionally any changes to nodes_rx is observed and copied over. pool: NodePool, nodes_rx: watch::Receiver>, + keyspaces_rx: KeyspaceChanRx, rng: SmallRng, task_handshake_tx: mpsc::Sender, } @@ -143,6 +148,7 @@ impl Clone for CassandraSinkCluster { // Because the self.nodes_rx is always copied from the original nodes_rx created before any node lists were sent, // once a single node list has been sent all new connections will immediately recognize it as a change. nodes_rx: self.nodes_rx.clone(), + keyspaces_rx: self.keyspaces_rx.clone(), rng: SmallRng::from_rng(rand::thread_rng()).unwrap(), task_handshake_tx: self.task_handshake_tx.clone(), } @@ -163,10 +169,14 @@ impl CassandraSinkCluster { let receive_timeout = timeout.map(Duration::from_secs); let (local_nodes_tx, local_nodes_rx) = watch::channel(vec![]); + let (keyspaces_tx, keyspaces_rx): (KeyspaceChanTx, KeyspaceChanRx) = + watch::channel(HashMap::new()); + let (task_handshake_tx, task_handshake_rx) = mpsc::channel(1); create_topology_task( local_nodes_tx, + keyspaces_tx, task_handshake_rx, local_shotover_node.data_center.clone(), ); @@ -192,6 +202,7 @@ impl CassandraSinkCluster { local_shotover_node, pool: NodePool::new(vec![]), nodes_rx: local_nodes_rx, + keyspaces_rx, rng: SmallRng::from_rng(rand::thread_rng()).unwrap(), task_handshake_tx, } @@ -236,6 +247,10 @@ impl CassandraSinkCluster { } } + if self.keyspaces_rx.has_changed()? { + self.pool.update_keyspaces(&mut self.keyspaces_rx).await; + } + let tables_to_rewrite: Vec = messages .iter_mut() .enumerate() @@ -369,7 +384,6 @@ impl CassandraSinkCluster { &self.local_shotover_node.rack, &metadata.version, &mut self.rng, - 1, ) .await { diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs index 1a0c6c3af..4ec0938c4 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs @@ -17,12 +17,12 @@ use uuid::Uuid; pub struct CassandraNode { pub address: SocketAddr, pub rack: String, - - #[derivative(Debug = "ignore")] - pub tokens: Vec, pub outbound: Option, pub host_id: Uuid, pub is_up: bool, + + #[derivative(Debug = "ignore")] + pub tokens: Vec, } impl CassandraNode { diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/node_pool.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/node_pool.rs index b09f50cd6..c4fb41f60 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/node_pool.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/node_pool.rs @@ -1,6 +1,7 @@ +use super::node::CassandraNode; use super::routing_key::calculate_routing_key; use super::token_map::TokenMap; -use crate::transforms::cassandra::sink_cluster::node::CassandraNode; +use super::KeyspaceChanRx; use anyhow::{anyhow, Error, Result}; use cassandra_protocol::frame::message_execute::BodyReqExecuteOwned; use cassandra_protocol::frame::message_result::PreparedMetadata; @@ -19,9 +20,15 @@ pub enum GetReplicaErr { Other(Error), } +#[derive(Debug, Clone)] +pub struct KeyspaceMetadata { + pub replication_factor: usize, +} + #[derive(Debug)] pub struct NodePool { prepared_metadata: Arc>>, + keyspace_metadata: Arc>>, token_map: TokenMap, nodes: Vec, prev_idx: usize, @@ -31,6 +38,7 @@ impl Clone for NodePool { fn clone(&self) -> Self { Self { prepared_metadata: self.prepared_metadata.clone(), + keyspace_metadata: self.keyspace_metadata.clone(), token_map: TokenMap::new(&[]), nodes: vec![], prev_idx: 0, @@ -44,6 +52,7 @@ impl NodePool { token_map: TokenMap::new(nodes.as_slice()), nodes, prepared_metadata: Arc::new(RwLock::new(HashMap::new())), + keyspace_metadata: Arc::new(RwLock::new(HashMap::new())), prev_idx: 0, } } @@ -70,6 +79,12 @@ impl NodePool { self.token_map = TokenMap::new(self.nodes.as_slice()); } + pub async fn update_keyspaces(&mut self, keyspaces_rx: &mut KeyspaceChanRx) { + let updated_keyspaces = keyspaces_rx.borrow_and_update().clone(); + let mut write_keyspaces = self.keyspace_metadata.write().await; + *write_keyspaces = updated_keyspaces; + } + pub async fn add_prepared_result(&mut self, id: CBytesShort, metadata: PreparedMetadata) { let mut write_lock = self.prepared_metadata.write().await; write_lock.insert(id, metadata); @@ -121,7 +136,6 @@ impl NodePool { rack: &str, version: &Version, rng: &mut SmallRng, - rf: usize, // TODO this parameter should be removed ) -> Result, GetReplicaErr> { let metadata = { let read_lock = self.prepared_metadata.read().await; @@ -131,19 +145,35 @@ impl NodePool { .clone() }; + let keyspace = { + let read_lock = self.keyspace_metadata.read().await; + read_lock + .get( + &metadata + .global_table_spec + .as_ref() + .ok_or_else(|| GetReplicaErr::Other(anyhow!("bruh")))? + .ks_name, + ) + .ok_or(GetReplicaErr::NoMetadata)? + .clone() + }; + let routing_key = calculate_routing_key( &metadata.pk_indexes, execute.query_parameters.values.as_ref().ok_or_else(|| { - GetReplicaErr::Other(anyhow!("Execute body does not have query paramters")) + GetReplicaErr::Other(anyhow!("Execute body does not have query parameters")) })?, *version, ) .unwrap(); - // TODO this should use the keyspace info to properly select the replica count let replica_host_ids = self .token_map - .iter_replica_nodes(Murmur3Token::generate(&routing_key), rf) + .iter_replica_nodes_capped( + Murmur3Token::generate(&routing_key), + keyspace.replication_factor, + ) .collect::>(); let (dc_replicas, rack_replicas) = self diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/test_router.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/test_router.rs index bc8f14378..7973e2995 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/test_router.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/test_router.rs @@ -1,8 +1,9 @@ #[cfg(test)] mod test_token_aware_router { - use super::super::node_pool::NodePool; + use super::super::node_pool::{KeyspaceMetadata, NodePool}; use super::super::routing_key::calculate_routing_key; use crate::transforms::cassandra::sink_cluster::node::CassandraNode; + use crate::transforms::cassandra::sink_cluster::{KeyspaceChanRx, KeyspaceChanTx}; use cassandra_protocol::consistency::Consistency::One; use cassandra_protocol::frame::message_execute::BodyReqExecuteOwned; use cassandra_protocol::frame::message_result::PreparedMetadata; @@ -18,6 +19,7 @@ mod test_token_aware_router { use rand::prelude::*; use std::collections::HashMap; use std::net::SocketAddr; + use tokio::sync::watch; use uuid::Uuid; #[tokio::test] @@ -30,6 +32,23 @@ mod test_token_aware_router { 11, 241, 38, 11, 140, 72, 217, 34, 214, 128, 175, 241, 151, 73, 197, 227, ]); + let keyspace_metadata = KeyspaceMetadata { + replication_factor: 3, + }; + + let (keyspaces_tx, mut keyspaces_rx): (KeyspaceChanTx, KeyspaceChanRx) = + watch::channel(HashMap::new()); + + keyspaces_tx + .send( + [("demo_ks".to_string(), keyspace_metadata)] + .into_iter() + .collect(), + ) + .unwrap(); + + router.update_keyspaces(&mut keyspaces_rx).await; + router .add_prepared_result(id.clone(), prepared_metadata().clone()) .await; @@ -64,7 +83,6 @@ mod test_token_aware_router { "rack1", &Version::V4, &mut rng, - 3, ) .await .unwrap() diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/token_map.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/token_map.rs index 2932d9fe5..57ec69ad8 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/token_map.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/token_map.rs @@ -19,7 +19,7 @@ impl TokenMap { } /// Returns nodes starting at given token and going in the direction of replicas. - pub fn iter_replica_nodes( + pub fn iter_replica_nodes_capped( &self, token: Murmur3Token, replica_count: usize, @@ -30,6 +30,14 @@ impl TokenMap { .take(replica_count) .map(|(_, node)| *node) } + + // pub fn iter_replica_nodes(&self, token: Murmur3Token) -> impl Iterator + '_ { + // self.token_ring + // .range(token..) + // .chain(self.token_ring.iter()) + // .take(self.token_ring.len()) + // .map(|(_, node)| *node) + // } } #[cfg(test)] @@ -47,7 +55,7 @@ mod test_token_map { vec![ CassandraNode::new( "127.0.0.1:9042".parse().unwrap(), - "rack1".into(), + "dc1".into(), vec![ Murmur3Token::new(-2), Murmur3Token::new(-1), @@ -57,13 +65,13 @@ mod test_token_map { ), CassandraNode::new( "127.0.0.1:9043".parse().unwrap(), - "rack1".into(), + "dc1".into(), vec![Murmur3Token::new(20)], NODE_2, ), CassandraNode::new( "127.0.0.1:9044".parse().unwrap(), - "rack1".into(), + "dc1".into(), vec![ Murmur3Token::new(2), Murmur3Token::new(1), @@ -98,7 +106,7 @@ mod test_token_map { fn verify_tokens(node_host_ids: &[Uuid], token: Murmur3Token) { let token_map = TokenMap::new(prepare_nodes().as_slice()); let nodes = token_map - .iter_replica_nodes(token, node_host_ids.len()) + .iter_replica_nodes_capped(token, node_host_ids.len()) .collect_vec(); assert_eq!(nodes, node_host_ids); diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs index 170d54a7c..5c6f25231 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs @@ -1,4 +1,6 @@ use super::node::{CassandraNode, ConnectionFactory}; +use super::node_pool::KeyspaceMetadata; +use super::KeyspaceChanTx; use crate::frame::cassandra::{parse_statement_single, Tracing}; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; use crate::message::{Message, MessageValue}; @@ -9,6 +11,8 @@ use cassandra_protocol::frame::events::{StatusChangeType, TopologyChangeType}; use cassandra_protocol::frame::message_register::BodyReqRegister; use cassandra_protocol::token::Murmur3Token; use cassandra_protocol::{frame::Version, query::QueryParams}; +use itertools::Itertools; +use std::collections::HashMap; use std::net::SocketAddr; use tokio::sync::mpsc::unbounded_channel; use tokio::sync::{mpsc, oneshot, watch}; @@ -21,13 +25,21 @@ pub struct TaskConnectionInfo { pub fn create_topology_task( nodes_tx: watch::Sender>, + keyspaces_tx: KeyspaceChanTx, mut connection_info_rx: mpsc::Receiver, data_center: String, ) { tokio::spawn(async move { while let Some(mut connection_info) = connection_info_rx.recv().await { let mut attempts = 0; - match topology_task_process(&nodes_tx, &mut connection_info, &data_center).await { + match topology_task_process( + &nodes_tx, + &keyspaces_tx, + &mut connection_info, + &data_center, + ) + .await + { Err(err) => { tracing::error!("topology task failed, retrying, error was: {err:?}"); attempts += 1; @@ -47,6 +59,7 @@ pub fn create_topology_task( async fn topology_task_process( nodes_tx: &watch::Sender>, + keyspaces_tx: &KeyspaceChanTx, connection_info: &mut TaskConnectionInfo, data_center: &str, ) -> Result<()> { @@ -67,9 +80,14 @@ async fn topology_task_process( return Ok(()); } + let mut keyspaces = system_keyspaces::query(&connection, data_center).await?; + if let Err(watch::error::SendError(_)) = keyspaces_tx.send(keyspaces.clone()) { + return Ok(()); + } + register_for_topology_and_status_events(&connection, version).await?; - loop { + 'listen: loop { // Wait for events to come in from the cassandra node. // If all the nodes receivers are closed then immediately stop listening and shutdown the task let pushed_messages = tokio::select! { @@ -106,9 +124,21 @@ async fn topology_task_process( } nodes = new_nodes; + + if let Err(watch::error::SendError(_)) = + nodes_tx.send(nodes.clone()) + { + break 'listen; + } } TopologyChangeType::RemovedNode => { - nodes.retain(|node| node.address != topology.addr) + nodes.retain(|node| node.address != topology.addr); + + if let Err(watch::error::SendError(_)) = + nodes_tx.send(nodes.clone()) + { + break 'listen; + } } }, ServerEvent::StatusChange(status) => { @@ -120,18 +150,31 @@ async fn topology_task_process( } } } + if let Err(watch::error::SendError(_)) = + nodes_tx.send(nodes.clone()) + { + break 'listen; + } + } + ServerEvent::SchemaChange(change) => { + tracing::warn!("{:?}", change); + keyspaces = + system_keyspaces::query(&connection, data_center).await?; + if let Err(watch::error::SendError(_)) = + keyspaces_tx.send(keyspaces.clone()) + { + break 'listen; + } } - event => tracing::error!("Unexpected event: {:?}", event), } } } } None => return Err(anyhow!("topology control connection was closed")), } - if let Err(watch::error::SendError(_)) = nodes_tx.send(nodes.clone()) { - return Ok(()); - } } + + Ok(()) } async fn register_for_topology_and_status_events( @@ -150,6 +193,7 @@ async fn register_for_topology_and_status_events( events: vec![ SimpleServerEvent::TopologyChange, SimpleServerEvent::StatusChange, + SimpleServerEvent::SchemaChange, ], }), })), @@ -183,6 +227,174 @@ async fn fetch_current_nodes( Ok(new_nodes) } +enum ReplicationStrategy { + SimpleStrategy { + replication_factor: usize, + }, + NetworkTopologyStrategy { + datacenter_replication_factor: HashMap, + }, + Other, +} + +mod system_keyspaces { + use super::*; + use serde_json::{Map, Value as JsonValue}; + use std::str::FromStr; + + pub async fn query( + connection: &CassandraConnection, + data_center: &str, + ) -> Result> { + let (tx, rx) = oneshot::channel(); + + connection.send(Message::from_frame( + Frame::Cassandra(CassandraFrame{ + version: Version::V4, + stream_id: 0, + tracing: Tracing::Request(false), + warnings: vec![], + operation: CassandraOperation::Query{ + query: Box::new( + parse_statement_single( + + "SELECT keyspace_name, toJson(replication) AS replication FROM system_schema.keyspaces", + ) + ), + + params: Box::new(QueryParams::default()), + } + })), + tx + )?; + + let response = rx.await?.response?; + into_keyspaces(response, data_center) + } + + fn into_keyspaces( + mut response: Message, + data_center: &str, + ) -> Result> { + if let Some(Frame::Cassandra(frame)) = response.frame() { + match &mut frame.operation { + CassandraOperation::Result(CassandraResult::Rows { rows, .. }) => rows + .iter_mut() + .map(|row| build_keyspace(row, data_center)) + .try_collect(), + operation => Err(anyhow!( + "keyspace query returned unexpected cassandra operation: {:?}", + operation + )), + } + } else { + Err(anyhow!("Failed to parse keyspace query response")) + } + } + + fn build_keyspace( + row: &mut Vec, + data_center: &str, + ) -> Result<(String, KeyspaceMetadata)> { + let metadata = if let Some(MessageValue::Varchar(string)) = row.pop() { + let replication: JsonValue = serde_json::from_str(&string).map_err(|error| { + anyhow!(format!( + "Error parsing replication. Error: {} Replication: {}", + error, string + )) + })?; + + let replication_strategy = match replication { + JsonValue::Object(properties) => build_replication_strategy(properties)?, + _ => { + return Err(anyhow!(format!( + "Error parsing replication strategy: {}", + replication + ))) + } + }; + + let replication_factor = match replication_strategy { + ReplicationStrategy::SimpleStrategy { replication_factor } => replication_factor, + ReplicationStrategy::NetworkTopologyStrategy { + datacenter_replication_factor, + } => *datacenter_replication_factor.get(data_center).unwrap_or(&0), + _ => 0, + }; + + KeyspaceMetadata { replication_factor } + } else { + return Err(anyhow!( + "system_schema.keyspaces.replication is not a varchar" + )); + }; + + let name = if let Some(MessageValue::Varchar(name)) = row.pop() { + name + } else { + return Err(anyhow!("system_schema_keyspaces.name")); + }; + + Ok((name, metadata)) + } + + fn build_replication_strategy( + mut properties: Map, + ) -> Result { + match properties.remove("class") { + Some(JsonValue::String(class)) => Ok(match class.as_str() { + "org.apache.cassandra.locator.SimpleStrategy" | "SimpleStrategy" => { + ReplicationStrategy::SimpleStrategy { + replication_factor: extract_replication_factor( + properties.get("replication_factor"), + )?, + } + } + "org.apache.cassandra.locator.NetworkTopologyStrategy" + | "NetworkTopologyStrategy" => ReplicationStrategy::NetworkTopologyStrategy { + datacenter_replication_factor: extract_datacenter_replication_factor( + properties, + )?, + }, + _ => ReplicationStrategy::Other, + }), + _ => Err(anyhow!("Missing replication strategy class")), + } + } + + fn extract_datacenter_replication_factor( + properties: Map, + ) -> Result> { + properties + .into_iter() + .map(|(key, replication_factor)| { + extract_replication_factor(Some(&replication_factor)) + .map(move |replication_factor| (key, replication_factor)) + }) + .try_collect() + } + + fn extract_replication_factor(value: Option<&JsonValue>) -> Result { + match value { + Some(JsonValue::String(replication_factor)) => { + let result = if let Some(slash) = replication_factor.find('/') { + usize::from_str(&replication_factor[..slash]) + } else { + usize::from_str(replication_factor) + }; + + result.map_err(|error| { + anyhow!(format!( + "Failed to parse ('{}'): {}", + replication_factor, error + )) + }) + } + _ => Err(anyhow!("Missing replication factor")), + } + } +} + mod system_local { use super::*; diff --git a/shotover-proxy/tests/cassandra_int_tests/cluster/mod.rs b/shotover-proxy/tests/cassandra_int_tests/cluster/mod.rs index 6f3d756fd..402f7edc2 100644 --- a/shotover-proxy/tests/cassandra_int_tests/cluster/mod.rs +++ b/shotover-proxy/tests/cassandra_int_tests/cluster/mod.rs @@ -6,6 +6,7 @@ use shotover_proxy::transforms::cassandra::sink_cluster::{ node::{CassandraNode, ConnectionFactory}, topology::{create_topology_task, TaskConnectionInfo}, }; +use std::collections::HashMap; use tokio::sync::{mpsc, watch}; pub mod multi_rack; @@ -15,6 +16,7 @@ pub mod single_rack_v4; pub async fn run_topology_task(ca_path: Option<&str>, port: Option) -> Vec { let port = port.unwrap_or(9042); let (nodes_tx, mut nodes_rx) = watch::channel(vec![]); + let (keyspaces_tx, _keyspaces_rx) = watch::channel(HashMap::new()); let (task_handshake_tx, task_handshake_rx) = mpsc::channel(1); let tls = ca_path.map(|ca_path| { TlsConnector::new(TlsConnectorConfig { @@ -31,7 +33,7 @@ pub async fn run_topology_task(ca_path: Option<&str>, port: Option) -> Vec< connection_factory.push_handshake_message(message); } - create_topology_task(nodes_tx, task_handshake_rx, "dc1".to_string()); + create_topology_task(nodes_tx, keyspaces_tx, task_handshake_rx, "dc1".to_string()); // Give the handshake task a hardcoded handshake. // Normally the handshake is the handshake that the client gave shotover.