Skip to content

Commit

Permalink
Merge pull request #1054 from muzarski/serializable_request_error
Browse files Browse the repository at this point in the history
errors: remove ParseError
  • Loading branch information
Lorak-mmk authored Sep 26, 2024
2 parents ab07be6 + 62c661b commit bcf3a6c
Show file tree
Hide file tree
Showing 22 changed files with 574 additions and 305 deletions.
148 changes: 108 additions & 40 deletions scylla-cql/src/frame/frame_errors.rs
Original file line number Diff line number Diff line change
@@ -1,61 +1,129 @@
use std::error::Error;
use std::sync::Arc;

pub use super::request::{
auth_response::AuthResponseSerializationError,
batch::{BatchSerializationError, BatchStatementSerializationError},
execute::ExecuteSerializationError,
prepare::PrepareSerializationError,
query::{QueryParametersSerializationError, QuerySerializationError},
register::RegisterSerializationError,
startup::StartupSerializationError,
};

use super::response::CqlResponseKind;
use super::TryFromPrimitiveError;
use crate::cql_to_rust::CqlTypeError;
use crate::frame::value::SerializeValuesError;
use crate::types::deserialize::{DeserializationError, TypeCheckError};
use crate::types::serialize::SerializationError;
use crate::types::deserialize::DeserializationError;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum FrameError {
#[error(transparent)]
Parse(#[from] ParseError),
/// An error returned by `parse_response_body_extensions`.
///
/// It represents an error that occurred during deserialization of
/// frame body extensions. These extensions include tracing id,
/// warnings and custom payload.
///
/// Possible error kinds:
/// - failed to decompress frame body (decompression is required for further deserialization)
/// - failed to deserialize tracing id (body ext.)
/// - failed to deserialize warnings list (body ext.)
/// - failed to deserialize custom payload map (body ext.)
#[derive(Error, Debug, Clone)]
#[non_exhaustive]
pub enum FrameBodyExtensionsParseError {
/// Frame is compressed, but no compression was negotiated for the connection.
#[error("Frame is compressed, but no compression negotiated for connection.")]
NoCompressionNegotiated,

/// Failed to deserialize frame trace id.
#[error("Malformed trace id: {0}")]
TraceIdParse(LowLevelDeserializationError),

/// Failed to deserialize warnings attached to frame.
#[error("Malformed warnings list: {0}")]
WarningsListParse(LowLevelDeserializationError),

/// Failed to deserialize frame's custom payload.
#[error("Malformed custom payload map: {0}")]
CustomPayloadMapParse(LowLevelDeserializationError),

/// Failed to decompress frame body (snap).
#[error("Snap decompression error: {0}")]
SnapDecompressError(Arc<dyn Error + Sync + Send>),

/// Failed to decompress frame body (lz4).
#[error("Error decompressing lz4 data {0}")]
Lz4DecompressError(Arc<dyn Error + Sync + Send>),
}

/// An error that occurred during frame header deserialization.
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum FrameHeaderParseError {
/// Failed to read the frame header from the socket.
#[error("Failed to read the frame header: {0}")]
HeaderIoError(std::io::Error),

/// Received a frame marked as coming from a client.
#[error("Received frame marked as coming from a client")]
FrameFromClient,

// FIXME: this should not belong here. User always expects a frame from server.
// This variant is only used in scylla-proxy - need to investigate it later.
#[error("Received frame marked as coming from the server")]
FrameFromServer,

/// Received a frame with unsupported version.
#[error("Received a frame from version {0}, but only 4 is supported")]
VersionNotSupported(u8),

/// Received unknown response opcode.
#[error("Unrecognized response opcode {0}")]
UnknownResponseOpcode(#[from] TryFromPrimitiveError<u8>),

/// Failed to read frame body from the socket.
#[error("Failed to read a chunk of response body. Expected {0} more bytes, error: {1}")]
BodyChunkIoError(usize, std::io::Error),

/// Connection was closed before whole frame was read.
#[error("Connection was closed before body was read: missing {0} out of {1}")]
ConnectionClosed(usize, usize),
#[error("Frame decompression failed.")]
FrameDecompression,
#[error("Frame compression failed.")]
FrameCompression,
#[error(transparent)]
StdIoError(#[from] std::io::Error),
#[error("Unrecognized opcode{0}")]
TryFromPrimitiveError(#[from] TryFromPrimitiveError<u8>),
#[error("Error compressing lz4 data {0}")]
Lz4CompressError(#[from] lz4_flex::block::CompressError),
#[error("Error decompressing lz4 data {0}")]
Lz4DecompressError(#[from] lz4_flex::block::DecompressError),
}

#[derive(Error, Debug)]
pub enum ParseError {
#[error("Low-level deserialization failed: {0}")]
LowLevelDeserializationError(#[from] LowLevelDeserializationError),
#[error("Could not serialize frame: {0}")]
BadDataToSerialize(String),
#[error("Could not deserialize frame: {0}")]
BadIncomingData(String),
#[error(transparent)]
DeserializationError(#[from] DeserializationError),
#[error(transparent)]
DeserializationTypeCheckError(#[from] TypeCheckError),
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error(transparent)]
SerializeValuesError(#[from] SerializeValuesError),
#[error(transparent)]
SerializationError(#[from] SerializationError),
#[error(transparent)]
CqlTypeError(#[from] CqlTypeError),
/// An error that occurred during CQL request serialization.
#[non_exhaustive]
#[derive(Error, Debug, Clone)]
pub enum CqlRequestSerializationError {
/// Failed to serialize STARTUP request.
#[error("Failed to serialize STARTUP request: {0}")]
StartupSerialization(#[from] StartupSerializationError),

/// Failed to serialize REGISTER request.
#[error("Failed to serialize REGISTER request: {0}")]
RegisterSerialization(#[from] RegisterSerializationError),

/// Failed to serialize AUTH_RESPONSE request.
#[error("Failed to serialize AUTH_RESPONSE request: {0}")]
AuthResponseSerialization(#[from] AuthResponseSerializationError),

/// Failed to serialize BATCH request.
#[error("Failed to serialize BATCH request: {0}")]
BatchSerialization(#[from] BatchSerializationError),

/// Failed to serialize PREPARE request.
#[error("Failed to serialize PREPARE request: {0}")]
PrepareSerialization(#[from] PrepareSerializationError),

/// Failed to serialize EXECUTE request.
#[error("Failed to serialize EXECUTE request: {0}")]
ExecuteSerialization(#[from] ExecuteSerializationError),

/// Failed to serialize QUERY request.
#[error("Failed to serialize QUERY request: {0}")]
QuerySerialization(#[from] QuerySerializationError),

/// Request body compression failed.
#[error("Snap compression error: {0}")]
SnapCompressError(Arc<dyn Error + Sync + Send>),
}

/// An error type returned when deserialization of CQL
Expand Down
51 changes: 33 additions & 18 deletions scylla-cql/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ pub mod value;
#[cfg(test)]
mod value_tests;

use crate::frame::frame_errors::FrameError;
use bytes::{Buf, BufMut, Bytes};
use frame_errors::{
CqlRequestSerializationError, FrameBodyExtensionsParseError, FrameHeaderParseError,
};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt};
use uuid::Uuid;

use std::fmt::Display;
use std::sync::Arc;
use std::{collections::HashMap, convert::TryFrom};

use request::SerializableRequest;
Expand Down Expand Up @@ -72,7 +75,7 @@ impl SerializedRequest {
req: &R,
compression: Option<Compression>,
tracing: bool,
) -> Result<SerializedRequest, FrameError> {
) -> Result<SerializedRequest, CqlRequestSerializationError> {
let mut flags = 0;
let mut data = vec![0; HEADER_SIZE];

Expand Down Expand Up @@ -128,19 +131,22 @@ impl Default for FrameParams {

pub async fn read_response_frame(
reader: &mut (impl AsyncRead + Unpin),
) -> Result<(FrameParams, ResponseOpcode, Bytes), FrameError> {
) -> Result<(FrameParams, ResponseOpcode, Bytes), FrameHeaderParseError> {
let mut raw_header = [0u8; HEADER_SIZE];
reader.read_exact(&mut raw_header[..]).await?;
reader
.read_exact(&mut raw_header[..])
.await
.map_err(FrameHeaderParseError::HeaderIoError)?;

let mut buf = &raw_header[..];

// TODO: Validate version
let version = buf.get_u8();
if version & 0x80 != 0x80 {
return Err(FrameError::FrameFromClient);
return Err(FrameHeaderParseError::FrameFromClient);
}
if version & 0x7F != 0x04 {
return Err(FrameError::VersionNotSupported(version & 0x7f));
return Err(FrameHeaderParseError::VersionNotSupported(version & 0x7f));
}

let flags = buf.get_u8();
Expand All @@ -159,10 +165,12 @@ pub async fn read_response_frame(

let mut raw_body = Vec::with_capacity(length).limit(length);
while raw_body.has_remaining_mut() {
let n = reader.read_buf(&mut raw_body).await?;
let n = reader.read_buf(&mut raw_body).await.map_err(|err| {
FrameHeaderParseError::BodyChunkIoError(raw_body.remaining_mut(), err)
})?;
if n == 0 {
// EOF, too early
return Err(FrameError::ConnectionClosed(
return Err(FrameHeaderParseError::ConnectionClosed(
raw_body.remaining_mut(),
length,
));
Expand All @@ -183,18 +191,19 @@ pub fn parse_response_body_extensions(
flags: u8,
compression: Option<Compression>,
mut body: Bytes,
) -> Result<ResponseBodyWithExtensions, FrameError> {
) -> Result<ResponseBodyWithExtensions, FrameBodyExtensionsParseError> {
if flags & FLAG_COMPRESSION != 0 {
if let Some(compression) = compression {
body = decompress(&body, compression)?.into();
} else {
return Err(FrameError::NoCompressionNegotiated);
return Err(FrameBodyExtensionsParseError::NoCompressionNegotiated);
}
}

let trace_id = if flags & FLAG_TRACING != 0 {
let buf = &mut &*body;
let trace_id = types::read_uuid(buf).map_err(frame_errors::ParseError::from)?;
let trace_id =
types::read_uuid(buf).map_err(FrameBodyExtensionsParseError::TraceIdParse)?;
body.advance(16);
Some(trace_id)
} else {
Expand All @@ -204,7 +213,8 @@ pub fn parse_response_body_extensions(
let warnings = if flags & FLAG_WARNING != 0 {
let body_len = body.len();
let buf = &mut &*body;
let warnings = types::read_string_list(buf).map_err(frame_errors::ParseError::from)?;
let warnings = types::read_string_list(buf)
.map_err(FrameBodyExtensionsParseError::WarningsListParse)?;
let buf_len = buf.len();
body.advance(body_len - buf_len);
warnings
Expand All @@ -215,7 +225,8 @@ pub fn parse_response_body_extensions(
let custom_payload = if flags & FLAG_CUSTOM_PAYLOAD != 0 {
let body_len = body.len();
let buf = &mut &*body;
let payload_map = types::read_bytes_map(buf).map_err(frame_errors::ParseError::from)?;
let payload_map = types::read_bytes_map(buf)
.map_err(FrameBodyExtensionsParseError::CustomPayloadMapParse)?;
let buf_len = buf.len();
body.advance(body_len - buf_len);
Some(payload_map)
Expand All @@ -235,7 +246,7 @@ fn compress_append(
uncomp_body: &[u8],
compression: Compression,
out: &mut Vec<u8>,
) -> Result<(), FrameError> {
) -> Result<(), CqlRequestSerializationError> {
match compression {
Compression::Lz4 => {
let uncomp_len = uncomp_body.len() as u32;
Expand All @@ -250,23 +261,27 @@ fn compress_append(
out.resize(old_size + snap::raw::max_compress_len(uncomp_body.len()), 0);
let compressed_size = snap::raw::Encoder::new()
.compress(uncomp_body, &mut out[old_size..])
.map_err(|_| FrameError::FrameCompression)?;
.map_err(|err| CqlRequestSerializationError::SnapCompressError(Arc::new(err)))?;
out.truncate(old_size + compressed_size);
Ok(())
}
}
}

fn decompress(mut comp_body: &[u8], compression: Compression) -> Result<Vec<u8>, FrameError> {
fn decompress(
mut comp_body: &[u8],
compression: Compression,
) -> Result<Vec<u8>, FrameBodyExtensionsParseError> {
match compression {
Compression::Lz4 => {
let uncomp_len = comp_body.get_u32() as usize;
let uncomp_body = lz4_flex::decompress(comp_body, uncomp_len)?;
let uncomp_body = lz4_flex::decompress(comp_body, uncomp_len)
.map_err(|err| FrameBodyExtensionsParseError::Lz4DecompressError(Arc::new(err)))?;
Ok(uncomp_body)
}
Compression::Snappy => snap::raw::Decoder::new()
.decompress_vec(comp_body)
.map_err(|_| FrameError::FrameDecompression),
.map_err(|err| FrameBodyExtensionsParseError::SnapDecompressError(Arc::new(err))),
}
}

Expand Down
20 changes: 17 additions & 3 deletions scylla-cql/src/frame/request/auth_response.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use crate::frame::frame_errors::ParseError;
use std::num::TryFromIntError;

use thiserror::Error;

use crate::frame::frame_errors::CqlRequestSerializationError;

use crate::frame::request::{RequestOpcode, SerializableRequest};
use crate::frame::types::write_bytes_opt;
Expand All @@ -11,7 +15,17 @@ pub struct AuthResponse {
impl SerializableRequest for AuthResponse {
const OPCODE: RequestOpcode = RequestOpcode::AuthResponse;

fn serialize(&self, buf: &mut Vec<u8>) -> Result<(), ParseError> {
Ok(write_bytes_opt(self.response.as_ref(), buf)?)
fn serialize(&self, buf: &mut Vec<u8>) -> Result<(), CqlRequestSerializationError> {
Ok(write_bytes_opt(self.response.as_ref(), buf)
.map_err(AuthResponseSerializationError::ResponseSerialization)?)
}
}

/// An error type returned when serialization of AUTH_RESPONSE request fails.
#[non_exhaustive]
#[derive(Error, Debug, Clone)]
pub enum AuthResponseSerializationError {
/// Maximum response's body length exceeded.
#[error("AUTH_RESPONSE body bytes length too big: {0}")]
ResponseSerialization(TryFromIntError),
}
Loading

0 comments on commit bcf3a6c

Please sign in to comment.