Skip to content

Commit

Permalink
flag fix
Browse files Browse the repository at this point in the history
  • Loading branch information
conorbros committed Oct 20, 2022
1 parent f800046 commit b8f695d
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 95 deletions.
13 changes: 4 additions & 9 deletions shotover-proxy/benches/benches/chain.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
use bytes::Bytes;
use cassandra_protocol::{
compression::Compression,
consistency::Consistency,
frame::{Flags, Version},
query::QueryParams,
compression::Compression, consistency::Consistency, frame::Version, query::QueryParams,
};
use criterion::{criterion_group, BatchSize, Criterion};
use hex_literal::hex;
use shotover_proxy::frame::cassandra::parse_statement_single;
use shotover_proxy::frame::cassandra::{parse_statement_single, Tracing};
use shotover_proxy::frame::RedisFrame;
use shotover_proxy::frame::{CassandraFrame, CassandraOperation, Frame, MessageType};
use shotover_proxy::message::{Message, QueryType};
Expand Down Expand Up @@ -233,9 +230,8 @@ fn criterion_benchmark(c: &mut Criterion) {
vec![Message::from_bytes(
CassandraFrame {
version: Version::V4,
flags: Flags::default(),
stream_id: 0,
tracing_id: None,
tracing: Tracing::Request(false),
warnings: vec![],
operation: CassandraOperation::Query {
query: Box::new(parse_statement_single(
Expand Down Expand Up @@ -341,9 +337,8 @@ fn cassandra_parsed_query(query: &str) -> Wrapper {
Wrapper::new_with_chain_name(
vec![Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
flags: Flags::default(),
stream_id: 0,
tracing_id: None,
tracing: Tracing::Request(false),
warnings: vec![],
operation: CassandraOperation::Query {
query: Box::new(parse_statement_single(query)),
Expand Down
13 changes: 4 additions & 9 deletions shotover-proxy/benches/benches/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ use bytes::BytesMut;
use cassandra_protocol::frame::message_result::{
ColSpec, ColType, ColTypeOption, ColTypeOptionValue, RowsMetadata, RowsMetadataFlags, TableSpec,
};
use cassandra_protocol::{
frame::{Flags, Version},
query::QueryParams,
};
use cassandra_protocol::{frame::Version, query::QueryParams};
use criterion::{black_box, criterion_group, BatchSize, Criterion};
use shotover_proxy::codec::cassandra::CassandraCodec;
use shotover_proxy::frame::cassandra::parse_statement_single;
use shotover_proxy::frame::cassandra::{parse_statement_single, Tracing};
use shotover_proxy::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame};
use shotover_proxy::message::{IntSize, Message, MessageValue};
use tokio_util::codec::Encoder;
Expand All @@ -20,9 +17,8 @@ fn criterion_benchmark(c: &mut Criterion) {
{
let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
flags: Flags::default(),
stream_id: 1,
tracing_id: None,
tracing: Tracing::Request(false),
warnings: vec![],
operation: CassandraOperation::Query {
query: Box::new(parse_statement_single("SELECT * FROM system.local;")),
Expand All @@ -48,9 +44,8 @@ fn criterion_benchmark(c: &mut Criterion) {
{
let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
flags: Flags::default(),
stream_id: 0,
tracing_id: None,
tracing: Tracing::Response(None),
warnings: vec![],
operation: CassandraOperation::Result(peers_v2_result()),
}))];
Expand Down
32 changes: 12 additions & 20 deletions shotover-proxy/src/codec/cassandra.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::frame::cassandra::{CassandraMetadata, CassandraOperation};
use crate::frame::cassandra::{CassandraMetadata, CassandraOperation, Tracing};
use crate::frame::{CassandraFrame, Frame, MessageType};
use crate::message::{Encodable, Message, Messages, Metadata};
use crate::server::CodecReadError;
Expand All @@ -7,7 +7,7 @@ use bytes::{Buf, BufMut, BytesMut};
use cassandra_protocol::compression::Compression;
use cassandra_protocol::frame::message_error::{ErrorBody, ErrorType};
use cassandra_protocol::frame::{
CheckEnvelopeSizeError, Envelope as RawCassandraFrame, Flags, Opcode, Version,
CheckEnvelopeSizeError, Envelope as RawCassandraFrame, Opcode, Version,
};
use cql3_parser::cassandra_statement::CassandraStatement;
use cql3_parser::common::Identifier;
Expand Down Expand Up @@ -185,13 +185,12 @@ fn reject_protocol_version(version: u8) -> CodecReadError {
CodecReadError::RespondAndThenCloseConnection(vec![Message::from_frame(Frame::Cassandra(
CassandraFrame {
version: Version::V4,
flags: Flags::default(),
stream_id: 0,
operation: CassandraOperation::Error(ErrorBody {
message: "Invalid or unsupported protocol version".into(),
ty: ErrorType::Protocol,
}),
tracing_id: None,
tracing: Tracing::Response(None),
warnings: vec![],
},
))])
Expand Down Expand Up @@ -225,7 +224,7 @@ impl Encoder<Messages> for CassandraCodec {
mod cassandra_protocol_tests {
use crate::codec::cassandra::CassandraCodec;
use crate::frame::cassandra::{
parse_statement_single, CassandraFrame, CassandraOperation, CassandraResult,
parse_statement_single, CassandraFrame, CassandraOperation, CassandraResult, Tracing,
};
use crate::frame::Frame;
use crate::message::Message;
Expand All @@ -236,7 +235,7 @@ mod cassandra_protocol_tests {
ColSpec, ColType, ColTypeOption, ColTypeOptionValue, RowsMetadata, RowsMetadataFlags,
TableSpec,
};
use cassandra_protocol::frame::{Flags, Version};
use cassandra_protocol::frame::Version;
use cassandra_protocol::query::QueryParams;
use hex_literal::hex;
use tokio_util::codec::{Decoder, Encoder};
Expand Down Expand Up @@ -282,12 +281,11 @@ mod cassandra_protocol_tests {
let bytes = hex!("0400000001000000160001000b43514c5f56455253494f4e0005332e302e30");
let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
flags: Flags::default(),
operation: CassandraOperation::Startup(vec![
0, 1, 0, 11, 67, 81, 76, 95, 86, 69, 82, 83, 73, 79, 78, 0, 5, 51, 46, 48, 46, 48,
]),
stream_id: 0,
tracing_id: None,
tracing: Tracing::Request(false),
warnings: vec![],
}))];
test_frame_codec_roundtrip(&mut codec, &bytes, messages);
Expand All @@ -299,10 +297,9 @@ mod cassandra_protocol_tests {
let bytes = hex!("040000000500000000");
let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
flags: Flags::default(),
operation: CassandraOperation::Options(vec![]),
stream_id: 0,
tracing_id: None,
tracing: Tracing::Request(false),
warnings: vec![],
}))];
test_frame_codec_roundtrip(&mut codec, &bytes, messages);
Expand All @@ -314,10 +311,9 @@ mod cassandra_protocol_tests {
let bytes = hex!("840000000200000000");
let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
flags: Flags::default(),
operation: CassandraOperation::Ready(vec![]),
stream_id: 0,
tracing_id: None,
tracing: Tracing::Response(None),
warnings: vec![],
}))];
test_frame_codec_roundtrip(&mut codec, &bytes, messages);
Expand All @@ -332,7 +328,6 @@ mod cassandra_protocol_tests {
);
let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
flags: Flags::default(),
operation: CassandraOperation::Register(BodyReqRegister {
events: vec![
SimpleServerEvent::TopologyChange,
Expand All @@ -341,7 +336,7 @@ mod cassandra_protocol_tests {
],
}),
stream_id: 1,
tracing_id: None,
tracing: Tracing::Request(false),
warnings: vec![],
}))];
test_frame_codec_roundtrip(&mut codec, &bytes, messages);
Expand All @@ -358,7 +353,6 @@ mod cassandra_protocol_tests {
);
let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
flags: Flags::default(),
operation: CassandraOperation::Result(CassandraResult::Rows {
rows: vec![],
metadata: Box::new(RowsMetadata {
Expand Down Expand Up @@ -450,7 +444,7 @@ mod cassandra_protocol_tests {
}),
}),
stream_id: 2,
tracing_id: None,
tracing: Tracing::Response(None),
warnings: vec![],
}))];
test_frame_codec_roundtrip(&mut codec, &bytes, messages);
Expand All @@ -466,9 +460,8 @@ mod cassandra_protocol_tests {

let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
flags: Flags::default(),
stream_id: 3,
tracing_id: None,
tracing: Tracing::Request(false),
warnings: vec![],
operation: CassandraOperation::Query {
query: Box::new(parse_statement_single(
Expand All @@ -490,9 +483,8 @@ mod cassandra_protocol_tests {

let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
flags: Flags::default(),
stream_id: 3,
tracing_id: None,
tracing: Tracing::Request(false),
warnings: vec![],
operation: CassandraOperation::Query {
query: Box::new(parse_statement_single(
Expand Down
74 changes: 58 additions & 16 deletions shotover-proxy/src/frame/cassandra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use uuid::Uuid;

/// Functions for operations on an unparsed Cassandra frame
pub mod raw_frame {
use super::{CassandraMetadata, RawCassandraFrame};
use super::{CassandraMetadata, RawCassandraFrame, Tracing};
use anyhow::{anyhow, bail, Result};
use cassandra_protocol::{compression::Compression, frame::Opcode};
use nonzero_ext::nonzero;
Expand All @@ -66,12 +66,11 @@ pub mod raw_frame {
let frame = RawCassandraFrame::from_buffer(bytes, Compression::None)
.map_err(|e| anyhow!("{e:?}"))?
.envelope;

let tracing = Tracing::from_frame(&frame)?;
Ok(CassandraMetadata {
version: frame.version,
flags: frame.flags,
stream_id: frame.stream_id,
tracing_id: frame.tracing_id,
tracing,
opcode: frame.opcode,
})
}
Expand All @@ -92,18 +91,61 @@ pub mod raw_frame {
pub struct CassandraMetadata {
pub version: Version,
pub stream_id: StreamId,
pub tracing_id: Option<Uuid>,
pub tracing: Tracing,
pub opcode: Opcode,
pub flags: Flags,
// missing `warnings` field because we are not using it currently
}

#[derive(PartialEq, Debug, Clone, Copy)]
pub enum Tracing {
Request(bool),
Response(Option<Uuid>),
}

impl Tracing {
fn enabled(&self) -> bool {
match self {
Self::Request(enabled) => *enabled,
Self::Response(uuid) => uuid.is_some(),
}
}
}

impl From<Tracing> for Option<Uuid> {
fn from(tracing: Tracing) -> Self {
match tracing {
Tracing::Request(_) => None,
Tracing::Response(uuid) => uuid,
}
}
}

impl Tracing {
fn from_frame(frame: &RawCassandraFrame) -> Result<Self> {
match frame.direction {
Direction::Request => Ok(Self::Request(frame.flags.contains(Flags::TRACING))),
Direction::Response => {
if frame.tracing_id.is_none() && frame.flags.contains(Flags::TRACING) {
return Err(anyhow!("Frame has no tracing_id but tracing was set"));
}

if frame.tracing_id.is_some() && !frame.flags.contains(Flags::TRACING) {
return Err(anyhow!(
"Frame has a tracing_id but tracing flag was not set"
));
}

Ok(Self::Response(frame.tracing_id))
}
}
}
}

#[derive(PartialEq, Debug, Clone)]
pub struct CassandraFrame {
pub version: Version,
pub flags: Flags,
pub stream_id: StreamId,
pub tracing_id: Option<Uuid>,
pub tracing: Tracing,
pub warnings: Vec<String>,
/// Contains the message body
pub operation: CassandraOperation,
Expand All @@ -114,9 +156,8 @@ impl CassandraFrame {
pub(crate) fn metadata(&self) -> CassandraMetadata {
CassandraMetadata {
version: self.version,
flags: self.flags,
stream_id: self.stream_id,
tracing_id: self.tracing_id,
tracing: self.tracing,
opcode: self.operation.to_opcode(),
}
}
Expand All @@ -136,6 +177,8 @@ impl CassandraFrame {
let frame = RawCassandraFrame::from_buffer(&bytes, Compression::None)
.map_err(|e| anyhow!("{e:?}"))?
.envelope;

let tracing = Tracing::from_frame(&frame)?;
let operation = match frame.opcode {
Opcode::Query => {
if let RequestBody::Query(body) = frame.request_body()? {
Expand Down Expand Up @@ -296,9 +339,8 @@ impl CassandraFrame {

Ok(CassandraFrame {
version: frame.version,
flags: frame.flags,
stream_id: frame.stream_id,
tracing_id: frame.tracing_id,
tracing,
warnings: frame.warnings,
operation,
})
Expand All @@ -313,9 +355,9 @@ impl CassandraFrame {
}

pub fn encode(self) -> RawCassandraFrame {
let mut flags = Flags::empty();
let mut flags = Flags::default();
flags.set(Flags::WARNING, !self.warnings.is_empty());
flags.set(Flags::TRACING, self.tracing_id.is_some());
flags.set(Flags::TRACING, self.tracing.enabled());

RawCassandraFrame {
direction: self.operation.to_direction(),
Expand All @@ -324,7 +366,7 @@ impl CassandraFrame {
opcode: self.operation.to_opcode(),
stream_id: self.stream_id,
body: self.operation.into_body(self.version),
tracing_id: self.tracing_id,
tracing_id: self.tracing.into(),
warnings: self.warnings,
}
}
Expand Down Expand Up @@ -695,7 +737,7 @@ pub struct CassandraBatch {
impl Display for CassandraFrame {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "{} stream:{}", self.version, self.stream_id)?;
if let Some(tracing_id) = self.tracing_id {
if let Tracing::Request(tracing_id) = self.tracing {
write!(f, " tracing_id:{}", tracing_id)?;
}
if !self.warnings.is_empty() {
Expand Down
Loading

0 comments on commit b8f695d

Please sign in to comment.