Skip to content

Commit

Permalink
split codecs into builders, encoders and decoders (#1029)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Feb 20, 2023
1 parent 8ffd548 commit d048005
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 107 deletions.
11 changes: 6 additions & 5 deletions shotover-proxy/benches/benches/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use cassandra_protocol::frame::message_result::{
};
use cassandra_protocol::frame::Version;
use criterion::{black_box, criterion_group, BatchSize, Criterion};
use shotover_proxy::codec::cassandra::CassandraCodec;
use shotover_proxy::codec::cassandra::CassandraCodecBuilder;
use shotover_proxy::codec::CodecBuilder;
use shotover_proxy::frame::cassandra::{parse_statement_single, Tracing};
use shotover_proxy::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame};
use shotover_proxy::message::Message;
Expand All @@ -27,14 +28,14 @@ fn criterion_benchmark(c: &mut Criterion) {
},
}))];

let mut codec = CassandraCodec::new();
let (_, mut encoder) = CassandraCodecBuilder::new().build();

group.bench_function("encode_cassandra_system.local_query", |b| {
b.iter_batched(
|| messages.clone(),
|messages| {
let mut bytes = BytesMut::new();
codec.encode(messages, &mut bytes).unwrap();
encoder.encode(messages, &mut bytes).unwrap();
black_box(bytes)
},
BatchSize::SmallInput,
Expand All @@ -51,14 +52,14 @@ fn criterion_benchmark(c: &mut Criterion) {
operation: CassandraOperation::Result(peers_v2_result()),
}))];

let mut codec = CassandraCodec::new();
let (_, mut encoder) = CassandraCodecBuilder::new().build();

