diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs index e75bcd2af..d8d858697 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs @@ -24,10 +24,12 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::{mpsc, oneshot, RwLock}; +use topology::{create_topology_task, TaskConnectionInfo}; use uuid::Uuid; use version_compare::Cmp; pub mod node; +pub mod topology; #[derive(Deserialize, Debug, Clone)] pub struct CassandraSinkClusterConfig { @@ -720,160 +722,6 @@ enum RewriteTableTy { Peers, } -pub fn create_topology_task( - nodes: Arc>>, - mut handshake_rx: mpsc::Receiver, - data_center: String, -) { - tokio::spawn(async move { - while let Some(handshake) = handshake_rx.recv().await { - let mut attempts = 0; - while let Err(err) = topology_task_process(&nodes, &handshake, &data_center).await { - tracing::error!("topology task failed, retrying, error was: {err:?}"); - attempts += 1; - if attempts > 3 { - // 3 attempts have failed, lets try a new handshake - break; - } - } - - // Sleep for an hour. - // TODO: This is a crude way to ensure we dont overload the transforms with too many topology changes. - // This will be replaced with: - // * the task subscribes to events - // * the transforms request a reload when they hit connection errors - tokio::time::sleep(std::time::Duration::from_secs(60 * 60)).await; - } - }); -} - -async fn topology_task_process( - nodes: &Arc>>, - handshake: &TaskConnectionInfo, - data_center: &str, -) -> Result<()> { - let outbound = handshake - .connection_factory - .new_connection(handshake.address) - .await?; - - let (peers_tx, peers_rx) = oneshot::channel(); - outbound.send( - Message::from_frame(Frame::Cassandra(CassandraFrame { - version: Version::V4, - stream_id: 0, - tracing_id: None, - warnings: vec![], - operation: CassandraOperation::Query { - query: Box::new(parse_statement_single( - "SELECT peer, rack, data_center, tokens FROM system.peers", - )), - params: Box::new(QueryParams::default()), - }, - })), - peers_tx, - )?; - - let (local_tx, local_rx) = oneshot::channel(); - outbound.send( - Message::from_frame(Frame::Cassandra(CassandraFrame { - version: Version::V4, - stream_id: 1, - tracing_id: None, - warnings: vec![], - operation: CassandraOperation::Query { - query: Box::new(parse_statement_single( - "SELECT broadcast_address, rack, data_center, tokens FROM system.local", - )), - params: Box::new(QueryParams::default()), - }, - })), - local_tx, - )?; - - let (new_nodes, more_nodes) = tokio::join!( - async { system_peers_into_nodes(peers_rx.await?.response?, data_center) }, - async { system_peers_into_nodes(local_rx.await?.response?, data_center) } - ); - let mut new_nodes = new_nodes?; - new_nodes.extend(more_nodes?); - - let mut write_lock = nodes.write().await; - let expensive_drop = std::mem::replace(&mut *write_lock, new_nodes); - - // Make sure to drop write_lock before the expensive_drop which will have to perform many deallocations. - std::mem::drop(write_lock); - std::mem::drop(expensive_drop); - - Ok(()) -} - -fn system_peers_into_nodes( - mut response: Message, - config_data_center: &str, -) -> Result> { - if let Some(Frame::Cassandra(frame)) = response.frame() { - match &mut frame.operation { - CassandraOperation::Result(CassandraResult::Rows { - value: MessageValue::Rows(rows), - .. - }) => rows - .iter_mut() - .filter(|row| { - if let Some(MessageValue::Varchar(data_center)) = row.get(2) { - data_center == config_data_center - } else { - false - } - }) - .map(|row| { - if row.len() != 4 { - return Err(anyhow!("expected 4 columns but was {}", row.len())); - } - - let tokens = if let Some(MessageValue::List(list)) = row.pop() { - list.into_iter() - .map::, _>(|x| match x { - MessageValue::Varchar(a) => Ok(a), - _ => Err(anyhow!("tokens value not a varchar")), - }) - .collect::>>()? - } else { - return Err(anyhow!("tokens not a list")); - }; - let _data_center = row.pop(); - let rack = if let Some(MessageValue::Varchar(value)) = row.pop() { - value - } else { - return Err(anyhow!("rack not a varchar")); - }; - let address = if let Some(MessageValue::Inet(value)) = row.pop() { - value - } else { - return Err(anyhow!("address not an inet")); - }; - - Ok(CassandraNode { - address, - rack, - _tokens: tokens, - outbound: None, - }) - }) - .collect(), - operation => Err(anyhow!( - "system.peers returned unexpected cassandra operation: {:?}", - operation - )), - } - } else { - Err(anyhow!( - "Failed to parse system.peers response {:?}", - response - )) - } -} - fn is_use_statement(request: &mut Message) -> bool { if let Some(Frame::Cassandra(frame)) = request.frame() { if let CassandraOperation::Query { query, .. } = &mut frame.operation { @@ -1031,9 +879,3 @@ impl Transform for CassandraSinkCluster { .set_pushed_messages_tx(pushed_messages_tx); } } - -#[derive(Debug)] -pub struct TaskConnectionInfo { - pub connection_factory: ConnectionFactory, - pub address: SocketAddr, -} diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs new file mode 100644 index 000000000..0ccb0e144 --- /dev/null +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs @@ -0,0 +1,170 @@ +use super::node::{CassandraNode, ConnectionFactory}; +use crate::frame::cassandra::parse_statement_single; +use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; +use crate::message::{Message, MessageValue}; +use anyhow::{anyhow, Result}; +use cassandra_protocol::frame::Version; +use cassandra_protocol::query::QueryParams; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot, RwLock}; + +#[derive(Debug)] +pub struct TaskConnectionInfo { + pub connection_factory: ConnectionFactory, + pub address: SocketAddr, +} + +pub fn create_topology_task( + nodes: Arc>>, + mut handshake_rx: mpsc::Receiver, + data_center: String, +) { + tokio::spawn(async move { + while let Some(handshake) = handshake_rx.recv().await { + let mut attempts = 0; + while let Err(err) = topology_task_process(&nodes, &handshake, &data_center).await { + tracing::error!("topology task failed, retrying, error was: {err:?}"); + attempts += 1; + if attempts > 3 { + // 3 attempts have failed, lets try a new handshake + break; + } + } + + // Sleep for an hour. + // TODO: This is a crude way to ensure we dont overload the transforms with too many topology changes. + // This will be replaced with: + // * the task subscribes to events + // * the transforms request a reload when they hit connection errors + tokio::time::sleep(std::time::Duration::from_secs(60 * 60)).await; + } + }); +} + +async fn topology_task_process( + nodes: &Arc>>, + handshake: &TaskConnectionInfo, + data_center: &str, +) -> Result<()> { + let outbound = handshake + .connection_factory + .new_connection(handshake.address) + .await?; + + let (peers_tx, peers_rx) = oneshot::channel(); + outbound.send( + Message::from_frame(Frame::Cassandra(CassandraFrame { + version: Version::V4, + stream_id: 0, + tracing_id: None, + warnings: vec![], + operation: CassandraOperation::Query { + query: Box::new(parse_statement_single( + "SELECT peer, rack, data_center, tokens FROM system.peers", + )), + params: Box::new(QueryParams::default()), + }, + })), + peers_tx, + )?; + + let (local_tx, local_rx) = oneshot::channel(); + outbound.send( + Message::from_frame(Frame::Cassandra(CassandraFrame { + version: Version::V4, + stream_id: 1, + tracing_id: None, + warnings: vec![], + operation: CassandraOperation::Query { + query: Box::new(parse_statement_single( + "SELECT broadcast_address, rack, data_center, tokens FROM system.local", + )), + params: Box::new(QueryParams::default()), + }, + })), + local_tx, + )?; + + let (new_nodes, more_nodes) = tokio::join!( + async { system_peers_into_nodes(peers_rx.await?.response?, data_center) }, + async { system_peers_into_nodes(local_rx.await?.response?, data_center) } + ); + let mut new_nodes = new_nodes?; + new_nodes.extend(more_nodes?); + + let mut write_lock = nodes.write().await; + let expensive_drop = std::mem::replace(&mut *write_lock, new_nodes); + + // Make sure to drop write_lock before the expensive_drop which will have to perform many deallocations. + std::mem::drop(write_lock); + std::mem::drop(expensive_drop); + + Ok(()) +} + +fn system_peers_into_nodes( + mut response: Message, + config_data_center: &str, +) -> Result> { + if let Some(Frame::Cassandra(frame)) = response.frame() { + match &mut frame.operation { + CassandraOperation::Result(CassandraResult::Rows { + value: MessageValue::Rows(rows), + .. + }) => rows + .iter_mut() + .filter(|row| { + if let Some(MessageValue::Varchar(data_center)) = row.get(2) { + data_center == config_data_center + } else { + false + } + }) + .map(|row| { + if row.len() != 4 { + return Err(anyhow!("expected 4 columns but was {}", row.len())); + } + + let tokens = if let Some(MessageValue::List(list)) = row.pop() { + list.into_iter() + .map::, _>(|x| match x { + MessageValue::Varchar(a) => Ok(a), + _ => Err(anyhow!("tokens value not a varchar")), + }) + .collect::>>()? + } else { + return Err(anyhow!("tokens not a list")); + }; + let _data_center = row.pop(); + let rack = if let Some(MessageValue::Varchar(value)) = row.pop() { + value + } else { + return Err(anyhow!("rack not a varchar")); + }; + let address = if let Some(MessageValue::Inet(value)) = row.pop() { + value + } else { + return Err(anyhow!("address not an inet")); + }; + + Ok(CassandraNode { + address, + rack, + _tokens: tokens, + outbound: None, + }) + }) + .collect(), + operation => Err(anyhow!( + "system.peers returned unexpected cassandra operation: {:?}", + operation + )), + } + } else { + Err(anyhow!( + "Failed to parse system.peers response {:?}", + response + )) + } +} diff --git a/shotover-proxy/tests/cassandra_int_tests/cluster.rs b/shotover-proxy/tests/cassandra_int_tests/cluster.rs index 03166bb09..444aa54c8 100644 --- a/shotover-proxy/tests/cassandra_int_tests/cluster.rs +++ b/shotover-proxy/tests/cassandra_int_tests/cluster.rs @@ -3,9 +3,8 @@ use shotover_proxy::frame::{CassandraFrame, CassandraOperation, Frame}; use shotover_proxy::message::Message; use shotover_proxy::tls::{TlsConnector, TlsConnectorConfig}; use shotover_proxy::transforms::cassandra::sink_cluster::{ - create_topology_task, node::{CassandraNode, ConnectionFactory}, - TaskConnectionInfo, + topology::{create_topology_task, TaskConnectionInfo}, }; use std::sync::Arc; use tokio::sync::{mpsc, RwLock};