Skip to content

Commit

Permalink
Fix cassandra protocol null handling
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Sep 8, 2022
1 parent af2fa3a commit 2700672
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 141 deletions.
12 changes: 2 additions & 10 deletions shotover-proxy/src/frame/cassandra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use cassandra_protocol::frame::{
use cassandra_protocol::query::{QueryParams, QueryValues};
use cassandra_protocol::types::blob::Blob;
use cassandra_protocol::types::cassandra_type::CassandraType;
use cassandra_protocol::types::{CBytes, CBytesShort, CInt, CLong};
use cassandra_protocol::types::{CBytesShort, CInt, CLong};
use cql3_parser::begin_batch::{BatchType as ParserBatchType, BeginBatch};
use cql3_parser::cassandra_ast::CassandraAST;
use cql3_parser::cassandra_statement::CassandraStatement;
Expand Down Expand Up @@ -481,15 +481,7 @@ impl CassandraOperation {
let rows_count = rows.len() as CInt;
let rows_content = rows
.into_iter()
.map(|row| {
row.into_iter()
.map(|value| {
CBytes::new(
cassandra_protocol::types::value::Bytes::from(value).into_inner(),
)
})
.collect()
})
.map(|row| row.into_iter().map(MessageValue::into_cbytes).collect())
.collect();

ResResultBody::Rows(BodyResResultRows {
Expand Down
191 changes: 148 additions & 43 deletions shotover-proxy/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use anyhow::{anyhow, Result};
use bigdecimal::BigDecimal;
use bytes::{Buf, Bytes};
use bytes_utils::Str;
use cassandra_protocol::frame::Serialize as FrameSerialize;
use cassandra_protocol::types::CInt;
use cassandra_protocol::{
frame::{
message_error::{AdditionalErrorInfo, ErrorBody},
Expand All @@ -20,12 +22,12 @@ use cassandra_protocol::{
},
};
use cql3_parser::common::Operand;
use itertools::Itertools;
use nonzero_ext::nonzero;
use num::BigInt;
use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, BTreeSet};
use std::io::{Cursor, Write};
use std::net::IpAddr;
use std::num::NonZeroU32;
use uuid::Uuid;
Expand Down Expand Up @@ -618,58 +620,161 @@ impl MessageValue {
CassandraType::Null => MessageValue::Null,
}
}
}

impl From<MessageValue> for cassandra_protocol::types::value::Bytes {
fn from(value: MessageValue) -> cassandra_protocol::types::value::Bytes {
match value {
MessageValue::Null => (-1_i32).into(),
MessageValue::Bytes(b) => cassandra_protocol::types::value::Bytes::new(b.to_vec()),
MessageValue::Strings(s) => s.into(),
MessageValue::Integer(x, size) => {
cassandra_protocol::types::value::Bytes::new(match size {
IntSize::I64 => (x as i64).to_be_bytes().to_vec(),
IntSize::I32 => (x as i32).to_be_bytes().to_vec(),
IntSize::I16 => (x as i16).to_be_bytes().to_vec(),
IntSize::I8 => (x as i8).to_be_bytes().to_vec(),
})
}
MessageValue::Float(f) => f.into_inner().into(),
MessageValue::Boolean(b) => b.into(),
MessageValue::List(l) => l.into(),
MessageValue::Inet(i) => i.into(),
MessageValue::Ascii(a) => a.into(),
MessageValue::Double(d) => d.into_inner().into(),
MessageValue::Set(s) => s.into_iter().collect_vec().into(),
MessageValue::Map(m) => m.into(),
MessageValue::Varint(v) => v.into(),
pub fn into_cbytes(value: MessageValue) -> CBytes {
// cassandra-protocol handles null values incredibly poorly.
// so we need to rewrite their logic to operate at the CBytes level in order to have the null value expressible
//
// Additionally reimplementing this logic allows us to allocate a lot less
// and its also way easier to understand the whole stack than the `.into()` based API.
//
// TODO: This should be upstreamable but will require rewriting their entire CBytes/Bytes/Value API
// and so will take a long time to both write and review

// CBytes API expects the length to be implied and the null value encoded
let mut bytes = vec![];
value.cassandra_serialize(&mut Cursor::new(&mut bytes));
if i32::from_be_bytes(bytes[0..4].try_into().unwrap()) < 0 {
// Despite the name of the function this actually creates a cassandra NULL value instead of a cassandra empty value
CBytes::new_empty()
} else {
// strip the length
bytes.drain(0..4);
CBytes::new(bytes)
}
}

fn cassandra_serialize(&self, cursor: &mut Cursor<&mut Vec<u8>>) {
match self {
MessageValue::Null => cursor.write_all(&[255, 255, 255, 255]).unwrap(),
MessageValue::Bytes(b) => serialize_bytes(cursor, b),
MessageValue::Strings(s) => serialize_bytes(cursor, s.as_bytes()),
MessageValue::Integer(x, size) => match size {
IntSize::I64 => serialize_bytes(cursor, &(*x as i64).to_be_bytes()),
IntSize::I32 => serialize_bytes(cursor, &(*x as i32).to_be_bytes()),
IntSize::I16 => serialize_bytes(cursor, &(*x as i16).to_be_bytes()),
IntSize::I8 => serialize_bytes(cursor, &(*x as i8).to_be_bytes()),
},
MessageValue::Float(f) => serialize_bytes(cursor, &f.into_inner().to_be_bytes()),
MessageValue::Boolean(b) => serialize_bytes(cursor, &[if *b { 1 } else { 0 }]),
MessageValue::List(l) => serialize_list(cursor, l),
MessageValue::Inet(i) => match i {
IpAddr::V4(ip) => serialize_bytes(cursor, &ip.octets()),
IpAddr::V6(ip) => serialize_bytes(cursor, &ip.octets()),
},
MessageValue::Ascii(a) => serialize_bytes(cursor, a.as_bytes()),
MessageValue::Double(d) => serialize_bytes(cursor, &d.into_inner().to_be_bytes()),
MessageValue::Set(s) => serialize_set(cursor, s),
MessageValue::Map(m) => serialize_map(cursor, m),
MessageValue::Varint(v) => serialize_bytes(cursor, &v.to_signed_bytes_be()),
MessageValue::Decimal(d) => {
let (unscaled, scale) = d.into_bigint_and_exponent();
cassandra_protocol::types::decimal::Decimal {
unscaled,
scale: scale as i32,
}
.into()
let (unscaled, scale) = d.as_bigint_and_exponent();
serialize_bytes(
cursor,
&cassandra_protocol::types::decimal::Decimal {
unscaled,
scale: scale as i32,
}
.serialize_to_vec(Version::V4),
);
}
MessageValue::Date(d) => d.into(),
MessageValue::Timestamp(t) => t.into(),
MessageValue::Date(d) => serialize_bytes(cursor, &d.to_be_bytes()),
MessageValue::Timestamp(t) => serialize_bytes(cursor, &t.to_be_bytes()),
MessageValue::Duration(d) => {
// TODO: Either this function should be made fallible or we Duration should have validated setters
cassandra_protocol::types::duration::Duration::new(d.months, d.days, d.nanoseconds)
// TODO: Either this function should be made fallible or Duration should have validated setters
serialize_bytes(
cursor,
&cassandra_protocol::types::duration::Duration::new(
d.months,
d.days,
d.nanoseconds,
)
.unwrap()
.into()
.serialize_to_vec(Version::V4),
);
}
MessageValue::Timeuuid(t) => t.into(),
MessageValue::Varchar(v) => v.into(),
MessageValue::Uuid(u) => u.into(),
MessageValue::Time(t) => t.into(),
MessageValue::Counter(c) => c.into(),
MessageValue::Tuple(t) => t.into(),
MessageValue::Udt(u) => u.into(),
MessageValue::Timeuuid(t) => serialize_bytes(cursor, t.as_bytes()),
MessageValue::Varchar(v) => serialize_bytes(cursor, v.as_bytes()),
MessageValue::Uuid(u) => serialize_bytes(cursor, u.as_bytes()),
MessageValue::Time(t) => serialize_bytes(cursor, &t.to_be_bytes()),
MessageValue::Counter(c) => serialize_bytes(cursor, &c.to_be_bytes()),
MessageValue::Tuple(t) => serialize_list(cursor, t),
MessageValue::Udt(u) => serialize_stringmap(cursor, u),
}
}
}

fn serialize_with_length_prefix(
cursor: &mut Cursor<&mut Vec<u8>>,
serializer: impl FnOnce(&mut Cursor<&mut Vec<u8>>),
) {
// write dummy length
let start_pos = cursor.position();
serialize_len(cursor, 0);

// perform serialization
serializer(cursor);

// overwrite dummy length with actual length of serialized bytes
let bytes_len = cursor.position() - start_pos;
cursor.get_mut()[start_pos as usize..start_pos as usize + 4]
.copy_from_slice(&(bytes_len as CInt).to_be_bytes());
}

fn serialize_len(cursor: &mut Cursor<&mut Vec<u8>>, len: usize) {
let len = len as CInt;
let _ = cursor.write_all(&len.to_be_bytes());
}

fn serialize_bytes(cursor: &mut Cursor<&mut Vec<u8>>, bytes: &[u8]) {
serialize_len(cursor, bytes.len());
let _ = cursor.write_all(bytes);
}

fn serialize_list(cursor: &mut Cursor<&mut Vec<u8>>, values: &[MessageValue]) {
serialize_with_length_prefix(cursor, |cursor| {
serialize_len(cursor, values.len());

for value in values {
value.cassandra_serialize(cursor);
}
});
}

#[allow(clippy::mutable_key_type)]
fn serialize_set(cursor: &mut Cursor<&mut Vec<u8>>, values: &BTreeSet<MessageValue>) {
serialize_with_length_prefix(cursor, |cursor| {
serialize_len(cursor, values.len());

for value in values {
value.cassandra_serialize(cursor);
}
});
}

fn serialize_stringmap(cursor: &mut Cursor<&mut Vec<u8>>, values: &BTreeMap<String, MessageValue>) {
serialize_with_length_prefix(cursor, |cursor| {
serialize_len(cursor, values.len());

for (key, value) in values.iter() {
serialize_bytes(cursor, key.as_bytes());
value.cassandra_serialize(cursor);
}
});
}

#[allow(clippy::mutable_key_type)]
fn serialize_map(cursor: &mut Cursor<&mut Vec<u8>>, values: &BTreeMap<MessageValue, MessageValue>) {
serialize_with_length_prefix(cursor, |cursor| {
serialize_len(cursor, values.len());

for (key, value) in values.iter() {
key.cassandra_serialize(cursor);
value.cassandra_serialize(cursor);
}
});
}

mod my_bytes {
use bytes::Bytes;
use serde::{Deserialize, Deserializer, Serializer};
Expand Down
32 changes: 9 additions & 23 deletions shotover-proxy/tests/cassandra_int_tests/cluster_multi_rack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,17 @@ use crate::helpers::cassandra::{assert_query_result, CassandraConnection, Result
use std::net::SocketAddr;

async fn test_rewrite_system_peers(connection: &CassandraConnection) {
let star_results1 = [
let star_results = [
// peer is non-determistic because we dont know which node this will be
ResultValue::Any,
ResultValue::Varchar("dc1".into()),
// host_id is non-determistic because we dont know which node this will be
ResultValue::Any,
ResultValue::Inet("255.255.255.255".into()),
ResultValue::Null,
// rack is non-determistic because we dont know which node this will be
ResultValue::Any,
ResultValue::Varchar("4.0.6".into()),
ResultValue::Inet("255.255.255.255".into()),
// schema_version is non deterministic so we cant assert on it.
ResultValue::Any,
// Unfortunately token generation appears to be non-deterministic but we can at least assert that
// there are 128 tokens per node
ResultValue::Set(std::iter::repeat(ResultValue::Any).take(128).collect()),
];
let star_results2 = [
ResultValue::Any,
ResultValue::Varchar("dc1".into()),
ResultValue::Any,
ResultValue::Inet("255.255.255.255".into()),
ResultValue::Any,
ResultValue::Varchar("4.0.6".into()),
ResultValue::Inet("255.255.255.255".into()),
ResultValue::Null,
// schema_version is non deterministic so we cant assert on it.
ResultValue::Any,
// Unfortunately token generation appears to be non-deterministic but we can at least assert that
Expand All @@ -39,21 +25,21 @@ async fn test_rewrite_system_peers(connection: &CassandraConnection) {
assert_query_result(
connection,
"SELECT * FROM system.peers;",
&[&star_results1, &star_results2],
&[&star_results, &star_results],
)
.await;
assert_query_result(
connection,
&format!("SELECT {all_columns} FROM system.peers;"),
&[&star_results1, &star_results2],
&[&star_results, &star_results],
)
.await;
assert_query_result(
connection,
&format!("SELECT {all_columns}, {all_columns} FROM system.peers;"),
&[
&[star_results1.as_slice(), star_results1.as_slice()].concat(),
&[star_results2.as_slice(), star_results2.as_slice()].concat(),
&[star_results.as_slice(), star_results.as_slice()].concat(),
&[star_results.as_slice(), star_results.as_slice()].concat(),
],
)
.await;
Expand All @@ -70,8 +56,8 @@ async fn test_rewrite_system_peers_v2(connection: &CassandraConnection) {
// native_address is non-determistic because we dont know which node this will be
ResultValue::Any,
ResultValue::Int(9042),
ResultValue::Inet("255.255.255.255".into()),
ResultValue::Int(-1),
ResultValue::Null,
ResultValue::Null,
// rack is non-determistic because we dont know which node this will be
ResultValue::Any,
ResultValue::Varchar("4.0.6".into()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ async fn test_rewrite_system_peers_dummy_peers(connection: &CassandraConnection)
ResultValue::Inet("127.0.0.1".parse().unwrap()),
ResultValue::Varchar("dc1".into()),
ResultValue::Uuid("3c3c4e2d-ba74-4f76-b52e-fb5bcee6a9f4".parse().unwrap()),
ResultValue::Inet("255.255.255.255".into()),
ResultValue::Null,
ResultValue::Varchar("rack1".into()),
ResultValue::Varchar("3.11.13".into()),
ResultValue::Inet("255.255.255.255".into()),
ResultValue::Null,
// schema_version is non deterministic so we cant assert on it.
ResultValue::Any,
// Unfortunately token generation appears to be non-deterministic but we can at least assert that
Expand All @@ -19,10 +19,10 @@ async fn test_rewrite_system_peers_dummy_peers(connection: &CassandraConnection)
ResultValue::Inet("127.0.0.1".parse().unwrap()),
ResultValue::Varchar("dc1".into()),
ResultValue::Uuid("fa74d7ec-1223-472b-97de-04a32ccdb70b".parse().unwrap()),
ResultValue::Inet("255.255.255.255".into()),
ResultValue::Null,
ResultValue::Varchar("rack1".into()),
ResultValue::Varchar("3.11.13".into()),
ResultValue::Inet("255.255.255.255".into()),
ResultValue::Null,
// schema_version is non deterministic so we cant assert on it.
ResultValue::Any,
// Unfortunately token generation appears to be non-deterministic but we can at least assert that
Expand Down Expand Up @@ -78,7 +78,7 @@ async fn test_rewrite_system_local(connection: &CassandraConnection) {
// Unfortunately token generation appears to be non-deterministic but we can at least assert that
// there are 128 tokens per node
ResultValue::Set(std::iter::repeat(ResultValue::Any).take(3 * 128).collect()),
ResultValue::Map(vec![]),
ResultValue::Null,
];

let all_columns =
Expand Down
Loading

0 comments on commit 2700672

Please sign in to comment.