Skip to content

Commit

Permalink
serialize
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Sep 2, 2022
1 parent c0d8f58 commit 96e0eba
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 169 deletions.
4 changes: 2 additions & 2 deletions shotover-proxy/src/codec/cassandra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl Decoder for CassandraCodec {
Ok(frame_len) => {
// Clear the read bytes from the FramedReader
let bytes = src.split_to(frame_len);
tracing::error!(
tracing::debug!(
"incoming cassandra message:\n{}",
pretty_hex::pretty_hex(&bytes)
);
Expand Down Expand Up @@ -127,7 +127,7 @@ impl Encoder<Messages> for CassandraCodec {
Encodable::Bytes(bytes) => dst.extend_from_slice(&bytes),
Encodable::Frame(frame) => self.encode_raw(frame.into_cassandra().unwrap(), dst),
}
tracing::error!(
tracing::debug!(
"outgoing cassandra message:\n{}",
pretty_hex::pretty_hex(&&dst[start..])
);
Expand Down
211 changes: 150 additions & 61 deletions shotover-proxy/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use cassandra_protocol::{
},
};
use cql3_parser::common::Operand;
use itertools::Itertools;
use nonzero_ext::nonzero;
use num::BigInt;
use ordered_float::OrderedFloat;
Expand Down Expand Up @@ -636,72 +635,165 @@ impl MessageValue {
}
}

fn into_cbytes(self) -> CBytes {
fn cassandra_serialize(&self, cursor: &mut Cursor<&mut Vec<u8>>) {
match self {
MessageValue::Null => {
// Despite the name of the function this actually creates a cassandra NULL value instead of a cassandra empty value
CBytes::new_empty()
}
MessageValue::Bytes(b) => CBytes::new(b.to_vec()),
MessageValue::Strings(s) => CBytes::new(s.into_bytes()),
MessageValue::Integer(x, size) => CBytes::new(match size {
IntSize::I64 => (x as i64).to_be_bytes().to_vec(),
IntSize::I32 => (x as i32).to_be_bytes().to_vec(),
IntSize::I16 => (x as i16).to_be_bytes().to_vec(),
IntSize::I8 => (x as i8).to_be_bytes().to_vec(),
}),
MessageValue::Float(f) => CBytes::new(f.into_inner().to_be_bytes().to_vec()),
MessageValue::Boolean(b) => CBytes::new(vec![if b { 1 } else { 0 }]),
MessageValue::List(l) => vec_into_cbytes(l),
//MessageValue::Rows(r) => cassandra_protocol::types::value::Bytes::from(r),
//MessageValue::NamedRows(n) => cassandra_protocol::types::value::Bytes::from(n),
//MessageValue::Document(d) => cassandra_protocol::types::value::Bytes::from(d),
MessageValue::Inet(i) => CBytes::new(match i {
IpAddr::V4(ip) => ip.octets().to_vec(),
IpAddr::V6(ip) => ip.octets().to_vec(),
}),
MessageValue::FragmentedResponse(l) => vec_into_cbytes(l),
MessageValue::Ascii(a) => CBytes::new(a.into_bytes()),
MessageValue::Double(d) => CBytes::new(d.into_inner().to_be_bytes().into()),
MessageValue::Set(s) => vec_into_cbytes(s.into_iter().collect_vec()),
//MessageValue::Map(m) => m.into(),
//MessageValue::Varint(v) => v.into(),
MessageValue::Null => cursor.write_all(&[255, 255, 255, 255]).unwrap(),
MessageValue::Bytes(b) => serialize_bytes(cursor, b),
MessageValue::Strings(s) => serialize_bytes(cursor, s.as_bytes()),
MessageValue::Integer(x, size) => match size {
IntSize::I64 => serialize_bytes(cursor, &(*x as i64).to_be_bytes()),
IntSize::I32 => serialize_bytes(cursor, &(*x as i32).to_be_bytes()),
IntSize::I16 => serialize_bytes(cursor, &(*x as i16).to_be_bytes()),
IntSize::I8 => serialize_bytes(cursor, &(*x as i8).to_be_bytes()),
},
MessageValue::Float(f) => serialize_bytes(cursor, &f.into_inner().to_be_bytes()),
MessageValue::Boolean(b) => serialize_bytes(cursor, &[if *b { 1 } else { 0 }]),
MessageValue::List(l) => serialize_list(cursor, l),
MessageValue::Rows(rows) => serialize_list_list(cursor, rows),
MessageValue::NamedRows(n) => serialize_list_stringmap(cursor, n),
MessageValue::Document(d) => serialize_stringmap(cursor, d),
MessageValue::Inet(i) => match i {
IpAddr::V4(ip) => serialize_bytes(cursor, &ip.octets()),
IpAddr::V6(ip) => serialize_bytes(cursor, &ip.octets()),
},
MessageValue::FragmentedResponse(l) => serialize_list(cursor, l),
MessageValue::Ascii(a) => serialize_bytes(cursor, a.as_bytes()),
MessageValue::Double(d) => serialize_bytes(cursor, &d.into_inner().to_be_bytes()),
MessageValue::Set(s) => serialize_set(cursor, s),
MessageValue::Map(m) => serialize_map(cursor, m),
MessageValue::Varint(v) => serialize_bytes(cursor, &v.to_signed_bytes_be()),
MessageValue::Decimal(d) => {
let (unscaled, scale) = d.into_bigint_and_exponent();
CBytes::new(
cassandra_protocol::types::decimal::Decimal {
let (unscaled, scale) = d.clone().into_bigint_and_exponent();
serialize_bytes(
cursor,
&cassandra_protocol::types::decimal::Decimal {
unscaled,
scale: scale as i32,
}
.serialize_to_vec(Version::V4),
)
);
}
MessageValue::Date(d) => CBytes::new(d.to_be_bytes().to_vec()),
MessageValue::Timestamp(t) => CBytes::new(t.to_be_bytes().to_vec()),
MessageValue::Date(d) => serialize_bytes(cursor, &d.to_be_bytes()),
MessageValue::Timestamp(t) => serialize_bytes(cursor, &t.to_be_bytes()),
MessageValue::Duration(d) => {
// TODO: Either this function should be made fallible or Duration should have validated setters
CBytes::new(
cassandra_protocol::types::duration::Duration::new(
serialize_bytes(
cursor,
&cassandra_protocol::types::duration::Duration::new(
d.months,
d.days,
d.nanoseconds,
)
.unwrap()
.serialize_to_vec(Version::V4),
)
);
}
//MessageValue::Timeuuid(t) => t.into(),
MessageValue::Varchar(v) => CBytes::new(v.into_bytes()),
//MessageValue::Uuid(u) => u.into(),
MessageValue::Time(t) => CBytes::new(t.to_be_bytes().to_vec()),
MessageValue::Counter(c) => CBytes::new(c.to_be_bytes().to_vec()),
//MessageValue::Tuple(t) => t.into(),
//MessageValue::Udt(u) => u.into(),
_ => todo!(),
MessageValue::Timeuuid(t) => serialize_bytes(cursor, t.as_bytes()),
MessageValue::Varchar(v) => serialize_bytes(cursor, v.as_bytes()),
MessageValue::Uuid(u) => serialize_bytes(cursor, u.as_bytes()),
MessageValue::Time(t) => serialize_bytes(cursor, &t.to_be_bytes()),
MessageValue::Counter(c) => serialize_bytes(cursor, &c.to_be_bytes()),
MessageValue::Tuple(t) => serialize_list(cursor, t),
MessageValue::Udt(u) => serialize_stringmap(cursor, u),
}
}
}

fn serialize_len(cursor: &mut Cursor<&mut Vec<u8>>, len: usize) {
let len = len as CInt;
cursor.write_all(&len.to_be_bytes()).unwrap();
}

fn serialize_bytes(cursor: &mut Cursor<&mut Vec<u8>>, bytes: &[u8]) {
serialize_len(cursor, bytes.len());
cursor.write_all(bytes).unwrap();
}

fn serialize_list(cursor: &mut Cursor<&mut Vec<u8>>, values: &[MessageValue]) {
// TODO: avoid allocating here, will need some length hint logic
let mut bytes = vec![];
let mut inner_cursor = Cursor::new(&mut bytes);
serialize_len(&mut inner_cursor, values.len());

for value in values {
value.cassandra_serialize(&mut inner_cursor);
}

serialize_bytes(cursor, &bytes);
}

fn serialize_list_list(cursor: &mut Cursor<&mut Vec<u8>>, values: &[Vec<MessageValue>]) {
// TODO: avoid allocating here, will need some length hint logic
let mut bytes = vec![];
let mut inner_cursor = Cursor::new(&mut bytes);
serialize_len(&mut inner_cursor, values.len());

for value in values {
serialize_list(cursor, value);
}

serialize_bytes(cursor, &bytes);
}

fn serialize_list_stringmap(
cursor: &mut Cursor<&mut Vec<u8>>,
values: &[BTreeMap<String, MessageValue>],
) {
// TODO: avoid allocating here, will need some length hint logic
let mut bytes = vec![];
let mut inner_cursor = Cursor::new(&mut bytes);
serialize_len(&mut inner_cursor, values.len());

for value in values {
serialize_stringmap(cursor, value);
}

serialize_bytes(cursor, &bytes);
}

#[allow(clippy::mutable_key_type)]
fn serialize_set(cursor: &mut Cursor<&mut Vec<u8>>, values: &BTreeSet<MessageValue>) {
// TODO: avoid allocating here, will need some length hint logic
let mut bytes = vec![];
let mut inner_cursor = Cursor::new(&mut bytes);
serialize_len(&mut inner_cursor, values.len());

for value in values {
value.cassandra_serialize(&mut inner_cursor);
}

serialize_bytes(cursor, &bytes);
}

fn serialize_stringmap(cursor: &mut Cursor<&mut Vec<u8>>, values: &BTreeMap<String, MessageValue>) {
// TODO: avoid allocating here, will need some length hint logic
let mut bytes = vec![];
let mut inner_cursor = Cursor::new(&mut bytes);
serialize_len(&mut inner_cursor, values.len());

for (key, value) in values.iter() {
serialize_bytes(&mut inner_cursor, key.as_bytes());
value.cassandra_serialize(&mut inner_cursor);
}

serialize_bytes(cursor, &bytes);
}

#[allow(clippy::mutable_key_type)]
fn serialize_map(cursor: &mut Cursor<&mut Vec<u8>>, values: &BTreeMap<MessageValue, MessageValue>) {
// TODO: avoid allocating here, will need some length hint logic
let mut bytes = vec![];
let mut inner_cursor = Cursor::new(&mut bytes);
serialize_len(&mut inner_cursor, values.len());

for (key, value) in values.iter() {
key.cassandra_serialize(&mut inner_cursor);
value.cassandra_serialize(&mut inner_cursor);
}

serialize_bytes(cursor, &bytes);
}

// TODO: just call into_cbytes directly
impl From<MessageValue> for CBytes {
fn from(value: MessageValue) -> CBytes {
Expand All @@ -713,22 +805,19 @@ impl From<MessageValue> for CBytes {
//
// TODO: This should be upstreamable but will require rewriting their entire CBytes/Bytes/Value API
// and so will take a long time to both write and review
value.into_cbytes()
}
}

fn vec_into_cbytes(values: Vec<MessageValue>) -> CBytes {
let mut bytes = vec![];
let mut cursor = Cursor::new(&mut bytes);
let len = values.len() as CInt;
cursor.write_all(&len.to_be_bytes()).unwrap();

for value in values {
let value_bytes = value.into_cbytes();
value_bytes.serialize(&mut cursor, Version::V4);
// CBytes API expects the length to be implied and the null value encoded
let mut bytes = vec![];
value.cassandra_serialize(&mut Cursor::new(&mut bytes));
if i32::from_be_bytes(bytes[0..4].try_into().unwrap()) < 0 {
// Despite the name of the function this actually creates a cassandra NULL value instead of a cassandra empty value
CBytes::new_empty()
} else {
// strip the length
bytes.drain(0..4);
CBytes::new(bytes)
}
}

CBytes::new(bytes)
}

mod my_bytes {
Expand Down
14 changes: 2 additions & 12 deletions shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,6 @@ impl CassandraSinkCluster {
let mut data_center_alias = "data_center";
let mut rack_alias = "rack";
let mut host_id_alias = "host_id";
let mut native_address_alias = "native_address";
let mut native_port_alias = "native_port";
let mut preferred_ip_alias = "preferred_ip";
let mut preferred_port_alias = "preferred_port";
let mut rpc_address_alias = "rpc_address";
Expand All @@ -485,10 +483,6 @@ impl CassandraSinkCluster {
rack_alias = alias;
} else if column.name == Identifier::Unquoted("host_id".to_string()) {
host_id_alias = alias;
} else if column.name == Identifier::Unquoted("native_address".to_string()) {
native_address_alias = alias;
} else if column.name == Identifier::Unquoted("native_port".to_string()) {
native_port_alias = alias;
} else if column.name == Identifier::Unquoted("preferred_ip".to_string()) {
preferred_ip_alias = alias;
} else if column.name == Identifier::Unquoted("preferred_port".to_string()) {
Expand Down Expand Up @@ -565,17 +559,13 @@ impl CassandraSinkCluster {
|| colspec.name == rpc_address_alias
{
MessageValue::Null
} else if colspec.name == native_address_alias {
} else if colspec.name == peer_alias {
MessageValue::Inet(shotover_peer.address.ip())
} else if colspec.name == native_port_alias {
} else if colspec.name == peer_port_alias {
MessageValue::Integer(
shotover_peer.address.port() as i64,
IntSize::I32,
)
} else if colspec.name == peer_alias {
MessageValue::Inet(shotover_peer.address.ip())
} else if colspec.name == peer_port_alias {
MessageValue::Integer(7000, IntSize::I32)
} else if colspec.name == release_version_alias {
MessageValue::Varchar(release_version.clone())
} else if colspec.name == tokens_alias {
Expand Down
33 changes: 8 additions & 25 deletions shotover-proxy/tests/cassandra_int_tests/cluster_multi_rack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,14 @@ use crate::helpers::cassandra::{assert_query_result, CassandraConnection, Result
use std::net::IpAddr;

async fn test_rewrite_system_peers(connection: &CassandraConnection) {
let star_results1 = [
// peer is non-determistic because we dont know which node this will be
ResultValue::Any,
ResultValue::Varchar("dc1".into()),
// host_id is non-determistic because we dont know which node this will be
ResultValue::Any,
ResultValue::Inet("255.255.255.255".into()),
// rack is non-determistic because we dont know which node this will be
ResultValue::Any,
ResultValue::Varchar("3.11.13".into()),
ResultValue::Inet("255.255.255.255".into()),
// schema_version is non deterministic so we cant assert on it.
ResultValue::Any,
// Unfortunately token generation appears to be non-deterministic but we can at least assert that
// there are 128 tokens per node
ResultValue::Set(std::iter::repeat(ResultValue::Any).take(128).collect()),
];
let star_results2 = [
let star_results = [
ResultValue::Any,
ResultValue::Varchar("dc1".into()),
ResultValue::Any,
ResultValue::Inet("255.255.255.255".into()),
ResultValue::Null,
ResultValue::Any,
ResultValue::Varchar("3.11.13".into()),
ResultValue::Inet("255.255.255.255".into()),
ResultValue::Null,
// schema_version is non deterministic so we cant assert on it.
ResultValue::Any,
// Unfortunately token generation appears to be non-deterministic but we can at least assert that
Expand All @@ -39,21 +22,21 @@ async fn test_rewrite_system_peers(connection: &CassandraConnection) {
assert_query_result(
connection,
"SELECT * FROM system.peers;",
&[&star_results1, &star_results2],
&[&star_results, &star_results],
)
.await;
assert_query_result(
connection,
&format!("SELECT {all_columns} FROM system.peers;"),
&[&star_results1, &star_results2],
&[&star_results, &star_results],
)
.await;
assert_query_result(
connection,
&format!("SELECT {all_columns}, {all_columns} FROM system.peers;"),
&[
&[star_results1.as_slice(), star_results1.as_slice()].concat(),
&[star_results2.as_slice(), star_results2.as_slice()].concat(),
&[star_results.as_slice(), star_results.as_slice()].concat(),
&[star_results.as_slice(), star_results.as_slice()].concat(),
],
)
.await;
Expand Down Expand Up @@ -86,7 +69,7 @@ async fn test_rewrite_system_local(connection: &CassandraConnection) {
// Unfortunately token generation appears to be non-deterministic but we can at least assert that
// there are 128 tokens per node
ResultValue::Set(std::iter::repeat(ResultValue::Any).take(128).collect()),
ResultValue::Map(vec![]),
ResultValue::Null,
];

let all_columns =
Expand Down
Loading

0 comments on commit 96e0eba

Please sign in to comment.