From 188eb55f3320e6cdbfca2470fe0024e957f892c3 Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Thu, 16 Feb 2023 16:18:15 +1100 Subject: [PATCH] Move MessageValue out of message/mod.rs (#1044) --- shotover-proxy/benches/benches/codec.rs | 3 +- shotover-proxy/src/frame/cassandra.rs | 3 +- shotover-proxy/src/lib.rs | 1 + shotover-proxy/src/message/mod.rs | 386 +----------------- shotover-proxy/src/message_value.rs | 384 +++++++++++++++++ .../src/transforms/cassandra/peers_rewrite.rs | 3 +- .../transforms/cassandra/sink_cluster/mod.rs | 3 +- .../cassandra/sink_cluster/topology.rs | 3 +- .../src/transforms/protect/crypto.rs | 2 +- shotover-proxy/src/transforms/protect/mod.rs | 2 +- 10 files changed, 400 insertions(+), 390 deletions(-) create mode 100644 shotover-proxy/src/message_value.rs diff --git a/shotover-proxy/benches/benches/codec.rs b/shotover-proxy/benches/benches/codec.rs index b699c80b2..738260f14 100644 --- a/shotover-proxy/benches/benches/codec.rs +++ b/shotover-proxy/benches/benches/codec.rs @@ -7,7 +7,8 @@ use criterion::{black_box, criterion_group, BatchSize, Criterion}; use shotover_proxy::codec::cassandra::CassandraCodec; use shotover_proxy::frame::cassandra::{parse_statement_single, Tracing}; use shotover_proxy::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; -use shotover_proxy::message::{IntSize, Message, MessageValue}; +use shotover_proxy::message::Message; +use shotover_proxy::message_value::{IntSize, MessageValue}; use tokio_util::codec::Encoder; fn criterion_benchmark(c: &mut Criterion) { diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 29d8379ad..7deb4305c 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -1,4 +1,5 @@ -use crate::message::{serialize_len, serialize_with_length_prefix, MessageValue, QueryType}; +use crate::message::QueryType; +use crate::message_value::{serialize_len, serialize_with_length_prefix, MessageValue}; use anyhow::{anyhow, Result}; use bytes::Bytes; use cassandra_protocol::compression::Compression; diff --git a/shotover-proxy/src/lib.rs b/shotover-proxy/src/lib.rs index c802133e8..af30866ce 100644 --- a/shotover-proxy/src/lib.rs +++ b/shotover-proxy/src/lib.rs @@ -27,6 +27,7 @@ pub mod config; pub mod error; pub mod frame; pub mod message; +pub mod message_value; mod observability; pub mod runner; mod server; diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index 4f78f4bb7..4c88e8d40 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -2,36 +2,16 @@ use crate::codec::redis::redis_query_type; use crate::frame::cassandra::Tracing; use crate::frame::{ cassandra, - cassandra::{to_cassandra_type, CassandraMetadata, CassandraOperation}, + cassandra::{CassandraMetadata, CassandraOperation}, }; use crate::frame::{CassandraFrame, Frame, MessageType, RedisFrame}; use anyhow::{anyhow, Result}; -use bigdecimal::BigDecimal; use bytes::{Buf, Bytes}; use bytes_utils::Str; -use cassandra_protocol::frame::Serialize as FrameSerialize; -use cassandra_protocol::types::CInt; -use cassandra_protocol::{ - frame::{ - message_error::{ErrorBody, ErrorType}, - message_result::{ColSpec, ColTypeOption}, - Version, - }, - types::{ - cassandra_type::{wrapper_fn, CassandraType}, - CBytes, - }, -}; -use cql3_parser::common::Operand; +use cassandra_protocol::frame::message_error::{ErrorBody, ErrorType}; use nonzero_ext::nonzero; -use num::BigInt; -use ordered_float::OrderedFloat; -use serde::{Deserialize, Serialize}; -use std::collections::{BTreeMap, BTreeSet}; -use std::io::{Cursor, Write}; -use std::net::IpAddr; +use serde::Deserialize; use std::num::NonZeroU32; -use uuid::Uuid; pub enum Metadata { Cassandra(CassandraMetadata), @@ -408,363 +388,3 @@ pub enum QueryType { SchemaChange, PubSubMessage, } - -#[derive(PartialEq, Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialOrd, Ord)] -pub enum MessageValue { - Null, - #[serde(with = "my_bytes")] - Bytes(Bytes), - Ascii(String), - Strings(String), - Integer(i64, IntSize), - Double(OrderedFloat), - Float(OrderedFloat), - Boolean(bool), - Inet(IpAddr), - List(Vec), - Set(BTreeSet), - Map(BTreeMap), - Varint(BigInt), - Decimal(BigDecimal), - Date(i32), - Timestamp(i64), - Duration(Duration), - Timeuuid(Uuid), - Varchar(String), - Uuid(Uuid), - Time(i64), - Counter(i64), - Tuple(Vec), - Udt(BTreeMap), -} - -#[derive(PartialEq, Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialOrd, Ord)] -pub enum IntSize { - I64, // BigInt - I32, // Int - I16, // Smallint - I8, // Tinyint -} - -// TODO: This is tailored directly to cassandras Duration and will need to be adjusted once we add another protocol that uses it -#[derive(PartialEq, Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialOrd, Ord)] -pub struct Duration { - pub months: i32, - pub days: i32, - pub nanoseconds: i64, -} - -impl From<&Operand> for MessageValue { - fn from(operand: &Operand) -> Self { - MessageValue::create_element(to_cassandra_type(operand)) - } -} - -impl From for MessageValue { - fn from(f: RedisFrame) -> Self { - match f { - RedisFrame::SimpleString(s) => { - MessageValue::Strings(String::from_utf8_lossy(&s).to_string()) - } - RedisFrame::Error(e) => MessageValue::Strings(e.to_string()), - RedisFrame::Integer(i) => MessageValue::Integer(i, IntSize::I64), - RedisFrame::BulkString(b) => MessageValue::Bytes(b), - RedisFrame::Array(a) => { - MessageValue::List(a.iter().cloned().map(MessageValue::from).collect()) - } - RedisFrame::Null => MessageValue::Null, - } - } -} - -impl From<&RedisFrame> for MessageValue { - fn from(f: &RedisFrame) -> Self { - match f.clone() { - RedisFrame::SimpleString(s) => { - MessageValue::Strings(String::from_utf8_lossy(s.as_ref()).to_string()) - } - RedisFrame::Error(e) => MessageValue::Strings(e.to_string()), - RedisFrame::Integer(i) => MessageValue::Integer(i, IntSize::I64), - RedisFrame::BulkString(b) => MessageValue::Bytes(b), - RedisFrame::Array(a) => { - MessageValue::List(a.iter().cloned().map(MessageValue::from).collect()) - } - RedisFrame::Null => MessageValue::Null, - } - } -} - -impl From for RedisFrame { - fn from(value: MessageValue) -> RedisFrame { - match value { - MessageValue::Null => RedisFrame::Null, - MessageValue::Bytes(b) => RedisFrame::BulkString(b), - MessageValue::Strings(s) => RedisFrame::SimpleString(s.into()), - MessageValue::Integer(i, _) => RedisFrame::Integer(i), - MessageValue::Float(f) => RedisFrame::SimpleString(f.to_string().into()), - MessageValue::Boolean(b) => RedisFrame::Integer(i64::from(b)), - MessageValue::Inet(i) => RedisFrame::SimpleString(i.to_string().into()), - MessageValue::List(l) => RedisFrame::Array(l.into_iter().map(|v| v.into()).collect()), - MessageValue::Ascii(_a) => todo!(), - MessageValue::Double(_d) => todo!(), - MessageValue::Set(_s) => todo!(), - MessageValue::Map(_) => todo!(), - MessageValue::Varint(_v) => todo!(), - MessageValue::Decimal(_d) => todo!(), - MessageValue::Date(_date) => todo!(), - MessageValue::Timestamp(_timestamp) => todo!(), - MessageValue::Timeuuid(_timeuuid) => todo!(), - MessageValue::Varchar(_v) => todo!(), - MessageValue::Uuid(_uuid) => todo!(), - MessageValue::Time(_t) => todo!(), - MessageValue::Counter(_c) => todo!(), - MessageValue::Tuple(_) => todo!(), - MessageValue::Udt(_) => todo!(), - MessageValue::Duration(_) => todo!(), - } - } -} - -impl MessageValue { - pub fn value_byte_string(string: String) -> MessageValue { - MessageValue::Bytes(Bytes::from(string)) - } - - pub fn value_byte_str(str: &'static str) -> MessageValue { - MessageValue::Bytes(Bytes::from(str)) - } - - pub fn build_value_from_cstar_col_type( - version: Version, - spec: &ColSpec, - data: &CBytes, - ) -> MessageValue { - let cassandra_type = MessageValue::into_cassandra_type(version, &spec.col_type, data); - MessageValue::create_element(cassandra_type) - } - - fn into_cassandra_type( - version: Version, - col_type: &ColTypeOption, - data: &CBytes, - ) -> CassandraType { - let wrapper = wrapper_fn(&col_type.id); - wrapper(data, col_type, version).unwrap() - } - - fn create_element(element: CassandraType) -> MessageValue { - match element { - CassandraType::Ascii(a) => MessageValue::Ascii(a), - CassandraType::Bigint(b) => MessageValue::Integer(b, IntSize::I64), - CassandraType::Blob(b) => MessageValue::Bytes(b.into_vec().into()), - CassandraType::Boolean(b) => MessageValue::Boolean(b), - CassandraType::Counter(c) => MessageValue::Counter(c), - CassandraType::Decimal(d) => { - let big_decimal = BigDecimal::new(d.unscaled, d.scale.into()); - MessageValue::Decimal(big_decimal) - } - CassandraType::Double(d) => MessageValue::Double(d.into()), - CassandraType::Float(f) => MessageValue::Float(f.into()), - CassandraType::Int(c) => MessageValue::Integer(c.into(), IntSize::I32), - CassandraType::Timestamp(t) => MessageValue::Timestamp(t), - CassandraType::Uuid(u) => MessageValue::Uuid(u), - CassandraType::Varchar(v) => MessageValue::Varchar(v), - CassandraType::Varint(v) => MessageValue::Varint(v), - CassandraType::Timeuuid(t) => MessageValue::Timeuuid(t), - CassandraType::Inet(i) => MessageValue::Inet(i), - CassandraType::Date(d) => MessageValue::Date(d), - CassandraType::Time(d) => MessageValue::Time(d), - CassandraType::Duration(d) => MessageValue::Duration(Duration { - months: d.months(), - days: d.days(), - nanoseconds: d.nanoseconds(), - }), - CassandraType::Smallint(d) => MessageValue::Integer(d.into(), IntSize::I16), - CassandraType::Tinyint(d) => MessageValue::Integer(d.into(), IntSize::I8), - CassandraType::List(list) => { - let value_list = list.into_iter().map(MessageValue::create_element).collect(); - MessageValue::List(value_list) - } - CassandraType::Map(map) => MessageValue::Map( - map.into_iter() - .map(|(key, value)| { - ( - MessageValue::create_element(key), - MessageValue::create_element(value), - ) - }) - .collect(), - ), - CassandraType::Set(set) => { - MessageValue::Set(set.into_iter().map(MessageValue::create_element).collect()) - } - CassandraType::Udt(udt) => { - let values = udt - .into_iter() - .map(|(key, element)| (key, MessageValue::create_element(element))) - .collect(); - MessageValue::Udt(values) - } - CassandraType::Tuple(tuple) => { - let value_list = tuple - .into_iter() - .map(MessageValue::create_element) - .collect(); - MessageValue::Tuple(value_list) - } - CassandraType::Null => MessageValue::Null, - _ => unreachable!(), - } - } - - pub fn cassandra_serialize(&self, cursor: &mut Cursor<&mut Vec>) { - match self { - MessageValue::Null => cursor.write_all(&[255, 255, 255, 255]).unwrap(), - MessageValue::Bytes(b) => serialize_bytes(cursor, b), - MessageValue::Strings(s) => serialize_bytes(cursor, s.as_bytes()), - MessageValue::Integer(x, size) => match size { - IntSize::I64 => serialize_bytes(cursor, &(*x).to_be_bytes()), - IntSize::I32 => serialize_bytes(cursor, &(*x as i32).to_be_bytes()), - IntSize::I16 => serialize_bytes(cursor, &(*x as i16).to_be_bytes()), - IntSize::I8 => serialize_bytes(cursor, &(*x as i8).to_be_bytes()), - }, - MessageValue::Float(f) => serialize_bytes(cursor, &f.into_inner().to_be_bytes()), - MessageValue::Boolean(b) => serialize_bytes(cursor, &[*b as u8]), - MessageValue::List(l) => serialize_list(cursor, l), - MessageValue::Inet(i) => match i { - IpAddr::V4(ip) => serialize_bytes(cursor, &ip.octets()), - IpAddr::V6(ip) => serialize_bytes(cursor, &ip.octets()), - }, - MessageValue::Ascii(a) => serialize_bytes(cursor, a.as_bytes()), - MessageValue::Double(d) => serialize_bytes(cursor, &d.into_inner().to_be_bytes()), - MessageValue::Set(s) => serialize_set(cursor, s), - MessageValue::Map(m) => serialize_map(cursor, m), - MessageValue::Varint(v) => serialize_bytes(cursor, &v.to_signed_bytes_be()), - MessageValue::Decimal(d) => { - let (unscaled, scale) = d.as_bigint_and_exponent(); - serialize_bytes( - cursor, - &cassandra_protocol::types::decimal::Decimal { - unscaled, - scale: scale as i32, - } - .serialize_to_vec(Version::V4), - ); - } - MessageValue::Date(d) => serialize_bytes(cursor, &d.to_be_bytes()), - MessageValue::Timestamp(t) => serialize_bytes(cursor, &t.to_be_bytes()), - MessageValue::Duration(d) => { - // TODO: Either this function should be made fallible or Duration should have validated setters - serialize_bytes( - cursor, - &cassandra_protocol::types::duration::Duration::new( - d.months, - d.days, - d.nanoseconds, - ) - .unwrap() - .serialize_to_vec(Version::V4), - ); - } - MessageValue::Timeuuid(t) => serialize_bytes(cursor, t.as_bytes()), - MessageValue::Varchar(v) => serialize_bytes(cursor, v.as_bytes()), - MessageValue::Uuid(u) => serialize_bytes(cursor, u.as_bytes()), - MessageValue::Time(t) => serialize_bytes(cursor, &t.to_be_bytes()), - MessageValue::Counter(c) => serialize_bytes(cursor, &c.to_be_bytes()), - MessageValue::Tuple(t) => serialize_list(cursor, t), - MessageValue::Udt(u) => serialize_stringmap(cursor, u), - } - } -} - -pub(crate) fn serialize_with_length_prefix( - cursor: &mut Cursor<&mut Vec>, - serializer: impl FnOnce(&mut Cursor<&mut Vec>), -) { - // write dummy length - let length_start = cursor.position(); - let bytes_start = length_start + 4; - serialize_len(cursor, 0); - - // perform serialization - serializer(cursor); - - // overwrite dummy length with actual length of serialized bytes - let bytes_len = cursor.position() - bytes_start; - cursor.get_mut()[length_start as usize..bytes_start as usize] - .copy_from_slice(&(bytes_len as CInt).to_be_bytes()); -} - -pub fn serialize_len(cursor: &mut Cursor<&mut Vec>, len: usize) { - let len = len as CInt; - cursor.write_all(&len.to_be_bytes()).unwrap(); -} - -fn serialize_bytes(cursor: &mut Cursor<&mut Vec>, bytes: &[u8]) { - serialize_len(cursor, bytes.len()); - cursor.write_all(bytes).unwrap(); -} - -fn serialize_list(cursor: &mut Cursor<&mut Vec>, values: &[MessageValue]) { - serialize_with_length_prefix(cursor, |cursor| { - serialize_len(cursor, values.len()); - - for value in values { - value.cassandra_serialize(cursor); - } - }); -} - -#[allow(clippy::mutable_key_type)] -fn serialize_set(cursor: &mut Cursor<&mut Vec>, values: &BTreeSet) { - serialize_with_length_prefix(cursor, |cursor| { - serialize_len(cursor, values.len()); - - for value in values { - value.cassandra_serialize(cursor); - } - }); -} - -fn serialize_stringmap(cursor: &mut Cursor<&mut Vec>, values: &BTreeMap) { - serialize_with_length_prefix(cursor, |cursor| { - serialize_len(cursor, values.len()); - - for (key, value) in values.iter() { - serialize_bytes(cursor, key.as_bytes()); - value.cassandra_serialize(cursor); - } - }); -} - -#[allow(clippy::mutable_key_type)] -fn serialize_map(cursor: &mut Cursor<&mut Vec>, values: &BTreeMap) { - serialize_with_length_prefix(cursor, |cursor| { - serialize_len(cursor, values.len()); - - for (key, value) in values.iter() { - key.cassandra_serialize(cursor); - value.cassandra_serialize(cursor); - } - }); -} - -mod my_bytes { - use bytes::Bytes; - use serde::{Deserialize, Deserializer, Serializer}; - - pub fn serialize(val: &Bytes, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_bytes(val) - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let val: Vec = Deserialize::deserialize(deserializer)?; - Ok(Bytes::from(val)) - } -} diff --git a/shotover-proxy/src/message_value.rs b/shotover-proxy/src/message_value.rs new file mode 100644 index 000000000..458378756 --- /dev/null +++ b/shotover-proxy/src/message_value.rs @@ -0,0 +1,384 @@ +use crate::frame::cassandra::to_cassandra_type; +use crate::frame::RedisFrame; +use bigdecimal::BigDecimal; +use bytes::Bytes; +use cassandra_protocol::frame::Serialize as FrameSerialize; +use cassandra_protocol::types::CInt; +use cassandra_protocol::{ + frame::{ + message_result::{ColSpec, ColTypeOption}, + Version, + }, + types::{ + cassandra_type::{wrapper_fn, CassandraType}, + CBytes, + }, +}; +use cql3_parser::common::Operand; +use num::BigInt; +use ordered_float::OrderedFloat; +use serde::{Deserialize, Serialize}; +use std::collections::{BTreeMap, BTreeSet}; +use std::io::{Cursor, Write}; +use std::net::IpAddr; +use uuid::Uuid; + +#[derive(PartialEq, Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialOrd, Ord)] +pub enum MessageValue { + Null, + #[serde(with = "my_bytes")] + Bytes(Bytes), + Ascii(String), + Strings(String), + Integer(i64, IntSize), + Double(OrderedFloat), + Float(OrderedFloat), + Boolean(bool), + Inet(IpAddr), + List(Vec), + Set(BTreeSet), + Map(BTreeMap), + Varint(BigInt), + Decimal(BigDecimal), + Date(i32), + Timestamp(i64), + Duration(Duration), + Timeuuid(Uuid), + Varchar(String), + Uuid(Uuid), + Time(i64), + Counter(i64), + Tuple(Vec), + Udt(BTreeMap), +} + +#[derive(PartialEq, Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialOrd, Ord)] +pub enum IntSize { + I64, // BigInt + I32, // Int + I16, // Smallint + I8, // Tinyint +} + +// TODO: This is tailored directly to cassandras Duration and will need to be adjusted once we add another protocol that uses it +#[derive(PartialEq, Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialOrd, Ord)] +pub struct Duration { + pub months: i32, + pub days: i32, + pub nanoseconds: i64, +} + +impl From<&Operand> for MessageValue { + fn from(operand: &Operand) -> Self { + MessageValue::create_element(to_cassandra_type(operand)) + } +} + +impl From for MessageValue { + fn from(f: RedisFrame) -> Self { + match f { + RedisFrame::SimpleString(s) => { + MessageValue::Strings(String::from_utf8_lossy(&s).to_string()) + } + RedisFrame::Error(e) => MessageValue::Strings(e.to_string()), + RedisFrame::Integer(i) => MessageValue::Integer(i, IntSize::I64), + RedisFrame::BulkString(b) => MessageValue::Bytes(b), + RedisFrame::Array(a) => { + MessageValue::List(a.iter().cloned().map(MessageValue::from).collect()) + } + RedisFrame::Null => MessageValue::Null, + } + } +} + +impl From<&RedisFrame> for MessageValue { + fn from(f: &RedisFrame) -> Self { + match f.clone() { + RedisFrame::SimpleString(s) => { + MessageValue::Strings(String::from_utf8_lossy(s.as_ref()).to_string()) + } + RedisFrame::Error(e) => MessageValue::Strings(e.to_string()), + RedisFrame::Integer(i) => MessageValue::Integer(i, IntSize::I64), + RedisFrame::BulkString(b) => MessageValue::Bytes(b), + RedisFrame::Array(a) => { + MessageValue::List(a.iter().cloned().map(MessageValue::from).collect()) + } + RedisFrame::Null => MessageValue::Null, + } + } +} + +impl From for RedisFrame { + fn from(value: MessageValue) -> RedisFrame { + match value { + MessageValue::Null => RedisFrame::Null, + MessageValue::Bytes(b) => RedisFrame::BulkString(b), + MessageValue::Strings(s) => RedisFrame::SimpleString(s.into()), + MessageValue::Integer(i, _) => RedisFrame::Integer(i), + MessageValue::Float(f) => RedisFrame::SimpleString(f.to_string().into()), + MessageValue::Boolean(b) => RedisFrame::Integer(i64::from(b)), + MessageValue::Inet(i) => RedisFrame::SimpleString(i.to_string().into()), + MessageValue::List(l) => RedisFrame::Array(l.into_iter().map(|v| v.into()).collect()), + MessageValue::Ascii(_a) => todo!(), + MessageValue::Double(_d) => todo!(), + MessageValue::Set(_s) => todo!(), + MessageValue::Map(_) => todo!(), + MessageValue::Varint(_v) => todo!(), + MessageValue::Decimal(_d) => todo!(), + MessageValue::Date(_date) => todo!(), + MessageValue::Timestamp(_timestamp) => todo!(), + MessageValue::Timeuuid(_timeuuid) => todo!(), + MessageValue::Varchar(_v) => todo!(), + MessageValue::Uuid(_uuid) => todo!(), + MessageValue::Time(_t) => todo!(), + MessageValue::Counter(_c) => todo!(), + MessageValue::Tuple(_) => todo!(), + MessageValue::Udt(_) => todo!(), + MessageValue::Duration(_) => todo!(), + } + } +} + +impl MessageValue { + pub fn value_byte_string(string: String) -> MessageValue { + MessageValue::Bytes(Bytes::from(string)) + } + + pub fn value_byte_str(str: &'static str) -> MessageValue { + MessageValue::Bytes(Bytes::from(str)) + } + + pub fn build_value_from_cstar_col_type( + version: Version, + spec: &ColSpec, + data: &CBytes, + ) -> MessageValue { + let cassandra_type = MessageValue::into_cassandra_type(version, &spec.col_type, data); + MessageValue::create_element(cassandra_type) + } + + fn into_cassandra_type( + version: Version, + col_type: &ColTypeOption, + data: &CBytes, + ) -> CassandraType { + let wrapper = wrapper_fn(&col_type.id); + wrapper(data, col_type, version).unwrap() + } + + fn create_element(element: CassandraType) -> MessageValue { + match element { + CassandraType::Ascii(a) => MessageValue::Ascii(a), + CassandraType::Bigint(b) => MessageValue::Integer(b, IntSize::I64), + CassandraType::Blob(b) => MessageValue::Bytes(b.into_vec().into()), + CassandraType::Boolean(b) => MessageValue::Boolean(b), + CassandraType::Counter(c) => MessageValue::Counter(c), + CassandraType::Decimal(d) => { + let big_decimal = BigDecimal::new(d.unscaled, d.scale.into()); + MessageValue::Decimal(big_decimal) + } + CassandraType::Double(d) => MessageValue::Double(d.into()), + CassandraType::Float(f) => MessageValue::Float(f.into()), + CassandraType::Int(c) => MessageValue::Integer(c.into(), IntSize::I32), + CassandraType::Timestamp(t) => MessageValue::Timestamp(t), + CassandraType::Uuid(u) => MessageValue::Uuid(u), + CassandraType::Varchar(v) => MessageValue::Varchar(v), + CassandraType::Varint(v) => MessageValue::Varint(v), + CassandraType::Timeuuid(t) => MessageValue::Timeuuid(t), + CassandraType::Inet(i) => MessageValue::Inet(i), + CassandraType::Date(d) => MessageValue::Date(d), + CassandraType::Time(d) => MessageValue::Time(d), + CassandraType::Duration(d) => MessageValue::Duration(Duration { + months: d.months(), + days: d.days(), + nanoseconds: d.nanoseconds(), + }), + CassandraType::Smallint(d) => MessageValue::Integer(d.into(), IntSize::I16), + CassandraType::Tinyint(d) => MessageValue::Integer(d.into(), IntSize::I8), + CassandraType::List(list) => { + let value_list = list.into_iter().map(MessageValue::create_element).collect(); + MessageValue::List(value_list) + } + CassandraType::Map(map) => MessageValue::Map( + map.into_iter() + .map(|(key, value)| { + ( + MessageValue::create_element(key), + MessageValue::create_element(value), + ) + }) + .collect(), + ), + CassandraType::Set(set) => { + MessageValue::Set(set.into_iter().map(MessageValue::create_element).collect()) + } + CassandraType::Udt(udt) => { + let values = udt + .into_iter() + .map(|(key, element)| (key, MessageValue::create_element(element))) + .collect(); + MessageValue::Udt(values) + } + CassandraType::Tuple(tuple) => { + let value_list = tuple + .into_iter() + .map(MessageValue::create_element) + .collect(); + MessageValue::Tuple(value_list) + } + CassandraType::Null => MessageValue::Null, + _ => unreachable!(), + } + } + + pub fn cassandra_serialize(&self, cursor: &mut Cursor<&mut Vec>) { + match self { + MessageValue::Null => cursor.write_all(&[255, 255, 255, 255]).unwrap(), + MessageValue::Bytes(b) => serialize_bytes(cursor, b), + MessageValue::Strings(s) => serialize_bytes(cursor, s.as_bytes()), + MessageValue::Integer(x, size) => match size { + IntSize::I64 => serialize_bytes(cursor, &(*x).to_be_bytes()), + IntSize::I32 => serialize_bytes(cursor, &(*x as i32).to_be_bytes()), + IntSize::I16 => serialize_bytes(cursor, &(*x as i16).to_be_bytes()), + IntSize::I8 => serialize_bytes(cursor, &(*x as i8).to_be_bytes()), + }, + MessageValue::Float(f) => serialize_bytes(cursor, &f.into_inner().to_be_bytes()), + MessageValue::Boolean(b) => serialize_bytes(cursor, &[*b as u8]), + MessageValue::List(l) => serialize_list(cursor, l), + MessageValue::Inet(i) => match i { + IpAddr::V4(ip) => serialize_bytes(cursor, &ip.octets()), + IpAddr::V6(ip) => serialize_bytes(cursor, &ip.octets()), + }, + MessageValue::Ascii(a) => serialize_bytes(cursor, a.as_bytes()), + MessageValue::Double(d) => serialize_bytes(cursor, &d.into_inner().to_be_bytes()), + MessageValue::Set(s) => serialize_set(cursor, s), + MessageValue::Map(m) => serialize_map(cursor, m), + MessageValue::Varint(v) => serialize_bytes(cursor, &v.to_signed_bytes_be()), + MessageValue::Decimal(d) => { + let (unscaled, scale) = d.as_bigint_and_exponent(); + serialize_bytes( + cursor, + &cassandra_protocol::types::decimal::Decimal { + unscaled, + scale: scale as i32, + } + .serialize_to_vec(Version::V4), + ); + } + MessageValue::Date(d) => serialize_bytes(cursor, &d.to_be_bytes()), + MessageValue::Timestamp(t) => serialize_bytes(cursor, &t.to_be_bytes()), + MessageValue::Duration(d) => { + // TODO: Either this function should be made fallible or Duration should have validated setters + serialize_bytes( + cursor, + &cassandra_protocol::types::duration::Duration::new( + d.months, + d.days, + d.nanoseconds, + ) + .unwrap() + .serialize_to_vec(Version::V4), + ); + } + MessageValue::Timeuuid(t) => serialize_bytes(cursor, t.as_bytes()), + MessageValue::Varchar(v) => serialize_bytes(cursor, v.as_bytes()), + MessageValue::Uuid(u) => serialize_bytes(cursor, u.as_bytes()), + MessageValue::Time(t) => serialize_bytes(cursor, &t.to_be_bytes()), + MessageValue::Counter(c) => serialize_bytes(cursor, &c.to_be_bytes()), + MessageValue::Tuple(t) => serialize_list(cursor, t), + MessageValue::Udt(u) => serialize_stringmap(cursor, u), + } + } +} + +pub(crate) fn serialize_with_length_prefix( + cursor: &mut Cursor<&mut Vec>, + serializer: impl FnOnce(&mut Cursor<&mut Vec>), +) { + // write dummy length + let length_start = cursor.position(); + let bytes_start = length_start + 4; + serialize_len(cursor, 0); + + // perform serialization + serializer(cursor); + + // overwrite dummy length with actual length of serialized bytes + let bytes_len = cursor.position() - bytes_start; + cursor.get_mut()[length_start as usize..bytes_start as usize] + .copy_from_slice(&(bytes_len as CInt).to_be_bytes()); +} + +pub fn serialize_len(cursor: &mut Cursor<&mut Vec>, len: usize) { + let len = len as CInt; + cursor.write_all(&len.to_be_bytes()).unwrap(); +} + +fn serialize_bytes(cursor: &mut Cursor<&mut Vec>, bytes: &[u8]) { + serialize_len(cursor, bytes.len()); + cursor.write_all(bytes).unwrap(); +} + +fn serialize_list(cursor: &mut Cursor<&mut Vec>, values: &[MessageValue]) { + serialize_with_length_prefix(cursor, |cursor| { + serialize_len(cursor, values.len()); + + for value in values { + value.cassandra_serialize(cursor); + } + }); +} + +#[allow(clippy::mutable_key_type)] +fn serialize_set(cursor: &mut Cursor<&mut Vec>, values: &BTreeSet) { + serialize_with_length_prefix(cursor, |cursor| { + serialize_len(cursor, values.len()); + + for value in values { + value.cassandra_serialize(cursor); + } + }); +} + +fn serialize_stringmap(cursor: &mut Cursor<&mut Vec>, values: &BTreeMap) { + serialize_with_length_prefix(cursor, |cursor| { + serialize_len(cursor, values.len()); + + for (key, value) in values.iter() { + serialize_bytes(cursor, key.as_bytes()); + value.cassandra_serialize(cursor); + } + }); +} + +#[allow(clippy::mutable_key_type)] +fn serialize_map(cursor: &mut Cursor<&mut Vec>, values: &BTreeMap) { + serialize_with_length_prefix(cursor, |cursor| { + serialize_len(cursor, values.len()); + + for (key, value) in values.iter() { + key.cassandra_serialize(cursor); + value.cassandra_serialize(cursor); + } + }); +} + +mod my_bytes { + use bytes::Bytes; + use serde::{Deserialize, Deserializer, Serializer}; + + pub fn serialize(val: &Bytes, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_bytes(val) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let val: Vec = Deserialize::deserialize(deserializer)?; + Ok(Bytes::from(val)) + } +} diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index ccf2d3b3f..6c66c03ab 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -1,5 +1,6 @@ use crate::frame::{CassandraOperation, CassandraResult, Frame}; -use crate::message::{IntSize, Message, MessageValue}; +use crate::message::Message; +use crate::message_value::{IntSize, MessageValue}; use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event; use crate::{ error::ChainResponse, diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs index 092f726bd..bfc2e467e 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs @@ -2,7 +2,8 @@ use self::node_pool::{NodePoolBuilder, PreparedMetadata}; use crate::error::ChainResponse; use crate::frame::cassandra::{parse_statement_single, CassandraMetadata, Tracing}; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; -use crate::message::{IntSize, Message, MessageValue, Messages, Metadata}; +use crate::message::{Message, Messages, Metadata}; +use crate::message_value::{IntSize, MessageValue}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::cassandra::connection::{CassandraConnection, Response}; use crate::transforms::{Transform, TransformBuilder, Wrapper}; diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs index c4b1fcc12..152f2c39c 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs @@ -3,7 +3,8 @@ use super::node_pool::KeyspaceMetadata; use super::KeyspaceChanTx; use crate::frame::cassandra::{parse_statement_single, Tracing}; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; -use crate::message::{Message, MessageValue}; +use crate::message::Message; +use crate::message_value::MessageValue; use crate::transforms::cassandra::connection::CassandraConnection; use anyhow::{anyhow, Result}; use cassandra_protocol::events::{ServerEvent, SimpleServerEvent}; diff --git a/shotover-proxy/src/transforms/protect/crypto.rs b/shotover-proxy/src/transforms/protect/crypto.rs index a010c4862..70262ae16 100644 --- a/shotover-proxy/src/transforms/protect/crypto.rs +++ b/shotover-proxy/src/transforms/protect/crypto.rs @@ -1,4 +1,4 @@ -use crate::message::MessageValue; +use crate::message_value::MessageValue; use crate::transforms::protect::key_management::KeyManager; use anyhow::{anyhow, bail, Result}; use chacha20poly1305::aead::rand_core::RngCore; diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 718bbe94c..f579d12b6 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -1,6 +1,6 @@ use crate::error::ChainResponse; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; -use crate::message::MessageValue; +use crate::message_value::MessageValue; use crate::transforms::protect::key_management::KeyManager; pub use crate::transforms::protect::key_management::KeyManagerConfig; use crate::transforms::{Transform, TransformBuilder, Wrapper};