group.bench_function("encode_cassandra_system.local_result", |b| {
b.iter_batched(
|| messages.clone(),
|messages| {
let mut bytes = BytesMut::new();
codec.encode(messages, &mut bytes).unwrap();
encoder.encode(messages, &mut bytes).unwrap();
black_box(bytes)
},
BatchSize::SmallInput,
Expand Down
118 changes: 67 additions & 51 deletions shotover-proxy/src/codec/cassandra.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::codec::{Codec, CodecReadError};
use crate::codec::{CodecBuilder, CodecReadError};
use crate::frame::cassandra::{CassandraMetadata, CassandraOperation, Tracing};
use crate::frame::{CassandraFrame, Frame, MessageType};
use crate::message::{Encodable, Message, Messages, Metadata};
Expand All @@ -17,40 +17,44 @@ use std::sync::RwLock;
use tokio_util::codec::{Decoder, Encoder};
use tracing::info;

#[derive(Debug, Clone)]
pub struct CassandraCodec {
compression: Arc<RwLock<Compression>>,
messages: Vec<Message>,
current_use_keyspace: Option<Identifier>,
}
#[derive(Clone, Default)]
pub struct CassandraCodecBuilder {}

impl Default for CassandraCodec {
fn default() -> Self {
CassandraCodec::new()
impl CassandraCodecBuilder {
pub fn new() -> Self {
Self::default()
}
}

impl CassandraCodec {
pub fn new() -> CassandraCodec {
CassandraCodec {
compression: Arc::new(RwLock::new(Compression::None)),
messages: vec![],
current_use_keyspace: None,
}
impl CodecBuilder for CassandraCodecBuilder {
type Decoder = CassandraDecoder;
type Encoder = CassandraEncoder;
fn build(&self) -> (CassandraDecoder, CassandraEncoder) {
let compression = Arc::new(RwLock::new(Compression::None));
(
CassandraDecoder::new(compression.clone()),
CassandraEncoder::new(compression),
)
}
}

impl Codec for CassandraCodec {
fn clone_without_state(&self) -> Self {
Self {
compression: Arc::new(RwLock::new(Compression::None)),
messages: self.messages.clone(),
current_use_keyspace: self.current_use_keyspace.clone(),
pub struct CassandraDecoder {
compression: Arc<RwLock<Compression>>,
messages: Vec<Message>,
current_use_keyspace: Option<Identifier>,
}

impl CassandraDecoder {
pub fn new(compression: Arc<RwLock<Compression>>) -> CassandraDecoder {
CassandraDecoder {
compression,
messages: vec![],
current_use_keyspace: None,
}
}
}

impl CassandraCodec {
impl CassandraDecoder {
fn check_compression(&mut self, bytes: &BytesMut) -> Result<bool> {
if bytes.len() < 9 {
return Err(anyhow!("Not enough bytes for cassandra frame"));
Expand All @@ -66,28 +70,28 @@ impl CassandraCodec {
..
} = CassandraFrame::from_bytes(bytes.clone().freeze(), Compression::None)?
{
self.set_compression(&startup);
set_compression(&mut self.compression, &startup);
};
}

Ok(compressed)
}
}

fn set_compression(&mut self, startup: &BodyReqStartup) {
if let Some(compression) = startup.map.get("COMPRESSION") {
let mut write = self.compression.as_ref().write().unwrap();
fn set_compression(compression_state: &mut Arc<RwLock<Compression>>, startup: &BodyReqStartup) {
if let Some(compression) = startup.map.get("COMPRESSION") {
let mut write = compression_state.write().unwrap();

*write = match compression.as_str() {
"snappy" | "SNAPPY" => Compression::Snappy,
"lz4" | "LZ4" => Compression::Lz4,
"" | "none" | "NONE" => Compression::None,
_ => panic!(),
};
}
*write = match compression.as_str() {
"snappy" | "SNAPPY" => Compression::Snappy,
"lz4" | "LZ4" => Compression::Lz4,
"" | "none" | "NONE" => Compression::None,
_ => panic!(),
};
}
}

impl Decoder for CassandraCodec {
impl Decoder for CassandraDecoder {
type Item = Messages;
type Error = CodecReadError;

Expand Down Expand Up @@ -247,7 +251,17 @@ fn reject_protocol_version(version: u8) -> CodecReadError {
))])
}

impl Encoder<Messages> for CassandraCodec {
pub struct CassandraEncoder {
compression: Arc<RwLock<Compression>>,
}

impl CassandraEncoder {
pub fn new(compression: Arc<RwLock<Compression>>) -> CassandraEncoder {
CassandraEncoder { compression }
}
}

impl Encoder<Messages> for CassandraEncoder {
type Error = anyhow::Error;

fn encode(
Expand All @@ -271,7 +285,7 @@ impl Encoder<Messages> for CassandraCodec {
..
} = CassandraFrame::from_bytes(bytes.clone(), Compression::None)?
{
self.set_compression(&startup);
set_compression(&mut self.compression, &startup);
};
}
}
Expand All @@ -285,7 +299,7 @@ impl Encoder<Messages> for CassandraCodec {
..
}) = &frame
{
self.set_compression(startup);
set_compression(&mut self.compression, startup);
};

let buffer = frame.into_cassandra().unwrap().encode(compression);
Expand All @@ -304,7 +318,8 @@ impl Encoder<Messages> for CassandraCodec {

#[cfg(test)]
mod cassandra_protocol_tests {
use crate::codec::cassandra::CassandraCodec;
use crate::codec::cassandra::CassandraCodecBuilder;
use crate::codec::CodecBuilder;
use crate::frame::cassandra::{
parse_statement_single, CassandraFrame, CassandraOperation, CassandraResult, Tracing,
};
Expand All @@ -324,12 +339,13 @@ mod cassandra_protocol_tests {
use tokio_util::codec::{Decoder, Encoder};

fn test_frame_codec_roundtrip(
codec: &mut CassandraCodec,
codec: &mut CassandraCodecBuilder,
raw_frame: &[u8],
expected_messages: Vec<Message>,
) {
let (mut decoder, mut encoder) = codec.build();
// test decode
let decoded_messages = codec
let decoded_messages = decoder
.decode(&mut BytesMut::from(raw_frame))
.unwrap()
.unwrap();
Expand All @@ -346,21 +362,21 @@ mod cassandra_protocol_tests {
// test encode round trip - parsed messages
{
let mut dest = BytesMut::new();
codec.encode(parsed_messages, &mut dest).unwrap();
encoder.encode(parsed_messages, &mut dest).unwrap();
assert_eq!(raw_frame, &dest.to_vec());
}

// test encode round trip - raw messages
{
let mut dest = BytesMut::new();
codec.encode(decoded_messages, &mut dest).unwrap();
encoder.encode(decoded_messages, &mut dest).unwrap();
assert_eq!(raw_frame, &dest.to_vec());
}
}

#[test]
fn test_codec_startup() {
let mut codec = CassandraCodec::new();
let mut codec = CassandraCodecBuilder::new();
let mut startup_body: HashMap<String, String> = HashMap::new();
startup_body.insert("CQL_VERSION".into(), "3.0.0".into());
let bytes = hex!("0400000001000000160001000b43514c5f56455253494f4e0005332e302e30");
Expand All @@ -376,7 +392,7 @@ mod cassandra_protocol_tests {

#[test]
fn test_codec_options() {
let mut codec = CassandraCodec::new();
let mut codec = CassandraCodecBuilder::new();
let bytes = hex!("040000000500000000");
let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
Expand All @@ -390,7 +406,7 @@ mod cassandra_protocol_tests {

#[test]
fn test_codec_ready() {
let mut codec = CassandraCodec::new();
let mut codec = CassandraCodecBuilder::new();
let bytes = hex!("840000000200000000");
let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
Expand All @@ -404,7 +420,7 @@ mod cassandra_protocol_tests {

#[test]
fn test_codec_register() {
let mut codec = CassandraCodec::new();
let mut codec = CassandraCodecBuilder::new();
let bytes = hex!(
"040000010b000000310003000f544f504f4c4f47595f4348414e4745
000d5354415455535f4348414e4745000d534348454d415f4348414e4745"
Expand All @@ -427,7 +443,7 @@ mod cassandra_protocol_tests {

#[test]
fn test_codec_result() {
let mut codec = CassandraCodec::new();
let mut codec = CassandraCodecBuilder::new();
let bytes = hex!(
"840000020800000099000000020000000100000009000673797374656
d000570656572730004706565720010000b646174615f63656e746572000d0007686f73745f6964000c000c70726566
Expand Down Expand Up @@ -535,7 +551,7 @@ mod cassandra_protocol_tests {

#[test]
fn test_codec_query_select() {
let mut codec = CassandraCodec::new();
let mut codec = CassandraCodecBuilder::new();
let bytes = hex!(
"0400000307000000350000002e53454c454354202a2046524f4d20737973
74656d2e6c6f63616c205748455245206b6579203d20276c6f63616c27000100"
Expand All @@ -558,7 +574,7 @@ mod cassandra_protocol_tests {

#[test]
fn test_codec_query_insert() {
let mut codec = CassandraCodec::new();
let mut codec = CassandraCodecBuilder::new();
let bytes = hex!(
"0400000307000000330000002c494e5345525420494e544f207379737465
6d2e666f6f2028626172292056414c554553202827626172322729000100"
Expand Down
12 changes: 9 additions & 3 deletions shotover-proxy/src/codec/kafka.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
use crate::codec::{Codec, CodecReadError};
use crate::codec::{CodecBuilder, CodecReadError};
use crate::frame::MessageType;
use crate::message::{Encodable, Message, Messages, ProtocolType};
use anyhow::Result;
use bytes::{Buf, BytesMut};
use tokio_util::codec::{Decoder, Encoder};

impl CodecBuilder for KafkaCodec {
type Decoder = KafkaCodec;
type Encoder = KafkaCodec;
fn build(&self) -> (KafkaCodec, KafkaCodec) {
(KafkaCodec::new(), KafkaCodec::new())
}
}

#[derive(Debug, Clone)]
pub struct KafkaCodec {
messages: Messages,
Expand All @@ -16,8 +24,6 @@ impl Default for KafkaCodec {
}
}

impl Codec for KafkaCodec {}

impl KafkaCodec {
pub fn new() -> KafkaCodec {
KafkaCodec { messages: vec![] }
Expand Down
17 changes: 8 additions & 9 deletions shotover-proxy/src/codec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,15 @@ impl From<std::io::Error> for CodecReadError {
}

// TODO: Replace with trait_alias (rust-lang/rust#41517).
pub trait CodecReadHalf: Decoder<Item = Messages, Error = CodecReadError> + Clone + Send {}
impl<T: Decoder<Item = Messages, Error = CodecReadError> + Clone + Send> CodecReadHalf for T {}
pub trait DecoderHalf: Decoder<Item = Messages, Error = CodecReadError> + Send {}
impl<T: Decoder<Item = Messages, Error = CodecReadError> + Send> DecoderHalf for T {}

// TODO: Replace with trait_alias (rust-lang/rust#41517).
pub trait CodecWriteHalf: Encoder<Messages, Error = anyhow::Error> + Clone + Send {}
impl<T: Encoder<Messages, Error = anyhow::Error> + Clone + Send> CodecWriteHalf for T {}
pub trait EncoderHalf: Encoder<Messages, Error = anyhow::Error> + Send {}
impl<T: Encoder<Messages, Error = anyhow::Error> + Send> EncoderHalf for T {}

// TODO: Replace with trait_alias (rust-lang/rust#41517).
pub trait Codec: CodecReadHalf + CodecWriteHalf + Sized + Clone {
fn clone_without_state(&self) -> Self {
self.clone()
}
pub trait CodecBuilder: Clone + Send {
type Decoder: DecoderHalf;
type Encoder: EncoderHalf;
fn build(&self) -> (Self::Decoder, Self::Encoder);
}
12 changes: 9 additions & 3 deletions shotover-proxy/src/codec/redis.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::codec::{Codec, CodecReadError};
use crate::codec::{CodecBuilder, CodecReadError};
use crate::frame::RedisFrame;
use crate::frame::{Frame, MessageType};
use crate::message::{Encodable, Message, Messages, QueryType};
Expand All @@ -8,6 +8,14 @@ use redis_protocol::resp2::prelude::decode_mut;
use redis_protocol::resp2::prelude::encode_bytes;
use tokio_util::codec::{Decoder, Encoder};

impl CodecBuilder for RedisCodec {
type Decoder = RedisCodec;
type Encoder = RedisCodec;
fn build(&self) -> (RedisCodec, RedisCodec) {
(RedisCodec::new(), RedisCodec::new())
}
}

#[derive(Debug, Clone)]
pub struct RedisCodec {
messages: Messages,
Expand Down Expand Up @@ -36,8 +44,6 @@ impl Default for RedisCodec {
}
}

impl Codec for RedisCodec {}

impl RedisCodec {
pub fn new() -> RedisCodec {
RedisCodec { messages: vec![] }
Expand Down
Loading

0 comments on commit d048005

Please sign in to comment.