diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index e17055b48..d9f58cb8d 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -1,4 +1,4 @@ -use crate::message::{MessageValue, QueryType}; +use crate::message::{serialize_len, MessageValue, QueryType}; use anyhow::{anyhow, Result}; use bytes::Bytes; use cassandra_protocol::compression::Compression; @@ -14,8 +14,8 @@ use cassandra_protocol::frame::message_query::BodyReqQuery; use cassandra_protocol::frame::message_request::RequestBody; use cassandra_protocol::frame::message_response::ResponseBody; use cassandra_protocol::frame::message_result::{ - BodyResResultPrepared, BodyResResultRows, BodyResResultSetKeyspace, ResResultBody, - RowsMetadata, RowsMetadataFlags, + BodyResResultPrepared, BodyResResultSetKeyspace, ResResultBody, ResultKind, RowsMetadata, + RowsMetadataFlags, }; use cassandra_protocol::frame::{ Direction, Envelope as RawCassandraFrame, Flags, Opcode, Serialize, StreamId, Version, @@ -23,12 +23,13 @@ use cassandra_protocol::frame::{ use cassandra_protocol::query::{QueryParams, QueryValues}; use cassandra_protocol::types::blob::Blob; use cassandra_protocol::types::cassandra_type::CassandraType; -use cassandra_protocol::types::{CBytesShort, CInt, CLong}; +use cassandra_protocol::types::{CBytesShort, CLong}; use cql3_parser::begin_batch::{BatchType as ParserBatchType, BeginBatch}; use cql3_parser::cassandra_ast::CassandraAST; use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::Operand; use nonzero_ext::nonzero; +use std::io::Cursor; use std::net::IpAddr; use std::num::NonZeroU32; use std::str::FromStr; @@ -423,18 +424,32 @@ impl CassandraOperation { .serialize_to_vec(version), CassandraOperation::Result(result) => match result { CassandraResult::Rows { rows, metadata } => { - Self::build_cassandra_result_body(version, rows, *metadata) + let mut buf = vec![]; + let mut cursor = Cursor::new(&mut buf); + + ResultKind::Rows.serialize(&mut cursor, version); + + metadata.serialize(&mut cursor, version); + serialize_len(&mut cursor, rows.len()); + for row in rows { + for col in row { + col.cassandra_serialize(&mut cursor); + } + } + + buf } CassandraResult::SetKeyspace(set_keyspace) => { - ResResultBody::SetKeyspace(*set_keyspace) + ResResultBody::SetKeyspace(*set_keyspace).serialize_to_vec(version) + } + CassandraResult::Prepared(prepared) => { + ResResultBody::Prepared(*prepared).serialize_to_vec(version) } - CassandraResult::Prepared(prepared) => ResResultBody::Prepared(*prepared), CassandraResult::SchemaChange(schema_change) => { - ResResultBody::SchemaChange(*schema_change) + ResResultBody::SchemaChange(*schema_change).serialize_to_vec(version) } - CassandraResult::Void => ResResultBody::Void, - } - .serialize_to_vec(version), + CassandraResult::Void => ResResultBody::Void.serialize_to_vec(version), + }, CassandraOperation::Error(error) => error.serialize_to_vec(version), CassandraOperation::Startup(bytes) => bytes.to_vec(), CassandraOperation::Ready(bytes) => bytes.to_vec(), @@ -472,25 +487,6 @@ impl CassandraOperation { CassandraOperation::AuthSuccess(bytes) => bytes.to_vec(), } } - - fn build_cassandra_result_body( - protocol_version: Version, - rows: Vec>, - metadata: RowsMetadata, - ) -> ResResultBody { - let rows_count = rows.len() as CInt; - let rows_content = rows - .into_iter() - .map(|row| row.into_iter().map(MessageValue::into_cbytes).collect()) - .collect(); - - ResResultBody::Rows(BodyResResultRows { - protocol_version, - metadata, - rows_count, - rows_content, - }) - } } fn get_query_type(statement: &CassandraStatement) -> QueryType { diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index 28bf222c4..f73e21c8a 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -621,30 +621,7 @@ impl MessageValue { } } - pub fn into_cbytes(value: MessageValue) -> CBytes { - // cassandra-protocol handles null values incredibly poorly. - // so we need to rewrite their logic to operate at the CBytes level in order to have the null value expressible - // - // Additionally reimplementing this logic allows us to allocate a lot less - // and its also way easier to understand the whole stack than the `.into()` based API. - // - // TODO: This should be upstreamable but will require rewriting their entire CBytes/Bytes/Value API - // and so will take a long time to both write and review - - // CBytes API expects the length to be implied and the null value encoded - let mut bytes = vec![]; - value.cassandra_serialize(&mut Cursor::new(&mut bytes)); - if i32::from_be_bytes(bytes[0..4].try_into().unwrap()) < 0 { - // Despite the name of the function this actually creates a cassandra NULL value instead of a cassandra empty value - CBytes::new_empty() - } else { - // strip the length - bytes.drain(0..4); - CBytes::new(bytes) - } - } - - fn cassandra_serialize(&self, cursor: &mut Cursor<&mut Vec>) { + pub fn cassandra_serialize(&self, cursor: &mut Cursor<&mut Vec>) { match self { MessageValue::Null => cursor.write_all(&[255, 255, 255, 255]).unwrap(), MessageValue::Bytes(b) => serialize_bytes(cursor, b), @@ -721,14 +698,14 @@ fn serialize_with_length_prefix( .copy_from_slice(&(bytes_len as CInt).to_be_bytes()); } -fn serialize_len(cursor: &mut Cursor<&mut Vec>, len: usize) { +pub fn serialize_len(cursor: &mut Cursor<&mut Vec>, len: usize) { let len = len as CInt; - let _ = cursor.write_all(&len.to_be_bytes()); + cursor.write_all(&len.to_be_bytes()).unwrap(); } fn serialize_bytes(cursor: &mut Cursor<&mut Vec>, bytes: &[u8]) { serialize_len(cursor, bytes.len()); - let _ = cursor.write_all(bytes); + cursor.write_all(bytes).unwrap(); } fn serialize_list(cursor: &mut Cursor<&mut Vec>, values: &[MessageValue]) {