diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index 550cc3775..d1fef5fc1 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -43,7 +43,7 @@ pub type Messages = Vec; /// Usually a message is received and starts off containing just raw bytes (or possibly raw bytes + frame) /// This can be immediately sent off to the destination without any processing cost. /// -/// However if a transform wants to query the contents of the message it must call `Message::frame()q which will cause the raw bytes to be processed into a raw bytes + Frame. +/// However if a transform wants to query the contents of the message it must call `Message::frame()` which will cause the raw bytes to be processed into a raw bytes + Frame. /// The first call to frame has an expensive one time cost. /// /// The transform may also go one step further and modify the message's Frame + call `Message::invalidate_cache()`. diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster.rs index 792269b0b..4ce9b018c 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster.rs @@ -3,15 +3,12 @@ use crate::codec::cassandra::CassandraCodec; use crate::error::ChainResponse; use crate::frame::cassandra::parse_statement_single; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; -use crate::message::{IntSize, Message, MessageValue, Messages}; +use crate::message::{Message, MessageValue, Messages}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::util::Response; use crate::transforms::{Transform, Transforms, Wrapper}; use anyhow::{anyhow, Result}; use async_trait::async_trait; -use cassandra_protocol::frame::message_result::{ - ColSpec, ColType, ColTypeOption, ColTypeOptionValue, RowsMetadata, RowsMetadataFlags, TableSpec, -}; use cassandra_protocol::frame::Version; use cassandra_protocol::query::QueryParams; use cql3_parser::cassandra_statement::CassandraStatement; @@ -32,6 +29,9 @@ use version_compare::Cmp; #[derive(Deserialize, Debug, Clone)] pub struct CassandraSinkClusterConfig { + /// contact points must be within the specified data_center and rack. + /// If this is not followed, shotover's invariants will still be upheld but shotover will communicate with a + /// node outside of the specified data_center and rack. pub first_contact_points: Vec, pub data_center: String, pub rack: String, @@ -176,7 +176,40 @@ impl CassandraSinkCluster { .collect(); for table_to_rewrite in tables_to_rewrite.iter().rev() { - messages.remove(table_to_rewrite.index); + let mut stream_id = 0; + let mut restart = true; + while restart { + restart = false; + for message in &mut messages { + if let Some(Frame::Cassandra(frame)) = message.frame() { + if stream_id == frame.stream_id { + match stream_id.checked_add(1) { + Some(new_stream_id) => stream_id = new_stream_id, + None => return Err(anyhow!("Ran out of stream_ids")), + } + restart = true; + } + } + } + } + + if let RewriteTableTy::Local = table_to_rewrite.ty { + messages.insert( + table_to_rewrite.index+1, + Message::from_frame(Frame::Cassandra(CassandraFrame { + version: table_to_rewrite.version, + stream_id, + tracing_id: None, + warnings: vec![], + operation: CassandraOperation::Query { + query: Box::new(parse_statement_single( + "SELECT rack, data_center, schema_version, tokens, release_version FROM system.peers", + )), + params: Box::new(QueryParams::default()), + }, + })), + ); + } } // Create the initial connection. @@ -225,9 +258,10 @@ impl CassandraSinkCluster { let (return_chan_tx, return_chan_rx) = oneshot::channel(); if self.local_nodes.is_empty() || !self.init_handshake_complete - // DDL statements must be routed through the handshake connection so that later system.local queries are directed to the same node (they also go to the handshake connection) - // They must be the same node so that schema_version changes appear immediately in system.local + // system.local and system.peers must be routed to the same node otherwise the system.local node will be amongst the system.peers nodes and a node will be missing + // DDL statements and system.local must be routed through the same connection, so that schema_version changes appear immediately in system.local || is_ddl_statement(&mut message) + || self.is_system_local_or_peers(&mut message) { self.init_handshake_connection.as_mut().unwrap() } else if is_use_statement(&mut message) { @@ -305,10 +339,8 @@ impl CassandraSinkCluster { } for table_to_rewrite in tables_to_rewrite { - responses.insert( - table_to_rewrite.index, - self.rewrite_table(table_to_rewrite, local_addr).await?, - ); + self.rewrite_table(table_to_rewrite, local_addr, &mut responses) + .await?; } Ok(responses) @@ -330,9 +362,8 @@ impl CassandraSinkCluster { return Some(TableToRewrite { index, ty, - stream_id: cassandra.stream_id, - selects: select.columns.clone(), version: cassandra.version, + selects: select.columns.clone(), }); } } @@ -344,320 +375,184 @@ impl CassandraSinkCluster { &mut self, table: TableToRewrite, local_addr: SocketAddr, - ) -> Result { - let version = table.version; - let stream_id = table.stream_id; - let table_name = match table.ty { - RewriteTableTy::Local => "local".into(), - RewriteTableTy::Peers => "peers".into(), - }; - let (rows, col_specs) = match table.ty { - RewriteTableTy::Local => self.rewrite_table_local(table, local_addr).await?, - RewriteTableTy::Peers => self.rewrite_table_peers(table).await?, - }; - - Ok(Message::from_frame(Frame::Cassandra(CassandraFrame { - version, - stream_id, - tracing_id: None, - warnings: vec![], - operation: CassandraOperation::Result(CassandraResult::Rows { - value: MessageValue::Rows(rows), - // TODO: A bunch of these are just hardcoded with the assumption that the client didnt request any exotic features. - // We should implement them eventually but I cant imagine drivers would bother using them when querying the topology. - // TODO: flags should be removed upstream and just derived from the other fields at serialization time. - metadata: Box::new(RowsMetadata { - flags: RowsMetadataFlags::GLOBAL_TABLE_SPACE, - columns_count: col_specs.len() as i32, - paging_state: None, - new_metadata_id: None, - global_table_spec: Some(TableSpec { - ks_name: "system".into(), - table_name, - }), - col_specs, - }), - }), - }))) - } - - async fn rewrite_table_peers( - &mut self, - table: TableToRewrite, - ) -> Result<(Vec>, Vec)> { - let mut col_specs = vec![]; - let rows = vec![]; - - let peer_ident = Identifier::Unquoted("peer".into()); - let data_center_ident = Identifier::Unquoted("data_center".into()); - let host_id_ident = Identifier::Unquoted("host_id".into()); - let preferred_ip = Identifier::Unquoted("preferred_ip".into()); - let rack_ident = Identifier::Unquoted("rack".into()); - let release_version_ident = Identifier::Unquoted("release_version".into()); - let rpc_address_ident = Identifier::Unquoted("rpc_address".into()); - let schema_version_ident = Identifier::Unquoted("schema_version".into()); - let tokens_ident = Identifier::Unquoted("tokens".into()); - - for select in table.selects { - match select { - SelectElement::Star => { - col_specs.extend([ - meta("peer".into(), ColType::Inet), - meta("data_center".into(), ColType::Varchar), - meta("host_id".into(), ColType::Uuid), - meta("preferred_ip".into(), ColType::Inet), - meta("rack".into(), ColType::Varchar), - meta("release_version".into(), ColType::Varchar), - meta("rpc_address".into(), ColType::Inet), - meta("schema_version".into(), ColType::Uuid), - meta_tokens("tokens".into()), - ]); - } - SelectElement::Function(name) => { - // TODO: 90% sure SelectElement::Function(name) is not actually a name but the entire function call stuffed in a name LOL - col_specs.push(meta(name.to_string(), ColType::Varchar)); - } - SelectElement::Column(name) => { - if name.name == peer_ident { - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Inet)); - } else if name.name == data_center_ident { - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Varchar)); - } else if name.name == host_id_ident { - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Uuid)); - } else if name.name == preferred_ip { - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Inet)); - } else if name.name == rack_ident { - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Varchar)); - } else if name.name == rpc_address_ident { - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Inet)); - } else if name.name == release_version_ident { - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Varchar)); - } else if name.name == schema_version_ident { - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Uuid)); - } else if name.name == tokens_ident { - col_specs.push(meta_tokens(name.alias_or_name().to_string())); - } else { - tracing::error!("Unknown system.peer column requested {:?}", name) + responses: &mut Vec, + ) -> Result<()> { + match table.ty { + RewriteTableTy::Local => { + if table.index + 1 < responses.len() { + let peers_response = responses.remove(table.index + 1); + if let Some(local_response) = responses.get_mut(table.index) { + self.rewrite_table_local(table, local_addr, local_response, peers_response) + .await?; + local_response.invalidate_cache(); } } } + RewriteTableTy::Peers => { + if let Some(peers_response) = responses.get_mut(table.index) { + self.rewrite_table_peers(peers_response).await?; + peers_response.invalidate_cache(); + } + } } + Ok(()) + } + + async fn rewrite_table_peers(&mut self, peers_response: &mut Message) -> Result<()> { // TODO: generate rows for shotover peers // the current implementation will at least direct all traffic through shotover - - // TODO: schema_version and gossip_generation should be obtained by querying all nodes - - Ok((rows, col_specs)) + if let Some(Frame::Cassandra(frame)) = peers_response.frame() { + if let CassandraOperation::Result(CassandraResult::Rows { + value: MessageValue::Rows(rows), + .. + }) = &mut frame.operation + { + rows.clear(); + } + Ok(()) + } else { + Err(anyhow!( + "Failed to parse system.local response {:?}", + peers_response + )) + } } async fn rewrite_table_local( &mut self, table: TableToRewrite, local_address: SocketAddr, - ) -> Result<(Vec>, Vec)> { - let outbound = self.init_handshake_connection.as_ref().unwrap(); - let (peers_tx, peers_rx) = oneshot::channel(); - outbound.send( - Message::from_frame(Frame::Cassandra(CassandraFrame { - version: Version::V4, - stream_id: 0, - tracing_id: None, - warnings: vec![], - operation: CassandraOperation::Query { - query: Box::new(parse_statement_single( - "SELECT rack, data_center, schema_version, tokens, release_version FROM system.peers", - )), - params: Box::new(QueryParams::default()), - }, - })), - peers_tx, - )?; - - let (local_tx, local_rx) = oneshot::channel(); - outbound.send( - Message::from_frame(Frame::Cassandra(CassandraFrame { - version: Version::V4, - stream_id: 1, - tracing_id: None, - warnings: vec![], - operation: CassandraOperation::Query { - query: Box::new(parse_statement_single( - "SELECT cql_version, release_version, cluster_name, partitioner, schema_version, gossip_generation, tokens FROM system.local", - )), - params: Box::new(QueryParams::default()), - }, - })), - local_tx, - )?; - - let (peers, local) = tokio::join!( - async { parse_system_peers(peers_rx.await?.response?, &self.data_center, &self.rack) }, - async { parse_system_local(local_rx.await?.response?) }, - ); - let mut local = local?; - let mut peers = peers?; - - for peer in &mut peers { - local.tokens.append(&mut peer.tokens); - if let Ok(Cmp::Lt) = - version_compare::compare(&peer.release_version, &local.release_version) - { - std::mem::swap(&mut local.release_version, &mut peer.release_version); - } - } - local.tokens.sort(); - for peer in &peers { - if local.schema_version != peer.schema_version { - local.schema_version = Uuid::new_v4(); - break; + local_response: &mut Message, + peers_response: Message, + ) -> Result<()> { + let peers = parse_system_peers(peers_response, &self.data_center, &self.rack)?; + + let mut release_version_alias = "release_version"; + let mut tokens_alias = "tokens"; + let mut schema_version_alias = "schema_version"; + let mut broadcast_address_alias = "broadcast_address"; + let mut listen_address_alias = "listen_address"; + let mut host_id_alias = "host_id"; + let mut rack_alias = "rack"; + let mut data_center_alias = "data_center_alias"; + for select in &table.selects { + if let SelectElement::Column(column) = select { + if let Some(alias) = &column.alias { + let alias = match alias { + Identifier::Unquoted(alias) => alias, + Identifier::Quoted(alias) => alias, + }; + if column.name == Identifier::Unquoted("release_version".to_string()) { + release_version_alias = alias; + } else if column.name == Identifier::Unquoted("tokens".to_string()) { + tokens_alias = alias; + } else if column.name == Identifier::Unquoted("schema_version".to_string()) { + schema_version_alias = alias; + } else if column.name == Identifier::Unquoted("broadcast_address".to_string()) { + broadcast_address_alias = alias; + } else if column.name == Identifier::Unquoted("listen_address".to_string()) { + listen_address_alias = alias; + } else if column.name == Identifier::Unquoted("host_id".to_string()) { + host_id_alias = alias; + } else if column.name == Identifier::Unquoted("rack".to_string()) { + rack_alias = alias; + } else if column.name == Identifier::Unquoted("data_center".to_string()) { + data_center_alias = alias; + } + } } } - let version = match table.version { - Version::V3 => "3".to_string(), - Version::V4 => "4".to_string(), - Version::V5 => "5".to_string(), - }; - - let key_ident = Identifier::Unquoted("key".into()); - let bootstrapped_ident = Identifier::Unquoted("bootstrapped".into()); - let broadcast_address_ident = Identifier::Unquoted("broadcast_address".into()); - let cluster_name_ident = Identifier::Unquoted("cluster_name".into()); - let cql_version_ident = Identifier::Unquoted("cql_version".into()); - let data_center_ident = Identifier::Unquoted("data_center".into()); - let gossip_generation_ident = Identifier::Unquoted("gossip_generation".into()); - let host_id_ident = Identifier::Unquoted("host_id".into()); - let listen_address_ident = Identifier::Unquoted("listen_address".into()); - let native_protocol_version_ident = Identifier::Unquoted("native_protocol_version".into()); - let partitioner_ident = Identifier::Unquoted("partitioner".into()); - let rack_ident = Identifier::Unquoted("rack".into()); - let release_version_ident = Identifier::Unquoted("release_version".into()); - let rpc_address_ident = Identifier::Unquoted("rpc_address".into()); - let schema_version_ident = Identifier::Unquoted("schema_version".into()); - let tokens_ident = Identifier::Unquoted("tokens".into()); - let truncated_at_ident = Identifier::Unquoted("truncated_at".into()); - let mut col_specs = vec![]; - let mut row = vec![]; - for select in table.selects { - match select { - SelectElement::Star => { - row.extend([ - MessageValue::Varchar("local".into()), - MessageValue::Varchar("COMPLETED".into()), - MessageValue::Inet(local_address.ip()), - MessageValue::Varchar(local.cluster_name.clone()), - MessageValue::Varchar(local.cql_version.clone()), - MessageValue::Varchar(self.data_center.clone()), - MessageValue::Integer(local.gossip_generation, IntSize::I32), - MessageValue::Uuid(self.host_id), - MessageValue::Inet(local_address.ip()), - MessageValue::Varchar(version.clone()), - MessageValue::Varchar(local.partitioner.clone()), - MessageValue::Varchar(self.rack.clone()), - MessageValue::Varchar(local.release_version.clone()), - MessageValue::Inet("0.0.0.0".parse().unwrap()), - MessageValue::Uuid(local.schema_version), - MessageValue::List(local.tokens.clone()), - MessageValue::Null, - ]); - col_specs.extend([ - meta("key".into(), ColType::Varchar), - meta("bootstrapped".into(), ColType::Varchar), - meta("broadcast_address".into(), ColType::Inet), - //meta("broadcast_port".into(), ColType::Int), - meta("cluster_name".into(), ColType::Varchar), - meta("cql_version".into(), ColType::Varchar), - meta("data_center".into(), ColType::Varchar), - meta("gossip_generation".into(), ColType::Int), - meta("host_id".into(), ColType::Uuid), - meta("listen_address".into(), ColType::Inet), - //meta("listen_port".into(), ColType::Int), - meta("native_protocol_version".into(), ColType::Varchar), - meta("partitioner".into(), ColType::Varchar), - meta("rack".into(), ColType::Varchar), - meta("release_version".into(), ColType::Varchar), - meta("rpc_address".into(), ColType::Inet), - //meta("rpc_port".into(), ColType::Int), - meta("schema_version".into(), ColType::Uuid), - meta_tokens("tokens".into()), - meta_truncated_at("truncated_at".into()), - ]); - } - SelectElement::Function(name) => { - row.push(MessageValue::Varchar( - "ERROR: Functions are not supported by shotover".into(), - )); - // TODO: 90% sure SelectElement::Function(name) is not actually a name but the entire function call stuffed in a name LOL - col_specs.push(meta(name.to_string(), ColType::Varchar)); - } - SelectElement::Column(name) => { - if name.name == key_ident { - row.push(MessageValue::Varchar("local".into())); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Varchar)); - } else if name.name == bootstrapped_ident { - row.push(MessageValue::Varchar("COMPLETED".into())); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Varchar)); - } else if name.name == broadcast_address_ident { - row.push(MessageValue::Inet(local_address.ip())); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Inet)); - } else if name.name == cluster_name_ident { - row.push(MessageValue::Varchar(local.cluster_name.clone())); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Varchar)); - } else if name.name == cql_version_ident { - row.push(MessageValue::Varchar(local.cql_version.clone())); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Varchar)); - } else if name.name == data_center_ident { - row.push(MessageValue::Varchar(self.data_center.clone())); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Varchar)); - } else if name.name == gossip_generation_ident { - row.push(MessageValue::Integer(local.gossip_generation, IntSize::I32)); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Int)); - } else if name.name == host_id_ident { - row.push(MessageValue::Uuid(self.host_id)); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Uuid)); - } else if name.name == listen_address_ident { - row.push(MessageValue::Inet(local_address.ip())); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Inet)); - } else if name.name == native_protocol_version_ident { - row.push(MessageValue::Varchar(version.clone())); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Varchar)); - } else if name.name == partitioner_ident { - row.push(MessageValue::Varchar(local.partitioner.clone())); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Varchar)); - } else if name.name == rack_ident { - row.push(MessageValue::Varchar(self.rack.clone())); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Varchar)); - } else if name.name == release_version_ident { - row.push(MessageValue::Varchar(local.release_version.clone())); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Varchar)); - } else if name.name == rpc_address_ident { - row.push(MessageValue::Inet("0.0.0.0".parse().unwrap())); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Inet)); - } else if name.name == schema_version_ident { - row.push(MessageValue::Uuid(local.schema_version)); - col_specs.push(meta(name.alias_or_name().to_string(), ColType::Uuid)); - } else if name.name == tokens_ident { - row.push(MessageValue::List(local.tokens.clone())); - col_specs.push(meta_tokens(name.alias_or_name().to_string())); - } else if name.name == truncated_at_ident { - row.push(MessageValue::Null); - col_specs.push(meta_truncated_at(name.alias_or_name().to_string())); - } else { - tracing::error!("Unknown system.local column requested {:?}", name) + if let Some(Frame::Cassandra(frame)) = local_response.frame() { + if let CassandraOperation::Result(CassandraResult::Rows { + value: MessageValue::Rows(rows), + metadata, + }) = &mut frame.operation + { + for row in rows { + for (col, col_meta) in row.iter_mut().zip(metadata.col_specs.iter()) { + if col_meta.name == release_version_alias { + if let MessageValue::Varchar(release_version) = col { + for peer in &peers { + if let Ok(Cmp::Lt) = version_compare::compare( + &peer.release_version, + &release_version, + ) { + *release_version = peer.release_version.clone(); + } + } + } + } else if col_meta.name == tokens_alias { + if let MessageValue::List(tokens) = col { + for peer in &peers { + tokens.extend(peer.tokens.iter().cloned()); + } + tokens.sort(); + } + } else if col_meta.name == schema_version_alias { + if let MessageValue::Uuid(schema_version) = col { + for peer in &peers { + if schema_version != &peer.schema_version { + *schema_version = Uuid::new_v4(); + break; + } + } + } + } else if col_meta.name == broadcast_address_alias + || col_meta.name == listen_address_alias + { + if let MessageValue::Inet(address) = col { + *address = local_address.ip(); + } + } else if col_meta.name == host_id_alias { + if let MessageValue::Uuid(host_id) = col { + *host_id = self.host_id; + } + } else if col_meta.name == rack_alias { + if let MessageValue::Varchar(rack) = col { + if rack != &self.rack { + *rack = self.rack.clone(); + tracing::warn!("A contact point node is not in the configured rack, this node will receive traffic from outside of its rack"); + } + } + } else if col_meta.name == data_center_alias { + if let MessageValue::Varchar(data_center) = col { + if data_center != &self.data_center { + *data_center = self.data_center.clone(); + tracing::warn!("A contact point node is not in the configured data_center, this node will receive traffic from outside of its data_center"); + } + } + } } } } + Ok(()) + } else { + Err(anyhow!( + "Failed to parse system.local response {:?}", + local_response + )) } + } - Ok((vec![row], col_specs)) + // TODO: handle use statement state + fn is_system_local_or_peers(&self, request: &mut Message) -> bool { + if let Some(Frame::Cassandra(frame)) = request.frame() { + if let CassandraOperation::Query { query, .. } = &mut frame.operation { + if let CassandraStatement::Select(select) = query.as_ref() { + return self.local_table == select.table_name + || self.peer_table == select.table_name; + } + } + } + false } } struct TableToRewrite { index: usize, ty: RewriteTableTy, - stream_id: i16, version: Version, selects: Vec, } @@ -667,51 +562,6 @@ enum RewriteTableTy { Peers, } -fn meta(name: String, col_type: ColType) -> ColSpec { - ColSpec { - name, - table_spec: None, - col_type: ColTypeOption { - id: col_type, - value: None, - }, - } -} - -fn meta_tokens(name: String) -> ColSpec { - ColSpec { - name, - table_spec: None, - col_type: ColTypeOption { - id: ColType::Set, - value: Some(ColTypeOptionValue::CSet(Box::new(ColTypeOption { - id: ColType::Varchar, - value: None, - }))), - }, - } -} - -fn meta_truncated_at(name: String) -> ColSpec { - ColSpec { - name, - table_spec: None, - col_type: ColTypeOption { - id: ColType::Map, - value: Some(ColTypeOptionValue::CMap( - Box::new(ColTypeOption { - id: ColType::Uuid, - value: None, - }), - Box::new(ColTypeOption { - id: ColType::Blob, - value: None, - }), - )), - }, - } -} - pub fn create_topology_task( tls: Option, nodes: Arc>>, @@ -921,16 +771,6 @@ fn is_use_statement_successful(response: Option>) -> bool { false } -struct SystemLocal { - schema_version: Uuid, - gossip_generation: i64, - tokens: Vec, - partitioner: String, - cluster_name: String, - release_version: String, - cql_version: String, -} - struct SystemPeer { tokens: Vec, schema_version: Uuid, @@ -1006,90 +846,6 @@ fn parse_system_peers( } } -fn parse_system_local(mut response: Message) -> Result { - if let Some(Frame::Cassandra(frame)) = response.frame() { - match &mut frame.operation { - CassandraOperation::Result(CassandraResult::Rows { - value: MessageValue::Rows(rows), - .. - }) => { - if rows.len() > 1 { - tracing::error!("system.local returned more than one row"); - } - if let Some(row) = rows.first_mut() { - if row.len() != 7 { - return Err(anyhow!("expected 7 columns but was {}", row.len())); - } - - let tokens = if let Some(MessageValue::List(value)) = row.pop() { - value - } else { - return Err(anyhow!("tokens not a list")); - }; - - let gossip_generation = if let Some(MessageValue::Integer(value, _)) = row.pop() - { - value - } else { - return Err(anyhow!("gossip_generation not an int")); - }; - - let schema_version = if let Some(MessageValue::Uuid(value)) = row.pop() { - value - } else { - return Err(anyhow!("schema_version not a uuid")); - }; - - let partitioner = if let Some(MessageValue::Varchar(value)) = row.pop() { - value - } else { - return Err(anyhow!("partitioner not a varchar")); - }; - - let cluster_name = if let Some(MessageValue::Varchar(value)) = row.pop() { - value - } else { - return Err(anyhow!("cluster_name not a varchar")); - }; - - let release_version = if let Some(MessageValue::Varchar(value)) = row.pop() { - value - } else { - return Err(anyhow!("release_version not a varchar")); - }; - - let cql_version = if let Some(MessageValue::Varchar(value)) = row.pop() { - value - } else { - return Err(anyhow!("cql_version not a varchar")); - }; - - Ok(SystemLocal { - schema_version, - gossip_generation, - tokens, - partitioner, - cluster_name, - release_version, - cql_version, - }) - } else { - Err(anyhow!("system.local returned no rows")) - } - } - operation => Err(anyhow!( - "system.local returned unexpected cassandra operation: {:?}", - operation - )), - } - } else { - Err(anyhow!( - "Failed to parse system.local response {:?}", - response - )) - } -} - #[async_trait] impl Transform for CassandraSinkCluster { async fn transform<'a>(&'a mut self, message_wrapper: Wrapper<'a>) -> ChainResponse { diff --git a/shotover-proxy/tests/cassandra_int_tests/cluster.rs b/shotover-proxy/tests/cassandra_int_tests/cluster.rs index 47ec94001..2a46cac46 100644 --- a/shotover-proxy/tests/cassandra_int_tests/cluster.rs +++ b/shotover-proxy/tests/cassandra_int_tests/cluster.rs @@ -9,8 +9,8 @@ use tokio::sync::{mpsc, RwLock}; async fn test_rewrite_system_local(connection: &CassandraConnection) { assert_query_result(connection, "SELECT * FROM system.peers;", &[]).await; - assert_query_result(connection, "SELECT *, peer FROM system.peers;", &[]).await; assert_query_result(connection, "SELECT peer FROM system.peers;", &[]).await; + assert_query_result(connection, "SELECT peer, peer FROM system.peers;", &[]).await; let star_results = [ ResultValue::Varchar("local".into()), @@ -19,6 +19,7 @@ async fn test_rewrite_system_local(connection: &CassandraConnection) { ResultValue::Varchar("TestCluster".into()), ResultValue::Varchar("3.4.4".into()), ResultValue::Varchar("dc1".into()), + // gossip_generation is non deterministic cant assert on it ResultValue::Any, ResultValue::Uuid("2dd022d6-2937-4754-89d6-02d2933a8f7a".parse().unwrap()), ResultValue::Inet("127.0.0.1".parse().unwrap()), @@ -27,6 +28,9 @@ async fn test_rewrite_system_local(connection: &CassandraConnection) { ResultValue::Varchar("rack1".into()), ResultValue::Varchar("3.11.13".into()), ResultValue::Inet("0.0.0.0".parse().unwrap()), + // schema_version is non deterministic so we cant assert on it. + ResultValue::Any, + // thrift_version isnt used anymore so I dont really care what it maps to ResultValue::Any, // Unfortunately token generation appears to be non-deterministic but we can at least assert that // there are 128 tokens per node @@ -34,22 +38,22 @@ async fn test_rewrite_system_local(connection: &CassandraConnection) { ResultValue::Map(vec![]), ]; + let all_columns = + "key, bootstrapped, broadcast_address, cluster_name, cql_version, data_center, + gossip_generation, host_id, listen_address, native_protocol_version, partitioner, rack, + release_version, rpc_address, schema_version, thrift_version, tokens, truncated_at"; + assert_query_result(connection, "SELECT * FROM system.local;", &[&star_results]).await; assert_query_result( connection, - "SELECT *, key FROM system.local;", - &[&[ - star_results.as_slice(), - [ResultValue::Varchar("local".into())].as_slice(), - ] - .concat()], + &format!("SELECT {all_columns} FROM system.local;"), + &[&star_results], ) .await; assert_query_result( connection, - "SELECT key, bootstrapped, broadcast_address, cluster_name, cql_version, data_center, gossip_generation, host_id, - listen_address, native_protocol_version, partitioner, rack, release_version, rpc_address, schema_version, tokens, truncated_at FROM system.local;", - &[&star_results], + &format!("SELECT {all_columns}, {all_columns} FROM system.local;"), + &[&[star_results.as_slice(), star_results.as_slice()].concat()], ) .await; } diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index 108a55c8c..089a1f930 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -255,7 +255,10 @@ impl ResultValue { ValueType::DOUBLE => ResultValue::Double(value.get_f64().unwrap().into()), ValueType::DURATION => ResultValue::Duration(value.get_bytes().unwrap().to_vec()), ValueType::FLOAT => ResultValue::Float(value.get_f32().unwrap().into()), - ValueType::INET => ResultValue::Inet(value.get_inet().unwrap().to_string()), + ValueType::INET => value + .get_inet() + .map(|x| ResultValue::Inet(x.to_string())) + .unwrap_or_else(|_| ResultValue::Inet("NULL address".to_string())), ValueType::SMALL_INT => ResultValue::SmallInt(value.get_i16().unwrap()), ValueType::TIME => ResultValue::Time(value.get_bytes().unwrap().to_vec()), ValueType::TIMESTAMP => ResultValue::Timestamp(value.get_i64().unwrap()),