From 0dd3abefa28e1196768b3d5be8fb59e2acd7286a Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Tue, 13 Sep 2022 17:13:58 +1000 Subject: [PATCH] CassandraSinkCluster keyspace based routing - handle use statements --- shotover-proxy/src/codec/cassandra.rs | 95 ++++++++++++++++++- shotover-proxy/src/frame/cassandra.rs | 13 +-- shotover-proxy/src/message/mod.rs | 14 +-- .../src/transforms/cassandra/connection.rs | 16 ++-- .../transforms/cassandra/sink_cluster/mod.rs | 1 - .../cluster_single_rack_v4.rs | 6 +- 6 files changed, 110 insertions(+), 35 deletions(-) diff --git a/shotover-proxy/src/codec/cassandra.rs b/shotover-proxy/src/codec/cassandra.rs index b14bad621..50a6ee58a 100644 --- a/shotover-proxy/src/codec/cassandra.rs +++ b/shotover-proxy/src/codec/cassandra.rs @@ -1,12 +1,16 @@ -use crate::frame::cassandra::CassandraOperation; +use crate::frame::cassandra::{CassandraMetadata, CassandraOperation}; use crate::frame::{CassandraFrame, Frame, MessageType}; -use crate::message::{Encodable, Message, Messages}; +use crate::message::{Encodable, Message, Messages, Metadata}; use crate::server::CodecReadError; use anyhow::{anyhow, Result}; use bytes::{Buf, BufMut, BytesMut}; use cassandra_protocol::compression::Compression; use cassandra_protocol::frame::message_error::{AdditionalErrorInfo, ErrorBody}; -use cassandra_protocol::frame::{CheckEnvelopeSizeError, Envelope as RawCassandraFrame, Version}; +use cassandra_protocol::frame::{ + CheckEnvelopeSizeError, Envelope as RawCassandraFrame, Opcode, Version, +}; +use cql3_parser::cassandra_statement::CassandraStatement; +use cql3_parser::common::Identifier; use tokio_util::codec::{Decoder, Encoder}; use tracing::info; @@ -14,6 +18,7 @@ use tracing::info; pub struct CassandraCodec { compressor: Compression, messages: Vec, + current_use_keyspace: Option, } impl Default for CassandraCodec { @@ -27,6 +32,7 @@ impl CassandraCodec { CassandraCodec { compressor: Compression::None, messages: vec![], + current_use_keyspace: None, } } } @@ -65,8 +71,23 @@ impl Decoder for CassandraCodec { return Err(reject_protocol_version(version.into())); } - self.messages - .push(Message::from_bytes(bytes.freeze(), MessageType::Cassandra)); + let mut message = Message::from_bytes(bytes.freeze(), MessageType::Cassandra); + + if let Ok(Metadata::Cassandra(CassandraMetadata { + opcode: Opcode::Query | Opcode::Batch, + .. + })) = message.metadata() + { + if let Some(keyspace) = get_use_keyspace(&mut message) { + self.current_use_keyspace = Some(keyspace); + } + + if let Some(keyspace) = &self.current_use_keyspace { + set_default_keyspace(&mut message, keyspace); + } + } + + self.messages.push(message); } Err(CheckEnvelopeSizeError::NotEnoughBytes) => { if self.messages.is_empty() || src.remaining() != 0 { @@ -89,6 +110,70 @@ impl Decoder for CassandraCodec { } } +fn get_use_keyspace(message: &mut Message) -> Option { + if let Some(Frame::Cassandra(frame)) = message.frame() { + if let CassandraOperation::Query { query, .. } = &mut frame.operation { + if let CassandraStatement::Use(keyspace) = query.as_ref() { + return Some(keyspace.clone()); + } + } + } + None +} + +fn set_default_keyspace(message: &mut Message, keyspace: &Identifier) { + // TODO: rewrite Operation::Prepared in the same way + if let Some(Frame::Cassandra(frame)) = message.frame() { + for query in frame.operation.queries() { + let name = match query { + CassandraStatement::AlterMaterializedView(x) => &mut x.name, + CassandraStatement::AlterTable(x) => &mut x.name, + CassandraStatement::AlterType(x) => &mut x.name, + CassandraStatement::CreateAggregate(x) => &mut x.name, + CassandraStatement::CreateFunction(x) => &mut x.name, + CassandraStatement::CreateIndex(x) => &mut x.table, + CassandraStatement::CreateMaterializedView(x) => &mut x.name, + CassandraStatement::CreateTable(x) => &mut x.name, + CassandraStatement::CreateTrigger(x) => &mut x.name, + CassandraStatement::CreateType(x) => &mut x.name, + CassandraStatement::Delete(x) => &mut x.table_name, + CassandraStatement::DropAggregate(x) => &mut x.name, + CassandraStatement::DropFunction(x) => &mut x.name, + CassandraStatement::DropIndex(x) => &mut x.name, + CassandraStatement::DropMaterializedView(x) => &mut x.name, + CassandraStatement::DropTable(x) => &mut x.name, + CassandraStatement::DropTrigger(x) => &mut x.name, + CassandraStatement::DropType(x) => &mut x.name, + CassandraStatement::Insert(x) => &mut x.table_name, + CassandraStatement::Select(x) => &mut x.table_name, + CassandraStatement::Truncate(name) => name, + CassandraStatement::Update(x) => &mut x.table_name, + CassandraStatement::AlterKeyspace(_) + | CassandraStatement::AlterRole(_) + | CassandraStatement::AlterUser(_) + | CassandraStatement::ApplyBatch + | CassandraStatement::CreateKeyspace(_) + | CassandraStatement::CreateRole(_) + | CassandraStatement::CreateUser(_) + | CassandraStatement::DropRole(_) + | CassandraStatement::DropUser(_) + | CassandraStatement::Grant(_) + | CassandraStatement::ListRoles(_) + | CassandraStatement::Revoke(_) + | CassandraStatement::DropKeyspace(_) + | CassandraStatement::ListPermissions(_) + | CassandraStatement::Use(_) + | CassandraStatement::Unknown(_) => { + return; + } + }; + if name.keyspace.is_none() { + name.keyspace = Some(keyspace.clone()); + } + } + } +} + /// If the client tried to use a protocol that we dont support then we need to reject it. /// The rejection process is sending back an error and then closing the connection. fn reject_protocol_version(version: u8) -> CodecReadError { diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 4db6b7a0c..f57f6b7e6 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -69,6 +69,7 @@ pub mod raw_frame { version: frame.version, stream_id: frame.stream_id, tracing_id: frame.tracing_id, + opcode: frame.opcode, }) } @@ -83,20 +84,13 @@ pub mod raw_frame { _ => nonzero!(1u32), }) } - - pub(crate) fn get_opcode(bytes: &[u8]) -> Result { - if bytes.len() < 9 { - bail!("Cassandra frame too short, needs at least 9 bytes for header"); - } - let opcode = Opcode::try_from(bytes[4])?; - Ok(opcode) - } } -pub(crate) struct CassandraMetadata { +pub struct CassandraMetadata { pub version: Version, pub stream_id: StreamId, pub tracing_id: Option, + pub opcode: Opcode, // missing `warnings` field because we are not using it currently } @@ -117,6 +111,7 @@ impl CassandraFrame { version: self.version, stream_id: self.stream_id, tracing_id: self.tracing_id, + opcode: self.operation.to_opcode(), } } diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index 4232a0d55..a905c9a05 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -32,7 +32,7 @@ use std::net::IpAddr; use std::num::NonZeroU32; use uuid::Uuid; -enum Metadata { +pub enum Metadata { Cassandra(CassandraMetadata), Redis, None, @@ -171,16 +171,6 @@ impl Message { } } - /// Only use for messages read straight from the socket - /// that are definitely in an unparsed state - /// (haven't passed through any transforms where they might have been parsed or modified) - pub(crate) fn as_raw_bytes(&self) -> Option<&Bytes> { - match self.inner.as_ref().unwrap() { - MessageInner::RawBytes { bytes, .. } => Some(bytes), - _ => None, - } - } - /// Batch messages have a cell count of 1 cell per inner message. /// Cell count is determined as follows: /// * Regular message - 1 cell @@ -270,7 +260,7 @@ impl Message { } /// Get metadata for this `Message` - fn metadata(&self) -> Result { + pub fn metadata(&self) -> Result { match self.inner.as_ref().unwrap() { MessageInner::RawBytes { bytes, diff --git a/shotover-proxy/src/transforms/cassandra/connection.rs b/shotover-proxy/src/transforms/cassandra/connection.rs index cf91784b6..6be60700d 100644 --- a/shotover-proxy/src/transforms/cassandra/connection.rs +++ b/shotover-proxy/src/transforms/cassandra/connection.rs @@ -1,6 +1,6 @@ use crate::codec::cassandra::CassandraCodec; -use crate::frame::cassandra; -use crate::message::Message; +use crate::frame::cassandra::CassandraMetadata; +use crate::message::{Message, Metadata}; use crate::tls::TlsConnector; use crate::transforms::util::Response; use crate::transforms::Messages; @@ -139,7 +139,7 @@ async fn rx_process_fallible( match response { Ok(response) => { for m in response { - if let Ok(Opcode::Event) = cassandra::raw_frame::get_opcode(m.as_raw_bytes().unwrap()) { + if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = m.metadata() { if let Some(ref pushed_messages_tx) = pushed_messages_tx { pushed_messages_tx.send(vec![m]).unwrap(); } @@ -223,10 +223,12 @@ pub async fn receive_message( response: Ok(message), .. } => { - if let Some(raw_bytes) = message.as_raw_bytes() { - if let Ok(Opcode::Error) = cassandra::raw_frame::get_opcode(raw_bytes) { - failed_requests.increment(1); - } + if let Ok(Metadata::Cassandra(CassandraMetadata { + opcode: Opcode::Error, + .. + })) = message.metadata() + { + failed_requests.increment(1); } Ok(message) } diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs index 07bb8563e..da2eac8d1 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs @@ -721,7 +721,6 @@ impl CassandraSinkCluster { } } - // TODO: handle use statement state fn is_system_query(&self, request: &mut Message) -> bool { if let Some(Frame::Cassandra(frame)) = request.frame() { if let CassandraOperation::Query { query, .. } = &mut frame.operation { diff --git a/shotover-proxy/tests/cassandra_int_tests/cluster_single_rack_v4.rs b/shotover-proxy/tests/cassandra_int_tests/cluster_single_rack_v4.rs index 7d91a2c78..e2dedd817 100644 --- a/shotover-proxy/tests/cassandra_int_tests/cluster_single_rack_v4.rs +++ b/shotover-proxy/tests/cassandra_int_tests/cluster_single_rack_v4.rs @@ -1,5 +1,5 @@ use crate::cassandra_int_tests::cluster::run_topology_task; -use crate::helpers::cassandra::{assert_query_result, CassandraConnection, ResultValue}; +use crate::helpers::cassandra::{assert_query_result, run_query, CassandraConnection, ResultValue}; use std::net::SocketAddr; async fn test_rewrite_system_peers(connection: &CassandraConnection) { @@ -22,6 +22,10 @@ async fn test_rewrite_system_peers(connection: &CassandraConnection) { async fn test_rewrite_system_peers_v2(connection: &CassandraConnection) { let all_columns = "peer, peer_port, data_center, host_id, native_address, native_port, preferred_ip, preferred_port, rack, release_version, schema_version, tokens"; assert_query_result(connection, "SELECT * FROM system.peers_v2;", &[]).await; + + run_query(connection, "USE system;").await; + assert_query_result(connection, "SELECT * FROM peers_v2;", &[]).await; + assert_query_result( connection, &format!("SELECT {all_columns} FROM system.peers_v2;"),