Skip to content

Commit

Permalink
and more
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Sep 9, 2022
1 parent 386d93d commit fbfd265
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 60 deletions.
56 changes: 26 additions & 30 deletions shotover-proxy/src/frame/cassandra.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::message::{MessageValue, QueryType};
use crate::message::{serialize_len, MessageValue, QueryType};
use anyhow::{anyhow, Result};
use bytes::Bytes;
use cassandra_protocol::compression::Compression;
Expand All @@ -14,21 +14,22 @@ use cassandra_protocol::frame::message_query::BodyReqQuery;
use cassandra_protocol::frame::message_request::RequestBody;
use cassandra_protocol::frame::message_response::ResponseBody;
use cassandra_protocol::frame::message_result::{
BodyResResultPrepared, BodyResResultRows, BodyResResultSetKeyspace, ResResultBody,
RowsMetadata, RowsMetadataFlags,
BodyResResultPrepared, BodyResResultSetKeyspace, ResResultBody, ResultKind, RowsMetadata,
RowsMetadataFlags,
};
use cassandra_protocol::frame::{
Direction, Envelope as RawCassandraFrame, Flags, Opcode, Serialize, StreamId, Version,
};
use cassandra_protocol::query::{QueryParams, QueryValues};
use cassandra_protocol::types::blob::Blob;
use cassandra_protocol::types::cassandra_type::CassandraType;
use cassandra_protocol::types::{CBytesShort, CInt, CLong};
use cassandra_protocol::types::{CBytesShort, CLong};
use cql3_parser::begin_batch::{BatchType as ParserBatchType, BeginBatch};
use cql3_parser::cassandra_ast::CassandraAST;
use cql3_parser::cassandra_statement::CassandraStatement;
use cql3_parser::common::Operand;
use nonzero_ext::nonzero;
use std::io::Cursor;
use std::net::IpAddr;
use std::num::NonZeroU32;
use std::str::FromStr;
Expand Down Expand Up @@ -423,18 +424,32 @@ impl CassandraOperation {
.serialize_to_vec(version),
CassandraOperation::Result(result) => match result {
CassandraResult::Rows { rows, metadata } => {
Self::build_cassandra_result_body(version, rows, *metadata)
let mut buf = vec![];
let mut cursor = Cursor::new(&mut buf);

ResultKind::Rows.serialize(&mut cursor, version);

metadata.serialize(&mut cursor, version);
serialize_len(&mut cursor, rows.len());
for row in rows {
for col in row {
col.cassandra_serialize(&mut cursor);
}
}

buf
}
CassandraResult::SetKeyspace(set_keyspace) => {
ResResultBody::SetKeyspace(*set_keyspace)
ResResultBody::SetKeyspace(*set_keyspace).serialize_to_vec(version)
}
CassandraResult::Prepared(prepared) => {
ResResultBody::Prepared(*prepared).serialize_to_vec(version)
}
CassandraResult::Prepared(prepared) => ResResultBody::Prepared(*prepared),
CassandraResult::SchemaChange(schema_change) => {
ResResultBody::SchemaChange(*schema_change)
ResResultBody::SchemaChange(*schema_change).serialize_to_vec(version)
}
CassandraResult::Void => ResResultBody::Void,
}
.serialize_to_vec(version),
CassandraResult::Void => ResResultBody::Void.serialize_to_vec(version),
},
CassandraOperation::Error(error) => error.serialize_to_vec(version),
CassandraOperation::Startup(bytes) => bytes.to_vec(),
CassandraOperation::Ready(bytes) => bytes.to_vec(),
Expand Down Expand Up @@ -472,25 +487,6 @@ impl CassandraOperation {
CassandraOperation::AuthSuccess(bytes) => bytes.to_vec(),
}
}

fn build_cassandra_result_body(
protocol_version: Version,
rows: Vec<Vec<MessageValue>>,
metadata: RowsMetadata,
) -> ResResultBody {
let rows_count = rows.len() as CInt;
let rows_content = rows
.into_iter()
.map(|row| row.into_iter().map(MessageValue::into_cbytes).collect())
.collect();

ResResultBody::Rows(BodyResResultRows {
protocol_version,
metadata,
rows_count,
rows_content,
})
}
}

fn get_query_type(statement: &CassandraStatement) -> QueryType {
Expand Down
38 changes: 8 additions & 30 deletions shotover-proxy/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,30 +621,7 @@ impl MessageValue {
}
}

pub fn into_cbytes(value: MessageValue) -> CBytes {
// cassandra-protocol handles null values incredibly poorly.
// so we need to rewrite their logic to operate at the CBytes level in order to have the null value expressible
//
// Additionally reimplementing this logic allows us to allocate a lot less
// and its also way easier to understand the whole stack than the `.into()` based API.
//
// TODO: This should be upstreamable but will require rewriting their entire CBytes/Bytes/Value API
// and so will take a long time to both write and review

// CBytes API expects the length to be implied and the null value encoded
let mut bytes = vec![];
value.cassandra_serialize(&mut Cursor::new(&mut bytes));
if i32::from_be_bytes(bytes[0..4].try_into().unwrap()) < 0 {
// Despite the name of the function this actually creates a cassandra NULL value instead of a cassandra empty value
CBytes::new_empty()
} else {
// strip the length
bytes.drain(0..4);
CBytes::new(bytes)
}
}

fn cassandra_serialize(&self, cursor: &mut Cursor<&mut Vec<u8>>) {
pub fn cassandra_serialize(&self, cursor: &mut Cursor<&mut Vec<u8>>) {
match self {
MessageValue::Null => cursor.write_all(&[255, 255, 255, 255]).unwrap(),
MessageValue::Bytes(b) => serialize_bytes(cursor, b),
Expand Down Expand Up @@ -709,26 +686,27 @@ fn serialize_with_length_prefix(
serializer: impl FnOnce(&mut Cursor<&mut Vec<u8>>),
) {
// write dummy length
let start_pos = cursor.position();
let length_start = cursor.position();
let bytes_start = length_start + 4;
serialize_len(cursor, 0);

// perform serialization
serializer(cursor);

// overwrite dummy length with actual length of serialized bytes
let bytes_len = cursor.position() - start_pos;
cursor.get_mut()[start_pos as usize..start_pos as usize + 4]
let bytes_len = cursor.position() - bytes_start;
cursor.get_mut()[length_start as usize..bytes_start as usize]
.copy_from_slice(&(bytes_len as CInt).to_be_bytes());
}

fn serialize_len(cursor: &mut Cursor<&mut Vec<u8>>, len: usize) {
pub fn serialize_len(cursor: &mut Cursor<&mut Vec<u8>>, len: usize) {
let len = len as CInt;
let _ = cursor.write_all(&len.to_be_bytes());
cursor.write_all(&len.to_be_bytes()).unwrap();
}

fn serialize_bytes(cursor: &mut Cursor<&mut Vec<u8>>, bytes: &[u8]) {
serialize_len(cursor, bytes.len());
let _ = cursor.write_all(bytes);
cursor.write_all(bytes).unwrap();
}

fn serialize_list(cursor: &mut Cursor<&mut Vec<u8>>, values: &[MessageValue]) {
Expand Down

0 comments on commit fbfd265

Please sign in to comment.