diff --git a/shotover-proxy/benches/benches/chain.rs b/shotover-proxy/benches/benches/chain.rs index ad916112e..56378dd5f 100644 --- a/shotover-proxy/benches/benches/chain.rs +++ b/shotover-proxy/benches/benches/chain.rs @@ -1,13 +1,10 @@ use bytes::Bytes; use cassandra_protocol::{ - compression::Compression, - consistency::Consistency, - frame::{Flags, Version}, - query::QueryParams, + compression::Compression, consistency::Consistency, frame::Version, query::QueryParams, }; use criterion::{criterion_group, BatchSize, Criterion}; use hex_literal::hex; -use shotover_proxy::frame::cassandra::parse_statement_single; +use shotover_proxy::frame::cassandra::{parse_statement_single, Tracing}; use shotover_proxy::frame::RedisFrame; use shotover_proxy::frame::{CassandraFrame, CassandraOperation, Frame, MessageType}; use shotover_proxy::message::{Message, QueryType}; @@ -233,9 +230,8 @@ fn criterion_benchmark(c: &mut Criterion) { vec![Message::from_bytes( CassandraFrame { version: Version::V4, - flags: Flags::default(), stream_id: 0, - tracing_id: None, + tracing: Tracing::Request(false), warnings: vec![], operation: CassandraOperation::Query { query: Box::new(parse_statement_single( @@ -341,9 +337,8 @@ fn cassandra_parsed_query(query: &str) -> Wrapper { Wrapper::new_with_chain_name( vec![Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, - flags: Flags::default(), stream_id: 0, - tracing_id: None, + tracing: Tracing::Request(false), warnings: vec![], operation: CassandraOperation::Query { query: Box::new(parse_statement_single(query)), diff --git a/shotover-proxy/benches/benches/codec.rs b/shotover-proxy/benches/benches/codec.rs index e5a412997..5ce1bef0a 100644 --- a/shotover-proxy/benches/benches/codec.rs +++ b/shotover-proxy/benches/benches/codec.rs @@ -2,13 +2,10 @@ use bytes::BytesMut; use cassandra_protocol::frame::message_result::{ ColSpec, ColType, ColTypeOption, ColTypeOptionValue, RowsMetadata, RowsMetadataFlags, TableSpec, }; -use cassandra_protocol::{ - frame::{Flags, Version}, - query::QueryParams, -}; +use cassandra_protocol::{frame::Version, query::QueryParams}; use criterion::{black_box, criterion_group, BatchSize, Criterion}; use shotover_proxy::codec::cassandra::CassandraCodec; -use shotover_proxy::frame::cassandra::parse_statement_single; +use shotover_proxy::frame::cassandra::{parse_statement_single, Tracing}; use shotover_proxy::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; use shotover_proxy::message::{IntSize, Message, MessageValue}; use tokio_util::codec::Encoder; @@ -20,9 +17,8 @@ fn criterion_benchmark(c: &mut Criterion) { { let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, - flags: Flags::default(), stream_id: 1, - tracing_id: None, + tracing: Tracing::Request(false), warnings: vec![], operation: CassandraOperation::Query { query: Box::new(parse_statement_single("SELECT * FROM system.local;")), @@ -48,9 +44,8 @@ fn criterion_benchmark(c: &mut Criterion) { { let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, - flags: Flags::default(), stream_id: 0, - tracing_id: None, + tracing: Tracing::Response(None), warnings: vec![], operation: CassandraOperation::Result(peers_v2_result()), }))]; diff --git a/shotover-proxy/src/codec/cassandra.rs b/shotover-proxy/src/codec/cassandra.rs index 75a9a6035..ed46163bb 100644 --- a/shotover-proxy/src/codec/cassandra.rs +++ b/shotover-proxy/src/codec/cassandra.rs @@ -1,4 +1,4 @@ -use crate::frame::cassandra::{CassandraMetadata, CassandraOperation}; +use crate::frame::cassandra::{CassandraMetadata, CassandraOperation, Tracing}; use crate::frame::{CassandraFrame, Frame, MessageType}; use crate::message::{Encodable, Message, Messages, Metadata}; use crate::server::CodecReadError; @@ -7,7 +7,7 @@ use bytes::{Buf, BufMut, BytesMut}; use cassandra_protocol::compression::Compression; use cassandra_protocol::frame::message_error::{ErrorBody, ErrorType}; use cassandra_protocol::frame::{ - CheckEnvelopeSizeError, Envelope as RawCassandraFrame, Flags, Opcode, Version, + CheckEnvelopeSizeError, Envelope as RawCassandraFrame, Opcode, Version, }; use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::Identifier; @@ -185,13 +185,12 @@ fn reject_protocol_version(version: u8) -> CodecReadError { CodecReadError::RespondAndThenCloseConnection(vec![Message::from_frame(Frame::Cassandra( CassandraFrame { version: Version::V4, - flags: Flags::default(), stream_id: 0, operation: CassandraOperation::Error(ErrorBody { message: "Invalid or unsupported protocol version".into(), ty: ErrorType::Protocol, }), - tracing_id: None, + tracing: Tracing::Response(None), warnings: vec![], }, ))]) @@ -225,7 +224,7 @@ impl Encoder for CassandraCodec { mod cassandra_protocol_tests { use crate::codec::cassandra::CassandraCodec; use crate::frame::cassandra::{ - parse_statement_single, CassandraFrame, CassandraOperation, CassandraResult, + parse_statement_single, CassandraFrame, CassandraOperation, CassandraResult, Tracing, }; use crate::frame::Frame; use crate::message::Message; @@ -236,7 +235,7 @@ mod cassandra_protocol_tests { ColSpec, ColType, ColTypeOption, ColTypeOptionValue, RowsMetadata, RowsMetadataFlags, TableSpec, }; - use cassandra_protocol::frame::{Flags, Version}; + use cassandra_protocol::frame::Version; use cassandra_protocol::query::QueryParams; use hex_literal::hex; use tokio_util::codec::{Decoder, Encoder}; @@ -282,12 +281,11 @@ mod cassandra_protocol_tests { let bytes = hex!("0400000001000000160001000b43514c5f56455253494f4e0005332e302e30"); let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, - flags: Flags::default(), operation: CassandraOperation::Startup(vec![ 0, 1, 0, 11, 67, 81, 76, 95, 86, 69, 82, 83, 73, 79, 78, 0, 5, 51, 46, 48, 46, 48, ]), stream_id: 0, - tracing_id: None, + tracing: Tracing::Request(false), warnings: vec![], }))]; test_frame_codec_roundtrip(&mut codec, &bytes, messages); @@ -299,10 +297,9 @@ mod cassandra_protocol_tests { let bytes = hex!("040000000500000000"); let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, - flags: Flags::default(), operation: CassandraOperation::Options(vec![]), stream_id: 0, - tracing_id: None, + tracing: Tracing::Request(false), warnings: vec![], }))]; test_frame_codec_roundtrip(&mut codec, &bytes, messages); @@ -314,10 +311,9 @@ mod cassandra_protocol_tests { let bytes = hex!("840000000200000000"); let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, - flags: Flags::default(), operation: CassandraOperation::Ready(vec![]), stream_id: 0, - tracing_id: None, + tracing: Tracing::Response(None), warnings: vec![], }))]; test_frame_codec_roundtrip(&mut codec, &bytes, messages); @@ -332,7 +328,6 @@ mod cassandra_protocol_tests { ); let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, - flags: Flags::default(), operation: CassandraOperation::Register(BodyReqRegister { events: vec![ SimpleServerEvent::TopologyChange, @@ -341,7 +336,7 @@ mod cassandra_protocol_tests { ], }), stream_id: 1, - tracing_id: None, + tracing: Tracing::Request(false), warnings: vec![], }))]; test_frame_codec_roundtrip(&mut codec, &bytes, messages); @@ -358,7 +353,6 @@ mod cassandra_protocol_tests { ); let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, - flags: Flags::default(), operation: CassandraOperation::Result(CassandraResult::Rows { rows: vec![], metadata: Box::new(RowsMetadata { @@ -450,7 +444,7 @@ mod cassandra_protocol_tests { }), }), stream_id: 2, - tracing_id: None, + tracing: Tracing::Response(None), warnings: vec![], }))]; test_frame_codec_roundtrip(&mut codec, &bytes, messages); @@ -466,9 +460,8 @@ mod cassandra_protocol_tests { let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, - flags: Flags::default(), stream_id: 3, - tracing_id: None, + tracing: Tracing::Request(false), warnings: vec![], operation: CassandraOperation::Query { query: Box::new(parse_statement_single( @@ -490,9 +483,8 @@ mod cassandra_protocol_tests { let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, - flags: Flags::default(), stream_id: 3, - tracing_id: None, + tracing: Tracing::Request(false), warnings: vec![], operation: CassandraOperation::Query { query: Box::new(parse_statement_single( diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index c36112cf1..2b34f4769 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -40,7 +40,7 @@ use uuid::Uuid; /// Functions for operations on an unparsed Cassandra frame pub mod raw_frame { - use super::{CassandraMetadata, RawCassandraFrame}; + use super::{CassandraMetadata, RawCassandraFrame, Tracing}; use anyhow::{anyhow, bail, Result}; use cassandra_protocol::{compression::Compression, frame::Opcode}; use nonzero_ext::nonzero; @@ -66,12 +66,11 @@ pub mod raw_frame { let frame = RawCassandraFrame::from_buffer(bytes, Compression::None) .map_err(|e| anyhow!("{e:?}"))? .envelope; - + let tracing = Tracing::from_frame(&frame)?; Ok(CassandraMetadata { version: frame.version, - flags: frame.flags, stream_id: frame.stream_id, - tracing_id: frame.tracing_id, + tracing, opcode: frame.opcode, }) } @@ -92,18 +91,61 @@ pub mod raw_frame { pub struct CassandraMetadata { pub version: Version, pub stream_id: StreamId, - pub tracing_id: Option, + pub tracing: Tracing, pub opcode: Opcode, - pub flags: Flags, // missing `warnings` field because we are not using it currently } +#[derive(PartialEq, Debug, Clone, Copy)] +pub enum Tracing { + Request(bool), + Response(Option), +} + +impl Tracing { + fn enabled(&self) -> bool { + match self { + Self::Request(enabled) => *enabled, + Self::Response(uuid) => uuid.is_some(), + } + } +} + +impl From for Option { + fn from(tracing: Tracing) -> Self { + match tracing { + Tracing::Request(_) => None, + Tracing::Response(uuid) => uuid, + } + } +} + +impl Tracing { + fn from_frame(frame: &RawCassandraFrame) -> Result { + match frame.direction { + Direction::Request => Ok(Self::Request(frame.flags.contains(Flags::TRACING))), + Direction::Response => { + if frame.tracing_id.is_none() && frame.flags.contains(Flags::TRACING) { + return Err(anyhow!("Frame has no tracing_id but tracing was set")); + } + + if frame.tracing_id.is_some() && !frame.flags.contains(Flags::TRACING) { + return Err(anyhow!( + "Frame has a tracing_id but tracing flag was not set" + )); + } + + Ok(Self::Response(frame.tracing_id)) + } + } + } +} + #[derive(PartialEq, Debug, Clone)] pub struct CassandraFrame { pub version: Version, - pub flags: Flags, pub stream_id: StreamId, - pub tracing_id: Option, + pub tracing: Tracing, pub warnings: Vec, /// Contains the message body pub operation: CassandraOperation, @@ -114,9 +156,8 @@ impl CassandraFrame { pub(crate) fn metadata(&self) -> CassandraMetadata { CassandraMetadata { version: self.version, - flags: self.flags, stream_id: self.stream_id, - tracing_id: self.tracing_id, + tracing: self.tracing, opcode: self.operation.to_opcode(), } } @@ -136,6 +177,8 @@ impl CassandraFrame { let frame = RawCassandraFrame::from_buffer(&bytes, Compression::None) .map_err(|e| anyhow!("{e:?}"))? .envelope; + + let tracing = Tracing::from_frame(&frame)?; let operation = match frame.opcode { Opcode::Query => { if let RequestBody::Query(body) = frame.request_body()? { @@ -296,9 +339,8 @@ impl CassandraFrame { Ok(CassandraFrame { version: frame.version, - flags: frame.flags, stream_id: frame.stream_id, - tracing_id: frame.tracing_id, + tracing, warnings: frame.warnings, operation, }) @@ -313,9 +355,9 @@ impl CassandraFrame { } pub fn encode(self) -> RawCassandraFrame { - let mut flags = Flags::empty(); + let mut flags = Flags::default(); flags.set(Flags::WARNING, !self.warnings.is_empty()); - flags.set(Flags::TRACING, self.tracing_id.is_some()); + flags.set(Flags::TRACING, self.tracing.enabled()); RawCassandraFrame { direction: self.operation.to_direction(), @@ -324,7 +366,7 @@ impl CassandraFrame { opcode: self.operation.to_opcode(), stream_id: self.stream_id, body: self.operation.into_body(self.version), - tracing_id: self.tracing_id, + tracing_id: self.tracing.into(), warnings: self.warnings, } } @@ -695,7 +737,7 @@ pub struct CassandraBatch { impl Display for CassandraFrame { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "{} stream:{}", self.version, self.stream_id)?; - if let Some(tracing_id) = self.tracing_id { + if let Tracing::Request(tracing_id) = self.tracing { write!(f, " tracing_id:{}", tracing_id)?; } if !self.warnings.is_empty() { diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index cbf96f805..c7a29a4e6 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -215,13 +215,12 @@ impl Message { )), Frame::Cassandra(frame) => Frame::Cassandra(CassandraFrame { version: frame.version, - flags: frame.flags, stream_id: frame.stream_id, operation: CassandraOperation::Error(ErrorBody { message: "Message was filtered out by shotover".into(), ty: ErrorType::Server, }), - tracing_id: frame.tracing_id, + tracing: frame.tracing, warnings: vec![], }), Frame::None => Frame::None, @@ -245,13 +244,12 @@ impl Message { } Metadata::Cassandra(frame) => Frame::Cassandra(CassandraFrame { version: frame.version, - flags: frame.flags, stream_id: frame.stream_id, operation: CassandraOperation::Error(ErrorBody { message: error, ty: ErrorType::Server, }), - tracing_id: frame.tracing_id, + tracing: frame.tracing, warnings: vec![], }), Metadata::None => Frame::None, @@ -293,9 +291,8 @@ impl Message { Frame::Cassandra(CassandraFrame { version: metadata.version, - flags: metadata.flags, stream_id: metadata.stream_id, - tracing_id: metadata.tracing_id, + tracing: metadata.tracing, warnings: vec![], operation: body, }) diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 9cc0a2761..efc89fc60 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -135,22 +135,21 @@ fn rewrite_port(message: &mut Message, column_names: &[Identifier], new_port: u1 #[cfg(test)] mod test { use super::*; - use crate::frame::cassandra::parse_statement_single; + use crate::frame::cassandra::{parse_statement_single, Tracing}; use crate::frame::CassandraFrame; use crate::transforms::cassandra::peers_rewrite::CassandraResult::Rows; use cassandra_protocol::consistency::Consistency; use cassandra_protocol::frame::message_result::{ ColSpec, ColType, ColTypeOption, RowsMetadata, RowsMetadataFlags, TableSpec, }; - use cassandra_protocol::frame::{Flags, Version}; + use cassandra_protocol::frame::Version; use cassandra_protocol::query::QueryParams; fn create_query_message(query: &str) -> Message { Message::from_frame(Frame::Cassandra(CassandraFrame { - flags: Flags::default(), version: Version::V4, stream_id: 0, - tracing_id: None, + tracing: Tracing::Request(false), warnings: vec![], operation: CassandraOperation::Query { query: Box::new(parse_statement_single(query)), @@ -172,9 +171,8 @@ mod test { fn create_response_message(col_specs: &[ColSpec], rows: Vec>) -> Message { Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, - flags: Flags::default(), stream_id: 0, - tracing_id: None, + tracing: Tracing::Response(None), warnings: vec![], operation: CassandraOperation::Result(Rows { rows, diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs index 0479e22e5..81a22d492 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs @@ -1,5 +1,5 @@ use crate::error::ChainResponse; -use crate::frame::cassandra::{parse_statement_single, CassandraMetadata}; +use crate::frame::cassandra::{parse_statement_single, CassandraMetadata, Tracing}; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; use crate::message::{IntSize, Message, MessageValue, Messages}; use crate::tls::{TlsConnector, TlsConnectorConfig}; @@ -12,7 +12,7 @@ use cassandra_protocol::events::ServerEvent; use cassandra_protocol::frame::message_error::{ErrorBody, ErrorType, UnpreparedError}; use cassandra_protocol::frame::message_execute::BodyReqExecuteOwned; use cassandra_protocol::frame::message_result::PreparedMetadata; -use cassandra_protocol::frame::{Flags, Opcode, Version}; +use cassandra_protocol::frame::{Opcode, Version}; use cassandra_protocol::query::QueryParams; use cassandra_protocol::types::CBytesShort; use cql3_parser::cassandra_statement::CassandraStatement; @@ -196,9 +196,8 @@ fn create_query(messages: &Messages, query: &str, version: Version) -> Result Option<(&BodyReqExecuteOwned, C if let Some(Frame::Cassandra(CassandraFrame { operation: CassandraOperation::Execute(execute_body), version, - flags, stream_id, - tracing_id, + tracing, .. })) = message.frame() { @@ -896,9 +893,8 @@ fn get_execute_message(message: &mut Message) -> Option<(&BodyReqExecuteOwned, C execute_body, CassandraMetadata { version: *version, - flags: *flags, stream_id: *stream_id, - tracing_id: *tracing_id, + tracing: *tracing, opcode: Opcode::Execute, }, )); diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs index ca1ed4944..170d54a7c 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs @@ -1,5 +1,5 @@ use super::node::{CassandraNode, ConnectionFactory}; -use crate::frame::cassandra::parse_statement_single; +use crate::frame::cassandra::{parse_statement_single, Tracing}; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; use crate::message::{Message, MessageValue}; use crate::transforms::cassandra::connection::CassandraConnection; @@ -8,10 +8,7 @@ use cassandra_protocol::events::{ServerEvent, SimpleServerEvent}; use cassandra_protocol::frame::events::{StatusChangeType, TopologyChangeType}; use cassandra_protocol::frame::message_register::BodyReqRegister; use cassandra_protocol::token::Murmur3Token; -use cassandra_protocol::{ - frame::{Flags, Version}, - query::QueryParams, -}; +use cassandra_protocol::{frame::Version, query::QueryParams}; use std::net::SocketAddr; use tokio::sync::mpsc::unbounded_channel; use tokio::sync::{mpsc, oneshot, watch}; @@ -147,8 +144,7 @@ async fn register_for_topology_and_status_events( Message::from_frame(Frame::Cassandra(CassandraFrame { version, stream_id: 0, - flags: Flags::default(), - tracing_id: None, + tracing: Tracing::Request(false), warnings: vec![], operation: CassandraOperation::Register(BodyReqRegister { events: vec![ @@ -199,9 +195,8 @@ mod system_local { connection.send( Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, - flags: Flags::default(), stream_id: 1, - tracing_id: None, + tracing: Tracing::Request(false), warnings: vec![], operation: CassandraOperation::Query { query: Box::new(parse_statement_single( @@ -287,8 +282,7 @@ mod system_peers { Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, stream_id: 0, - flags: Flags::default(), - tracing_id: None, + tracing: Tracing::Request(false), warnings: vec![], operation: CassandraOperation::Query { query: Box::new(parse_statement_single( @@ -308,8 +302,7 @@ mod system_peers { Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, stream_id: 0, - flags: Flags::default(), - tracing_id: None, + tracing: Tracing::Request(false), warnings: vec![], operation: CassandraOperation::Query { query: Box::new(parse_statement_single( diff --git a/shotover-proxy/tests/cassandra_int_tests/cluster.rs b/shotover-proxy/tests/cassandra_int_tests/cluster.rs index 65d478779..719156a6f 100644 --- a/shotover-proxy/tests/cassandra_int_tests/cluster.rs +++ b/shotover-proxy/tests/cassandra_int_tests/cluster.rs @@ -1,5 +1,5 @@ -use cassandra_protocol::frame::{Flags, Version}; -use shotover_proxy::frame::{CassandraFrame, CassandraOperation, Frame}; +use cassandra_protocol::frame::Version; +use shotover_proxy::frame::{cassandra::Tracing, CassandraFrame, CassandraOperation, Frame}; use shotover_proxy::message::Message; use shotover_proxy::tls::{TlsConnector, TlsConnectorConfig}; use shotover_proxy::transforms::cassandra::sink_cluster::{ @@ -47,17 +47,15 @@ fn create_handshake() -> Vec { vec![ Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, - flags: Flags::default(), stream_id: 64, - tracing_id: None, + tracing: Tracing::Request(false), warnings: vec![], operation: CassandraOperation::Startup(b"\0\x01\0\x0bCQL_VERSION\0\x053.0.0".to_vec()), })), Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, - flags: Flags::default(), stream_id: 128, - tracing_id: None, + tracing: Tracing::Response(None), warnings: vec![], operation: CassandraOperation::AuthResponse( b"\0\0\0\x14\0cassandra\0cassandra".to_vec(),