From 4c013e87b01c02fb5cd6f2bccd467fb70e674900 Mon Sep 17 00:00:00 2001 From: samuel orji Date: Fri, 6 Oct 2023 22:09:21 +0100 Subject: [PATCH 1/7] changed the type of the maximum number of statements in a batch query from an i16 to a u16 according to the CQL protocol spec --- scylla-cql/src/frame/request/batch.rs | 4 ++-- scylla-cql/src/frame/types.rs | 26 +++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/scylla-cql/src/frame/request/batch.rs b/scylla-cql/src/frame/request/batch.rs index 3c0bad3931..92b8b61ec4 100644 --- a/scylla-cql/src/frame/request/batch.rs +++ b/scylla-cql/src/frame/request/batch.rs @@ -81,7 +81,7 @@ where buf.put_u8(self.batch_type as u8); // Serializing queries - types::write_short(self.statements.len().try_into()?, buf); + types::write_u16(self.statements.len().try_into()?, buf); let counts_mismatch_err = |n_values: usize, n_statements: usize| { ParseError::BadDataToSerialize(format!( @@ -190,7 +190,7 @@ impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec Result { let batch_type = buf.get_u8().try_into()?; - let statements_count: usize = types::read_short(buf)?.try_into()?; + let statements_count: usize = types::read_u16(buf)?.try_into()?; let statements_with_values = (0..statements_count) .map(|_| { let batch_statement = BatchStatement::deserialize(buf)?; diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index fd2254c8b0..6af08e217f 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -5,7 +5,7 @@ use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Buf, BufMut}; use num_enum::TryFromPrimitive; use std::collections::HashMap; -use std::convert::TryFrom; +use std::convert::{Infallible, TryFrom}; use std::convert::TryInto; use std::net::IpAddr; use std::net::SocketAddr; @@ -98,6 +98,12 @@ impl From for ParseError { } } +impl From for ParseError { + fn from(_: Infallible) -> Self { + ParseError::BadIncomingData("Unexpected Infallible Error".to_string()) + } +} + impl From for ParseError { fn from(_err: std::array::TryFromSliceError) -> Self { ParseError::BadIncomingData("array try from slice failed".to_string()) @@ -174,10 +180,19 @@ pub fn read_short(buf: &mut &[u8]) -> Result { Ok(v) } +pub fn read_u16(buf: &mut &[u8]) -> Result { + let v = buf.read_u16::()?; + Ok(v) +} + pub fn write_short(v: i16, buf: &mut impl BufMut) { buf.put_i16(v); } +pub fn write_u16(v: u16, buf: &mut impl BufMut) { + buf.put_u16(v); +} + pub(crate) fn read_short_length(buf: &mut &[u8]) -> Result { let v = read_short(buf)?; let v: usize = v.try_into()?; @@ -200,6 +215,15 @@ fn type_short() { } } +#[test] +fn type_u16() { + let vals = [0, 1, u16::MAX]; + for val in vals.iter() { + let mut buf = Vec::new(); + write_u16(*val, &mut buf); + assert_eq!(read_u16(&mut &buf[..]).unwrap(), *val); + } +} // https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L208 pub fn read_bytes_opt<'a>(buf: &mut &'a [u8]) -> Result, ParseError> { let len = read_int(buf)?; From 772c46f1714b97eb75f79c1ac479d3d5cac015b4 Mon Sep 17 00:00:00 2001 From: samuel orji Date: Fri, 6 Oct 2023 23:09:37 +0100 Subject: [PATCH 2/7] add a guard when doing batch statements to prevent making calls to the server when the number of batch queries is greater than u16::MAX, as well as adding some tests --- scylla-cql/src/errors.rs | 4 + scylla-cql/src/frame/types.rs | 2 +- scylla/Cargo.toml | 1 + .../transport/large_batch_statements_test.rs | 106 ++++++++++++++++++ scylla/src/transport/mod.rs | 2 + scylla/src/transport/session.rs | 7 ++ 6 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 scylla/src/transport/large_batch_statements_test.rs diff --git a/scylla-cql/src/errors.rs b/scylla-cql/src/errors.rs index 40587cfef6..c7cc85d233 100644 --- a/scylla-cql/src/errors.rs +++ b/scylla-cql/src/errors.rs @@ -348,6 +348,10 @@ pub enum BadQuery { #[error("Passed invalid keyspace name to use: {0}")] BadKeyspaceName(#[from] BadKeyspaceName), + /// Too many queries in the batch statement + #[error("Number of Queries in Batch Statement has exceeded the max value of 65,536")] + TooManyQueriesInBatchStatement, + /// Other reasons of bad query #[error("{0}")] Other(String), diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index 6af08e217f..fa964e9478 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -5,8 +5,8 @@ use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Buf, BufMut}; use num_enum::TryFromPrimitive; use std::collections::HashMap; -use std::convert::{Infallible, TryFrom}; use std::convert::TryInto; +use std::convert::{Infallible, TryFrom}; use std::net::IpAddr; use std::net::SocketAddr; use std::str; diff --git a/scylla/Cargo.toml b/scylla/Cargo.toml index 2460e020b9..3408d10330 100644 --- a/scylla/Cargo.toml +++ b/scylla/Cargo.toml @@ -60,6 +60,7 @@ criterion = "0.4" # Note: v0.5 needs at least rust 1.70.0 tracing-subscriber = { version = "0.3.14", features = ["env-filter"] } assert_matches = "1.5.0" rand_chacha = "0.3.1" +bcs = "0.1.5" [[bench]] name = "benchmark" diff --git a/scylla/src/transport/large_batch_statements_test.rs b/scylla/src/transport/large_batch_statements_test.rs new file mode 100644 index 0000000000..6195de30df --- /dev/null +++ b/scylla/src/transport/large_batch_statements_test.rs @@ -0,0 +1,106 @@ +use bcs::serialize_into; +use scylla_cql::errors::{BadQuery, QueryError}; + +use crate::batch::BatchType; +use crate::query::Query; +use crate::{ + batch::Batch, + prepared_statement::PreparedStatement, + test_utils::{create_new_session_builder, unique_keyspace_name}, + IntoTypedRows, QueryResult, Session, +}; + +#[tokio::test] +async fn test_large_batch_statements() { + let mut session = create_new_session_builder().build().await.unwrap(); + let ks = unique_keyspace_name(); + session = create_test_session(session, &ks).await; + + let max_number_of_queries = u16::MAX as usize; + write_batch(&session, max_number_of_queries).await; + + let key_prefix = vec![0]; + let keys = find_keys_by_prefix(&session, key_prefix.clone()).await; + assert_eq!(keys.len(), max_number_of_queries); + + let too_many_queries = u16::MAX as usize + 1; + + let err = write_batch(&session, too_many_queries).await; + + assert!(err.is_err()); +} + +async fn create_test_session(session: Session, ks: &String) -> Session { + session + .query( + format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }}",ks), + &[], + ) + .await.unwrap(); + session + .query("DROP TABLE IF EXISTS kv.pairs;", &[]) + .await + .unwrap(); + session + .query( + "CREATE TABLE IF NOT EXISTS kv.pairs (dummy int, k blob, v blob, primary key (dummy, k))", + &[], + ) + .await.unwrap(); + session +} + +async fn write_batch(session: &Session, n: usize) -> Result { + let mut batch_query = Batch::new(BatchType::Logged); + let mut batch_values = Vec::new(); + for i in 0..n { + let mut key = vec![0]; + serialize_into(&mut key, &(i as usize)).unwrap(); + let value = key.clone(); + let query = "INSERT INTO kv.pairs (dummy, k, v) VALUES (0, ?, ?)"; + let values = vec![key, value]; + batch_values.push(values); + let query = Query::new(query); + batch_query.append_statement(query); + } + session.batch(&batch_query, batch_values).await +} + +async fn find_keys_by_prefix(session: &Session, key_prefix: Vec) -> Vec> { + let len = key_prefix.len(); + let rows = match get_upper_bound_option(&key_prefix) { + None => { + let values = (key_prefix,); + let query = "SELECT k FROM kv.pairs WHERE dummy = 0 AND k >= ? ALLOW FILTERING"; + session.query(query, values).await.unwrap() + } + Some(upper_bound) => { + let values = (key_prefix, upper_bound); + let query = + "SELECT k FROM kv.pairs WHERE dummy = 0 AND k >= ? AND k < ? ALLOW FILTERING"; + session.query(query, values).await.unwrap() + } + }; + let mut keys = Vec::new(); + if let Some(rows) = rows.rows { + for row in rows.into_typed::<(Vec,)>() { + let key = row.unwrap(); + let short_key = key.0[len..].to_vec(); + keys.push(short_key); + } + } + keys +} + +fn get_upper_bound_option(key_prefix: &[u8]) -> Option> { + let len = key_prefix.len(); + for i in (0..len).rev() { + let val = key_prefix[i]; + if val < u8::MAX { + let mut upper_bound = key_prefix[0..i + 1].to_vec(); + upper_bound[i] += 1; + return Some(upper_bound); + } + } + None +} diff --git a/scylla/src/transport/mod.rs b/scylla/src/transport/mod.rs index 939983cfc4..a33943645d 100644 --- a/scylla/src/transport/mod.rs +++ b/scylla/src/transport/mod.rs @@ -35,6 +35,8 @@ mod silent_prepare_batch_test; mod cql_types_test; #[cfg(test)] mod cql_value_test; +#[cfg(test)] +mod large_batch_statements_test; pub use cluster::ClusterData; pub use node::{KnownNode, Node, NodeAddr, NodeRef}; diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 35ff25475f..f92067363d 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -76,6 +76,7 @@ pub use crate::transport::connection_pool::PoolSize; use crate::authentication::AuthenticatorProvider; #[cfg(feature = "ssl")] use openssl::ssl::SslContext; +use scylla_cql::errors::BadQuery; /// Translates IP addresses received from ScyllaDB nodes into locally reachable addresses. /// @@ -1143,6 +1144,12 @@ impl Session { // Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard // If users batch statements by shard, they will be rewarded with full shard awareness + // check to ensure that we don't send a batch statement with more than u16::MAX queries + if batch.statements.len() > u16::MAX as usize { + return Err(QueryError::BadQuery( + BadQuery::TooManyQueriesInBatchStatement, + )); + } // Extract first serialized_value let first_serialized_value = values.batch_values_iter().next_serialized().transpose()?; let first_serialized_value = first_serialized_value.as_deref(); From efb96fcb26d00f4cc3de7a8a0e001a28c1d23669 Mon Sep 17 00:00:00 2001 From: samuel orji Date: Sat, 7 Oct 2023 12:34:59 +0100 Subject: [PATCH 3/7] review feedback: Change short to u16 as well as modify tests to be simpler --- scylla-cql/src/errors.rs | 4 +- scylla-cql/src/frame/frame_errors.rs | 2 +- scylla-cql/src/frame/request/batch.rs | 4 +- scylla-cql/src/frame/types.rs | 29 ++----- scylla-cql/src/frame/value.rs | 10 +-- scylla/Cargo.toml | 1 - scylla/src/statement/prepared_statement.rs | 2 +- .../transport/large_batch_statements_test.rs | 81 ++++++------------- scylla/src/transport/session.rs | 5 +- 9 files changed, 45 insertions(+), 93 deletions(-) diff --git a/scylla-cql/src/errors.rs b/scylla-cql/src/errors.rs index c7cc85d233..9e80247e20 100644 --- a/scylla-cql/src/errors.rs +++ b/scylla-cql/src/errors.rs @@ -349,8 +349,8 @@ pub enum BadQuery { BadKeyspaceName(#[from] BadKeyspaceName), /// Too many queries in the batch statement - #[error("Number of Queries in Batch Statement has exceeded the max value of 65,536")] - TooManyQueriesInBatchStatement, + #[error("Number of Queries in Batch Statement supplied is {0} which has exceeded the max value of 65,535")] + TooManyQueriesInBatchStatement(usize), /// Other reasons of bad query #[error("{0}")] diff --git a/scylla-cql/src/frame/frame_errors.rs b/scylla-cql/src/frame/frame_errors.rs index 403b6ab5fd..3da4e26d01 100644 --- a/scylla-cql/src/frame/frame_errors.rs +++ b/scylla-cql/src/frame/frame_errors.rs @@ -40,7 +40,7 @@ pub enum ParseError { #[error(transparent)] IoError(#[from] std::io::Error), #[error("type not yet implemented, id: {0}")] - TypeNotImplemented(i16), + TypeNotImplemented(u16), #[error(transparent)] SerializeValuesError(#[from] SerializeValuesError), #[error(transparent)] diff --git a/scylla-cql/src/frame/request/batch.rs b/scylla-cql/src/frame/request/batch.rs index 92b8b61ec4..3c0bad3931 100644 --- a/scylla-cql/src/frame/request/batch.rs +++ b/scylla-cql/src/frame/request/batch.rs @@ -81,7 +81,7 @@ where buf.put_u8(self.batch_type as u8); // Serializing queries - types::write_u16(self.statements.len().try_into()?, buf); + types::write_short(self.statements.len().try_into()?, buf); let counts_mismatch_err = |n_values: usize, n_statements: usize| { ParseError::BadDataToSerialize(format!( @@ -190,7 +190,7 @@ impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec Result { let batch_type = buf.get_u8().try_into()?; - let statements_count: usize = types::read_u16(buf)?.try_into()?; + let statements_count: usize = types::read_short(buf)?.try_into()?; let statements_with_values = (0..statements_count) .map(|_| { let batch_statement = BatchStatement::deserialize(buf)?; diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index fa964e9478..1c004e07cf 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -16,7 +16,7 @@ use uuid::Uuid; #[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive)] #[cfg_attr(feature = "serde", derive(serde::Deserialize))] #[cfg_attr(feature = "serde", serde(rename_all = "SCREAMING_SNAKE_CASE"))] -#[repr(i16)] +#[repr(u16)] pub enum Consistency { Any = 0x0000, One = 0x0001, @@ -175,8 +175,8 @@ fn type_long() { } } -pub fn read_short(buf: &mut &[u8]) -> Result { - let v = buf.read_i16::()?; +pub fn read_short(buf: &mut &[u8]) -> Result { + let v = buf.read_u16::()?; Ok(v) } @@ -185,11 +185,7 @@ pub fn read_u16(buf: &mut &[u8]) -> Result { Ok(v) } -pub fn write_short(v: i16, buf: &mut impl BufMut) { - buf.put_i16(v); -} - -pub fn write_u16(v: u16, buf: &mut impl BufMut) { +pub fn write_short(v: u16, buf: &mut impl BufMut) { buf.put_u16(v); } @@ -200,14 +196,14 @@ pub(crate) fn read_short_length(buf: &mut &[u8]) -> Result { } fn write_short_length(v: usize, buf: &mut impl BufMut) -> Result<(), ParseError> { - let v: i16 = v.try_into()?; + let v: u16 = v.try_into()?; write_short(v, buf); Ok(()) } #[test] fn type_short() { - let vals = [i16::MIN, -1, 0, 1, i16::MAX]; + let vals: [u16; 3] = [0, 1, u16::MAX]; for val in vals.iter() { let mut buf = Vec::new(); write_short(*val, &mut buf); @@ -215,15 +211,6 @@ fn type_short() { } } -#[test] -fn type_u16() { - let vals = [0, 1, u16::MAX]; - for val in vals.iter() { - let mut buf = Vec::new(); - write_u16(*val, &mut buf); - assert_eq!(read_u16(&mut &buf[..]).unwrap(), *val); - } -} // https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L208 pub fn read_bytes_opt<'a>(buf: &mut &'a [u8]) -> Result, ParseError> { let len = read_int(buf)?; @@ -488,11 +475,11 @@ pub fn read_consistency(buf: &mut &[u8]) -> Result { } pub fn write_consistency(c: Consistency, buf: &mut impl BufMut) { - write_short(c as i16, buf); + write_short(c as u16, buf); } pub fn write_serial_consistency(c: SerialConsistency, buf: &mut impl BufMut) { - write_short(c as i16, buf); + write_short(c as u16, buf); } #[test] diff --git a/scylla-cql/src/frame/value.rs b/scylla-cql/src/frame/value.rs index e9164f2531..617dce4820 100644 --- a/scylla-cql/src/frame/value.rs +++ b/scylla-cql/src/frame/value.rs @@ -63,7 +63,7 @@ pub struct Time(pub Duration); #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct SerializedValues { serialized_values: Vec, - values_num: i16, + values_num: u16, contains_names: bool, } @@ -134,7 +134,7 @@ impl SerializedValues { if self.contains_names { return Err(SerializeValuesError::MixingNamedAndNotNamedValues); } - if self.values_num == i16::MAX { + if self.values_num == u16::MAX { return Err(SerializeValuesError::TooManyValues); } @@ -158,7 +158,7 @@ impl SerializedValues { return Err(SerializeValuesError::MixingNamedAndNotNamedValues); } self.contains_names = true; - if self.values_num == i16::MAX { + if self.values_num == u16::MAX { return Err(SerializeValuesError::TooManyValues); } @@ -184,7 +184,7 @@ impl SerializedValues { } pub fn write_to_request(&self, buf: &mut impl BufMut) { - buf.put_i16(self.values_num); + buf.put_u16(self.values_num); buf.put(&self.serialized_values[..]); } @@ -192,7 +192,7 @@ impl SerializedValues { self.values_num == 0 } - pub fn len(&self) -> i16 { + pub fn len(&self) -> u16 { self.values_num } diff --git a/scylla/Cargo.toml b/scylla/Cargo.toml index 3408d10330..2460e020b9 100644 --- a/scylla/Cargo.toml +++ b/scylla/Cargo.toml @@ -60,7 +60,6 @@ criterion = "0.4" # Note: v0.5 needs at least rust 1.70.0 tracing-subscriber = { version = "0.3.14", features = ["env-filter"] } assert_matches = "1.5.0" rand_chacha = "0.3.1" -bcs = "0.1.5" [[bench]] name = "benchmark" diff --git a/scylla/src/statement/prepared_statement.rs b/scylla/src/statement/prepared_statement.rs index b57d5d4b23..9814e7350d 100644 --- a/scylla/src/statement/prepared_statement.rs +++ b/scylla/src/statement/prepared_statement.rs @@ -339,7 +339,7 @@ impl PreparedStatement { #[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)] pub enum PartitionKeyExtractionError { #[error("No value with given pk_index! pk_index: {0}, values.len(): {1}")] - NoPkIndexValue(u16, i16), + NoPkIndexValue(u16, u16), } #[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)] diff --git a/scylla/src/transport/large_batch_statements_test.rs b/scylla/src/transport/large_batch_statements_test.rs index 6195de30df..b2a5e8eaf0 100644 --- a/scylla/src/transport/large_batch_statements_test.rs +++ b/scylla/src/transport/large_batch_statements_test.rs @@ -1,14 +1,12 @@ -use bcs::serialize_into; -use scylla_cql::errors::{BadQuery, QueryError}; - use crate::batch::BatchType; use crate::query::Query; use crate::{ batch::Batch, - prepared_statement::PreparedStatement, test_utils::{create_new_session_builder, unique_keyspace_name}, - IntoTypedRows, QueryResult, Session, + QueryResult, Session, }; +use assert_matches::assert_matches; +use scylla_cql::errors::{BadQuery, DbError, QueryError}; #[tokio::test] async fn test_large_batch_statements() { @@ -16,48 +14,54 @@ async fn test_large_batch_statements() { let ks = unique_keyspace_name(); session = create_test_session(session, &ks).await; + // Add batch let max_number_of_queries = u16::MAX as usize; - write_batch(&session, max_number_of_queries).await; + let batch_result = write_batch(&session, max_number_of_queries, &ks).await; - let key_prefix = vec![0]; - let keys = find_keys_by_prefix(&session, key_prefix.clone()).await; - assert_eq!(keys.len(), max_number_of_queries); + if batch_result.is_err() { + assert_matches!( + batch_result.unwrap_err(), + QueryError::DbError(DbError::WriteTimeout { .. }, _) + ) + } + // Now try with too many queries let too_many_queries = u16::MAX as usize + 1; - - let err = write_batch(&session, too_many_queries).await; - - assert!(err.is_err()); + let batch_insert_result = write_batch(&session, too_many_queries, &ks).await; + assert_matches!( + batch_insert_result.unwrap_err(), + QueryError::BadQuery(BadQuery::TooManyQueriesInBatchStatement(_too_many_queries)) if _too_many_queries == too_many_queries + ) } async fn create_test_session(session: Session, ks: &String) -> Session { session .query( - format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }}",ks), + format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{ 'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1 }}",ks), &[], ) .await.unwrap(); session - .query("DROP TABLE IF EXISTS kv.pairs;", &[]) + .query(format!("DROP TABLE IF EXISTS {}.pairs;", ks), &[]) .await .unwrap(); session .query( - "CREATE TABLE IF NOT EXISTS kv.pairs (dummy int, k blob, v blob, primary key (dummy, k))", + format!("CREATE TABLE IF NOT EXISTS {}.pairs (dummy int, k blob, v blob, primary key (dummy, k))", ks), &[], ) .await.unwrap(); session } -async fn write_batch(session: &Session, n: usize) -> Result { +async fn write_batch(session: &Session, n: usize, ks: &String) -> Result { let mut batch_query = Batch::new(BatchType::Logged); let mut batch_values = Vec::new(); for i in 0..n { let mut key = vec![0]; - serialize_into(&mut key, &(i as usize)).unwrap(); + key.extend(i.to_be_bytes().as_slice()); let value = key.clone(); - let query = "INSERT INTO kv.pairs (dummy, k, v) VALUES (0, ?, ?)"; + let query = format!("INSERT INTO {}.pairs (dummy, k, v) VALUES (0, ?, ?)", ks); let values = vec![key, value]; batch_values.push(values); let query = Query::new(query); @@ -65,42 +69,3 @@ async fn write_batch(session: &Session, n: usize) -> Result) -> Vec> { - let len = key_prefix.len(); - let rows = match get_upper_bound_option(&key_prefix) { - None => { - let values = (key_prefix,); - let query = "SELECT k FROM kv.pairs WHERE dummy = 0 AND k >= ? ALLOW FILTERING"; - session.query(query, values).await.unwrap() - } - Some(upper_bound) => { - let values = (key_prefix, upper_bound); - let query = - "SELECT k FROM kv.pairs WHERE dummy = 0 AND k >= ? AND k < ? ALLOW FILTERING"; - session.query(query, values).await.unwrap() - } - }; - let mut keys = Vec::new(); - if let Some(rows) = rows.rows { - for row in rows.into_typed::<(Vec,)>() { - let key = row.unwrap(); - let short_key = key.0[len..].to_vec(); - keys.push(short_key); - } - } - keys -} - -fn get_upper_bound_option(key_prefix: &[u8]) -> Option> { - let len = key_prefix.len(); - for i in (0..len).rev() { - let val = key_prefix[i]; - if val < u8::MAX { - let mut upper_bound = key_prefix[0..i + 1].to_vec(); - upper_bound[i] += 1; - return Some(upper_bound); - } - } - None -} diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index f92067363d..2f67874f8c 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -1145,9 +1145,10 @@ impl Session { // If users batch statements by shard, they will be rewarded with full shard awareness // check to ensure that we don't send a batch statement with more than u16::MAX queries - if batch.statements.len() > u16::MAX as usize { + let batch_statements_length = batch.statements.len(); + if batch_statements_length > u16::MAX as usize { return Err(QueryError::BadQuery( - BadQuery::TooManyQueriesInBatchStatement, + BadQuery::TooManyQueriesInBatchStatement(batch_statements_length), )); } // Extract first serialized_value From ec5cf68e74b5d43e7d11d629b6fcfe53767ea703 Mon Sep 17 00:00:00 2001 From: samuel orji Date: Tue, 10 Oct 2023 19:24:20 +0100 Subject: [PATCH 4/7] review feedback: add request timeout to execution profile to enable that writes complete during cassandra tests --- .../transport/large_batch_statements_test.rs | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/scylla/src/transport/large_batch_statements_test.rs b/scylla/src/transport/large_batch_statements_test.rs index b2a5e8eaf0..7d08ea78f2 100644 --- a/scylla/src/transport/large_batch_statements_test.rs +++ b/scylla/src/transport/large_batch_statements_test.rs @@ -1,3 +1,7 @@ +use assert_matches::assert_matches; + +use scylla_cql::errors::{BadQuery, QueryError}; + use crate::batch::BatchType; use crate::query::Query; use crate::{ @@ -5,27 +9,14 @@ use crate::{ test_utils::{create_new_session_builder, unique_keyspace_name}, QueryResult, Session, }; -use assert_matches::assert_matches; -use scylla_cql::errors::{BadQuery, DbError, QueryError}; #[tokio::test] async fn test_large_batch_statements() { let mut session = create_new_session_builder().build().await.unwrap(); + let ks = unique_keyspace_name(); session = create_test_session(session, &ks).await; - // Add batch - let max_number_of_queries = u16::MAX as usize; - let batch_result = write_batch(&session, max_number_of_queries, &ks).await; - - if batch_result.is_err() { - assert_matches!( - batch_result.unwrap_err(), - QueryError::DbError(DbError::WriteTimeout { .. }, _) - ) - } - - // Now try with too many queries let too_many_queries = u16::MAX as usize + 1; let batch_insert_result = write_batch(&session, too_many_queries, &ks).await; assert_matches!( From a23e6ebb093d8363c6feb7f1c23ca9ae1f28d8dd Mon Sep 17 00:00:00 2001 From: samuel orji Date: Fri, 13 Oct 2023 13:47:42 +0100 Subject: [PATCH 5/7] review feedback: disable expensive batch write test for cassandra --- .github/workflows/cassandra.yml | 2 +- scylla/src/transport/large_batch_statements_test.rs | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/cassandra.yml b/.github/workflows/cassandra.yml index 5cc118f1f4..e22b915d46 100644 --- a/.github/workflows/cassandra.yml +++ b/.github/workflows/cassandra.yml @@ -29,7 +29,7 @@ jobs: run: cargo build --verbose --tests - name: Run tests on cassandra run: | - CDC='disabled' SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose -- --skip test_views_in_schema_info + CDC='disabled' SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose -- --skip test_views_in_schema_info --skip test_large_batch_statements - name: Stop the cluster if: ${{ always() }} run: docker compose -f test/cluster/cassandra/docker-compose.yml stop diff --git a/scylla/src/transport/large_batch_statements_test.rs b/scylla/src/transport/large_batch_statements_test.rs index 7d08ea78f2..252f1d8f4a 100644 --- a/scylla/src/transport/large_batch_statements_test.rs +++ b/scylla/src/transport/large_batch_statements_test.rs @@ -17,6 +17,11 @@ async fn test_large_batch_statements() { let ks = unique_keyspace_name(); session = create_test_session(session, &ks).await; + let max_queries = u16::MAX as usize; + let batch_insert_result = write_batch(&session, max_queries, &ks).await; + + assert!(batch_insert_result.is_ok()); + let too_many_queries = u16::MAX as usize + 1; let batch_insert_result = write_batch(&session, too_many_queries, &ks).await; assert_matches!( From b35346bec58cdc780a5f992d0150ffb829ffa291 Mon Sep 17 00:00:00 2001 From: samuel orji Date: Fri, 13 Oct 2023 18:13:16 +0100 Subject: [PATCH 6/7] review feedback --- scylla-cql/src/frame/request/batch.rs | 2 +- scylla-cql/src/frame/response/result.rs | 4 ++-- scylla-cql/src/frame/types.rs | 15 ++------------- scylla-cql/src/frame/value.rs | 2 +- .../src/transport/large_batch_statements_test.rs | 4 ++-- 5 files changed, 8 insertions(+), 19 deletions(-) diff --git a/scylla-cql/src/frame/request/batch.rs b/scylla-cql/src/frame/request/batch.rs index 3c0bad3931..35dd8c3c3b 100644 --- a/scylla-cql/src/frame/request/batch.rs +++ b/scylla-cql/src/frame/request/batch.rs @@ -190,7 +190,7 @@ impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec Result { let batch_type = buf.get_u8().try_into()?; - let statements_count: usize = types::read_short(buf)?.try_into()?; + let statements_count: usize = types::read_short(buf)?.into(); let statements_with_values = (0..statements_count) .map(|_| { let batch_statement = BatchStatement::deserialize(buf)?; diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index 288baf91eb..5ade677343 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -437,7 +437,7 @@ fn deser_type(buf: &mut &[u8]) -> StdResult { 0x0030 => { let keyspace_name: String = types::read_string(buf)?.to_string(); let type_name: String = types::read_string(buf)?.to_string(); - let fields_size: usize = types::read_short(buf)?.try_into()?; + let fields_size: usize = types::read_short(buf)?.into(); let mut field_types: Vec<(String, ColumnType)> = Vec::with_capacity(fields_size); @@ -455,7 +455,7 @@ fn deser_type(buf: &mut &[u8]) -> StdResult { } } 0x0031 => { - let len: usize = types::read_short(buf)?.try_into()?; + let len: usize = types::read_short(buf)?.into(); let mut types = Vec::with_capacity(len); for _ in 0..len { types.push(deser_type(buf)?); diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index 1c004e07cf..672fe2f97e 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -5,8 +5,8 @@ use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Buf, BufMut}; use num_enum::TryFromPrimitive; use std::collections::HashMap; +use std::convert::TryFrom; use std::convert::TryInto; -use std::convert::{Infallible, TryFrom}; use std::net::IpAddr; use std::net::SocketAddr; use std::str; @@ -98,12 +98,6 @@ impl From for ParseError { } } -impl From for ParseError { - fn from(_: Infallible) -> Self { - ParseError::BadIncomingData("Unexpected Infallible Error".to_string()) - } -} - impl From for ParseError { fn from(_err: std::array::TryFromSliceError) -> Self { ParseError::BadIncomingData("array try from slice failed".to_string()) @@ -180,18 +174,13 @@ pub fn read_short(buf: &mut &[u8]) -> Result { Ok(v) } -pub fn read_u16(buf: &mut &[u8]) -> Result { - let v = buf.read_u16::()?; - Ok(v) -} - pub fn write_short(v: u16, buf: &mut impl BufMut) { buf.put_u16(v); } pub(crate) fn read_short_length(buf: &mut &[u8]) -> Result { let v = read_short(buf)?; - let v: usize = v.try_into()?; + let v: usize = v.into(); Ok(v) } diff --git a/scylla-cql/src/frame/value.rs b/scylla-cql/src/frame/value.rs index 617dce4820..17b75ea855 100644 --- a/scylla-cql/src/frame/value.rs +++ b/scylla-cql/src/frame/value.rs @@ -77,7 +77,7 @@ pub struct CqlDuration { #[derive(Debug, Error, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum SerializeValuesError { - #[error("Too many values to add, max 32 767 values can be sent in a request")] + #[error("Too many values to add, max 65,535 values can be sent in a request")] TooManyValues, #[error("Mixing named and not named values is not allowed")] MixingNamedAndNotNamedValues, diff --git a/scylla/src/transport/large_batch_statements_test.rs b/scylla/src/transport/large_batch_statements_test.rs index 252f1d8f4a..0bbd06dfc9 100644 --- a/scylla/src/transport/large_batch_statements_test.rs +++ b/scylla/src/transport/large_batch_statements_test.rs @@ -20,7 +20,7 @@ async fn test_large_batch_statements() { let max_queries = u16::MAX as usize; let batch_insert_result = write_batch(&session, max_queries, &ks).await; - assert!(batch_insert_result.is_ok()); + batch_insert_result.unwrap(); let too_many_queries = u16::MAX as usize + 1; let batch_insert_result = write_batch(&session, too_many_queries, &ks).await; @@ -51,7 +51,7 @@ async fn create_test_session(session: Session, ks: &String) -> Session { } async fn write_batch(session: &Session, n: usize, ks: &String) -> Result { - let mut batch_query = Batch::new(BatchType::Logged); + let mut batch_query = Batch::new(BatchType::Unlogged); let mut batch_values = Vec::new(); for i in 0..n { let mut key = vec![0]; From 6ecaa03d5aa486c07d55fff900fcc12992c82cd0 Mon Sep 17 00:00:00 2001 From: samuel orji Date: Tue, 17 Oct 2023 13:52:27 +0100 Subject: [PATCH 7/7] review feedback, use prepared statements --- .../transport/large_batch_statements_test.rs | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/scylla/src/transport/large_batch_statements_test.rs b/scylla/src/transport/large_batch_statements_test.rs index 0bbd06dfc9..29482e31ce 100644 --- a/scylla/src/transport/large_batch_statements_test.rs +++ b/scylla/src/transport/large_batch_statements_test.rs @@ -33,35 +33,36 @@ async fn test_large_batch_statements() { async fn create_test_session(session: Session, ks: &String) -> Session { session .query( - format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{ 'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1 }}",ks), + format!("CREATE KEYSPACE {} WITH REPLICATION = {{ 'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1 }}",ks), &[], ) .await.unwrap(); - session - .query(format!("DROP TABLE IF EXISTS {}.pairs;", ks), &[]) - .await - .unwrap(); session .query( - format!("CREATE TABLE IF NOT EXISTS {}.pairs (dummy int, k blob, v blob, primary key (dummy, k))", ks), + format!( + "CREATE TABLE {}.pairs (dummy int, k blob, v blob, primary key (dummy, k))", + ks + ), &[], ) - .await.unwrap(); + .await + .unwrap(); session } async fn write_batch(session: &Session, n: usize, ks: &String) -> Result { let mut batch_query = Batch::new(BatchType::Unlogged); let mut batch_values = Vec::new(); + let query = format!("INSERT INTO {}.pairs (dummy, k, v) VALUES (0, ?, ?)", ks); + let query = Query::new(query); + let prepared_statement = session.prepare(query).await.unwrap(); for i in 0..n { let mut key = vec![0]; key.extend(i.to_be_bytes().as_slice()); let value = key.clone(); - let query = format!("INSERT INTO {}.pairs (dummy, k, v) VALUES (0, ?, ?)", ks); let values = vec![key, value]; batch_values.push(values); - let query = Query::new(query); - batch_query.append_statement(query); + batch_query.append_statement(prepared_statement.clone()); } session.batch(&batch_query, batch_values).await }