From 329707ddf3f360941adb3354d823ec29c43ac16a Mon Sep 17 00:00:00 2001 From: Conor Brosnan Date: Tue, 8 Nov 2022 03:20:32 +1000 Subject: [PATCH 1/2] scylla driver tests --- .../tests/cassandra_int_tests/mod.rs | 25 +++---- shotover-proxy/tests/helpers/cassandra.rs | 67 ++++++++++++++++--- 2 files changed, 72 insertions(+), 20 deletions(-) diff --git a/shotover-proxy/tests/cassandra_int_tests/mod.rs b/shotover-proxy/tests/cassandra_int_tests/mod.rs index 28ee81b85..93ca82d9e 100644 --- a/shotover-proxy/tests/cassandra_int_tests/mod.rs +++ b/shotover-proxy/tests/cassandra_int_tests/mod.rs @@ -1,8 +1,9 @@ +use crate::helpers::cassandra::{assert_query_result, ResultValue}; #[cfg(feature = "cassandra-cpp-driver-tests")] +use crate::helpers::cassandra::{run_query, CassandraDriver::Datastax}; use crate::helpers::cassandra::{ - assert_query_result, run_query, CassandraDriver::Datastax, ResultValue, + CassandraConnection, CassandraDriver, CassandraDriver::CdrsTokio, CassandraDriver::Scylla, }; -use crate::helpers::cassandra::{CassandraConnection, CassandraDriver, CassandraDriver::CdrsTokio}; use crate::helpers::ShotoverManager; #[cfg(feature = "cassandra-cpp-driver-tests")] use cassandra_cpp::{Error, ErrorKind}; @@ -22,17 +23,14 @@ use tokio::time::{sleep, timeout, Duration}; mod batch_statements; mod cache; -#[cfg(feature = "cassandra-cpp-driver-tests")] mod cluster; mod collections; mod functions; mod keyspace; mod native_types; mod prepared_statements; -#[cfg(feature = "cassandra-cpp-driver-tests")] #[cfg(feature = "alpha-transforms")] mod protect; -#[cfg(feature = "cassandra-cpp-driver-tests")] mod routing; mod table; mod udt; @@ -57,6 +55,7 @@ where #[rstest] #[case::cdrs(CdrsTokio)] #[cfg_attr(feature = "cassandra-cpp-driver-tests", case::datastax(Datastax))] +#[case::scylla(Scylla)] #[tokio::test(flavor = "multi_thread")] #[serial] async fn passthrough_standard(#[case] driver: CassandraDriver) { @@ -74,6 +73,7 @@ async fn passthrough_standard(#[case] driver: CassandraDriver) { #[rstest] #[case::cdrs(CdrsTokio)] #[cfg_attr(feature = "cassandra-cpp-driver-tests", case::datastax(Datastax))] +#[case::scylla(Scylla)] #[tokio::test(flavor = "multi_thread")] #[serial] async fn passthrough_encode(#[case] driver: CassandraDriver) { @@ -120,10 +120,10 @@ async fn source_tls_and_single_tls(#[case] driver: CassandraDriver) { standard_test_suite(&connection, driver).await; } -#[cfg(feature = "cassandra-cpp-driver-tests")] #[rstest] //#[case::cdrs(CdrsTokio)] // TODO #[cfg_attr(feature = "cassandra-cpp-driver-tests", case::datastax(Datastax))] +#[case::scylla(Scylla)] #[tokio::test(flavor = "multi_thread")] #[serial] async fn cluster_single_rack_v3(#[case] driver: CassandraDriver) { @@ -153,10 +153,10 @@ async fn cluster_single_rack_v3(#[case] driver: CassandraDriver) { cluster::single_rack_v3::test_topology_task(None).await; } -#[cfg(feature = "cassandra-cpp-driver-tests")] #[rstest] //#[case::cdrs(CdrsTokio)] // TODO #[cfg_attr(feature = "cassandra-cpp-driver-tests", case::datastax(Datastax))] +#[case::scylla(Scylla)] #[tokio::test(flavor = "multi_thread")] #[serial] async fn cluster_single_rack_v4(#[case] driver: CassandraDriver) { @@ -205,7 +205,7 @@ async fn cluster_single_rack_v4(#[case] driver: CassandraDriver) { #[cfg(feature = "cassandra-cpp-driver-tests")] #[rstest] //#[case::cdrs(CdrsTokio)] -#[cfg_attr(feature = "cassandra-cpp-driver-tests", case(Datastax))] +#[cfg_attr(feature = "cassandra-cpp-driver-tests", case::datastax(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] async fn cluster_single_rack_node_lost(#[case] driver: CassandraDriver) { @@ -217,10 +217,10 @@ async fn cluster_single_rack_node_lost(#[case] driver: CassandraDriver) { cluster::single_rack_v4::test_node_going_down(compose, shotover_manager, driver, true).await; } -#[cfg(feature = "cassandra-cpp-driver-tests")] #[rstest] //#[case::cdrs(CdrsTokio)] // TODO #[cfg_attr(feature = "cassandra-cpp-driver-tests", case::datastax(Datastax))] +#[case::scylla(Scylla)] #[tokio::test(flavor = "multi_thread")] #[serial] async fn cluster_multi_rack(#[case] driver: CassandraDriver) { @@ -302,6 +302,7 @@ async fn source_tls_and_cluster_tls(#[case] driver: CassandraDriver) { #[rstest] #[case::cdrs(CdrsTokio)] #[cfg_attr(feature = "cassandra-cpp-driver-tests", case::datastax(Datastax))] +#[case::scylla(Scylla)] #[tokio::test(flavor = "multi_thread")] #[serial] async fn cassandra_redis_cache(#[case] driver: CassandraDriver) { @@ -328,11 +329,11 @@ async fn cassandra_redis_cache(#[case] driver: CassandraDriver) { cache::test(&connection, &mut redis_connection, &snapshotter).await; } -#[cfg(feature = "cassandra-cpp-driver-tests")] #[cfg(feature = "alpha-transforms")] #[rstest] // #[case::cdrs(CdrsTokio)] // TODO #[cfg_attr(feature = "cassandra-cpp-driver-tests", case::datastax(Datastax))] +#[case::scylla(Scylla)] #[tokio::test(flavor = "multi_thread")] #[serial] async fn protect_transform_local(#[case] driver: CassandraDriver) { @@ -350,11 +351,11 @@ async fn protect_transform_local(#[case] driver: CassandraDriver) { protect::test(&shotover_connection().await, &direct_connection).await; } -#[cfg(feature = "cassandra-cpp-driver-tests")] #[cfg(feature = "alpha-transforms")] #[rstest] //#[case::cdrs(CdrsTokio)] // TODO #[cfg_attr(feature = "cassandra-cpp-driver-tests", case::datastax(Datastax))] +#[case::scylla(Scylla)] #[tokio::test(flavor = "multi_thread")] #[serial] async fn protect_transform_aws(#[case] driver: CassandraDriver) { @@ -371,10 +372,10 @@ async fn protect_transform_aws(#[case] driver: CassandraDriver) { protect::test(&shotover_connection().await, &direct_connection).await; } -#[cfg(feature = "cassandra-cpp-driver-tests")] #[rstest] //#[case::cdrs(CdrsTokio)] // TODO #[cfg_attr(feature = "cassandra-cpp-driver-tests", case::datastax(Datastax))] +#[case::scylla(Scylla)] #[tokio::test(flavor = "multi_thread")] #[serial] async fn peers_rewrite_v4(#[case] driver: CassandraDriver) { diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index 0b15bc255..9e1aea106 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -1,3 +1,4 @@ +use bytes::BufMut; #[cfg(feature = "cassandra-cpp-driver-tests")] use cassandra_cpp::{ stmt, Batch, BatchType, CassErrorCode, CassResult, Cluster, Error, ErrorKind, @@ -27,7 +28,9 @@ use cdrs_tokio::{ }; use openssl::ssl::{SslContext, SslMethod}; use ordered_float::OrderedFloat; +use scylla::batch::Batch as ScyllaBatch; use scylla::frame::response::result::CqlValue; +use scylla::frame::types::Consistency; use scylla::prepared_statement::PreparedStatement as PreparedStatementScylla; use scylla::{Session as SessionScylla, SessionBuilder as SessionBuilderScylla}; #[cfg(feature = "cassandra-cpp-driver-tests")] @@ -173,6 +176,7 @@ impl CassandraConnection { .collect::>(), ) .user("cassandra", "cassandra") + .default_consistency(Consistency::One) .build() .await .unwrap(); @@ -303,7 +307,13 @@ impl CassandraConnection { .map(|x| { x.columns .into_iter() - .map(|col| ResultValue::new_from_scylla(col.unwrap())) + .map(|col| { + if let Some(col) = col { + ResultValue::new_from_scylla(col) + } else { + ResultValue::Null + } + }) .collect() }) .collect(), @@ -498,7 +508,23 @@ impl CassandraConnection { Self::process_cdrs_response(response) } - Self::Scylla { .. } => todo!(), + Self::Scylla { session, .. } => { + let statement = prepared_query.as_scylla(); + let response = session.execute(statement, (value,)).await.unwrap(); + + if let Ok(rows) = response.rows() { + rows.into_iter() + .map(|row| { + row.columns + .into_iter() + .map(|col| ResultValue::new_from_scylla(col.unwrap())) + .collect() + }) + .collect() + } else { + vec![] + } + } } } @@ -533,7 +559,16 @@ impl CassandraConnection { session.batch(batch).await.unwrap(); } - Self::Scylla { .. } => todo!(), + Self::Scylla { session, .. } => { + let mut values = vec![]; + let mut batch: ScyllaBatch = Default::default(); + for query in queries { + batch.append_statement(query.as_str()); + values.push(()); + } + + session.batch(&batch, values).await.unwrap(); + } } } @@ -801,16 +836,32 @@ impl ResultValue { CqlValue::Blob(blob) => Self::Blob(blob), CqlValue::Boolean(b) => Self::Boolean(b), CqlValue::Counter(_counter) => todo!(), - CqlValue::Decimal(_decimal) => todo!(), + CqlValue::Decimal(d) => { + let (value, scale) = d.as_bigint_and_exponent(); + let mut buf = vec![]; + let serialized = value.to_signed_bytes_be(); + buf.put_i32(scale.try_into().unwrap()); + buf.extend_from_slice(&serialized); + Self::Decimal(buf) + } CqlValue::Float(float) => Self::Float(float.into()), CqlValue::Int(int) => Self::Int(int), - CqlValue::Timestamp(_timestamp) => todo!(), + CqlValue::Timestamp(timestamp) => Self::Timestamp(timestamp.num_milliseconds()), CqlValue::Uuid(uuid) => Self::Uuid(uuid), - CqlValue::Varint(_var_int) => todo!(), + CqlValue::Varint(var_int) => { + let mut buf = vec![]; + let serialized = var_int.to_signed_bytes_be(); + buf.extend_from_slice(&serialized); + Self::VarInt(buf) + } CqlValue::Timeuuid(timeuuid) => Self::TimeUuid(timeuuid), CqlValue::Inet(ip) => Self::Inet(ip.to_string()), - CqlValue::Date(_date) => todo!(), - CqlValue::Time(_time) => todo!(), + CqlValue::Date(date) => Self::Date(date.to_be_bytes().to_vec()), + CqlValue::Time(time) => { + let mut buf = vec![]; + buf.put_i64(time.num_nanoseconds().unwrap()); + Self::Time(buf) + } CqlValue::SmallInt(small_int) => Self::SmallInt(small_int), CqlValue::TinyInt(tiny_int) => Self::TinyInt(tiny_int), CqlValue::Duration(_duration) => todo!(), From 3f2fb2e7ef7b3e72c0f1e3e245a0dc56bf1e12bb Mon Sep 17 00:00:00 2001 From: Conor Brosnan Date: Tue, 8 Nov 2022 12:43:24 +1000 Subject: [PATCH 2/2] review feedback --- shotover-proxy/tests/helpers/cassandra.rs | 146 +++++++++++----------- 1 file changed, 72 insertions(+), 74 deletions(-) diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index 9e1aea106..61be0a114 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -307,13 +307,7 @@ impl CassandraConnection { .map(|x| { x.columns .into_iter() - .map(|col| { - if let Some(col) = col { - ResultValue::new_from_scylla(col) - } else { - ResultValue::Null - } - }) + .map(ResultValue::new_from_scylla) .collect() }) .collect(), @@ -512,17 +506,17 @@ impl CassandraConnection { let statement = prepared_query.as_scylla(); let response = session.execute(statement, (value,)).await.unwrap(); - if let Ok(rows) = response.rows() { - rows.into_iter() + match response.rows { + Some(rows) => rows + .into_iter() .map(|row| { row.columns .into_iter() - .map(|col| ResultValue::new_from_scylla(col.unwrap())) + .map(ResultValue::new_from_scylla) .collect() }) - .collect() - } else { - vec![] + .collect(), + None => vec![], } } } @@ -829,68 +823,72 @@ impl ResultValue { } } - pub fn new_from_scylla(value: CqlValue) -> Self { + pub fn new_from_scylla(value: Option) -> Self { match value { - CqlValue::Ascii(ascii) => Self::Ascii(ascii), - CqlValue::BigInt(big_int) => Self::BigInt(big_int), - CqlValue::Blob(blob) => Self::Blob(blob), - CqlValue::Boolean(b) => Self::Boolean(b), - CqlValue::Counter(_counter) => todo!(), - CqlValue::Decimal(d) => { - let (value, scale) = d.as_bigint_and_exponent(); - let mut buf = vec![]; - let serialized = value.to_signed_bytes_be(); - buf.put_i32(scale.try_into().unwrap()); - buf.extend_from_slice(&serialized); - Self::Decimal(buf) - } - CqlValue::Float(float) => Self::Float(float.into()), - CqlValue::Int(int) => Self::Int(int), - CqlValue::Timestamp(timestamp) => Self::Timestamp(timestamp.num_milliseconds()), - CqlValue::Uuid(uuid) => Self::Uuid(uuid), - CqlValue::Varint(var_int) => { - let mut buf = vec![]; - let serialized = var_int.to_signed_bytes_be(); - buf.extend_from_slice(&serialized); - Self::VarInt(buf) - } - CqlValue::Timeuuid(timeuuid) => Self::TimeUuid(timeuuid), - CqlValue::Inet(ip) => Self::Inet(ip.to_string()), - CqlValue::Date(date) => Self::Date(date.to_be_bytes().to_vec()), - CqlValue::Time(time) => { - let mut buf = vec![]; - buf.put_i64(time.num_nanoseconds().unwrap()); - Self::Time(buf) - } - CqlValue::SmallInt(small_int) => Self::SmallInt(small_int), - CqlValue::TinyInt(tiny_int) => Self::TinyInt(tiny_int), - CqlValue::Duration(_duration) => todo!(), - CqlValue::Double(double) => Self::Double(double.into()), - CqlValue::Text(text) => Self::Varchar(text), - CqlValue::Empty => Self::Null, - CqlValue::List(mut list) => { - Self::List(list.drain(..).map(ResultValue::new_from_scylla).collect()) - } - CqlValue::Set(mut set) => { - Self::Set(set.drain(..).map(ResultValue::new_from_scylla).collect()) - } - CqlValue::Map(mut map) => Self::Map( - map.drain(..) - .map(|(k, v)| { - ( - ResultValue::new_from_scylla(k), - ResultValue::new_from_scylla(v), - ) - }) - .collect(), - ), - CqlValue::Tuple(mut tuple) => Self::Tuple( - tuple - .drain(..) - .map(|element| ResultValue::new_from_scylla(element.unwrap())) - .collect(), - ), - CqlValue::UserDefinedType { .. } => todo!(), + Some(value) => match value { + CqlValue::Ascii(ascii) => Self::Ascii(ascii), + CqlValue::BigInt(big_int) => Self::BigInt(big_int), + CqlValue::Blob(blob) => Self::Blob(blob), + CqlValue::Boolean(b) => Self::Boolean(b), + CqlValue::Counter(_counter) => todo!(), + CqlValue::Decimal(d) => { + let (value, scale) = d.as_bigint_and_exponent(); + let mut buf = vec![]; + let serialized = value.to_signed_bytes_be(); + buf.put_i32(scale.try_into().unwrap()); + buf.extend_from_slice(&serialized); + Self::Decimal(buf) + } + CqlValue::Float(float) => Self::Float(float.into()), + CqlValue::Int(int) => Self::Int(int), + CqlValue::Timestamp(timestamp) => Self::Timestamp(timestamp.num_milliseconds()), + CqlValue::Uuid(uuid) => Self::Uuid(uuid), + CqlValue::Varint(var_int) => { + let mut buf = vec![]; + let serialized = var_int.to_signed_bytes_be(); + buf.extend_from_slice(&serialized); + Self::VarInt(buf) + } + CqlValue::Timeuuid(timeuuid) => Self::TimeUuid(timeuuid), + CqlValue::Inet(ip) => Self::Inet(ip.to_string()), + CqlValue::Date(date) => Self::Date(date.to_be_bytes().to_vec()), + CqlValue::Time(time) => { + let mut buf = vec![]; + buf.put_i64(time.num_nanoseconds().unwrap()); + Self::Time(buf) + } + CqlValue::SmallInt(small_int) => Self::SmallInt(small_int), + CqlValue::TinyInt(tiny_int) => Self::TinyInt(tiny_int), + CqlValue::Duration(_duration) => todo!(), + CqlValue::Double(double) => Self::Double(double.into()), + CqlValue::Text(text) => Self::Varchar(text), + CqlValue::Empty => Self::Null, + CqlValue::List(mut list) => Self::List( + list.drain(..) + .map(|v| ResultValue::new_from_scylla(Some(v))) + .collect(), + ), + CqlValue::Set(mut set) => Self::Set( + set.drain(..) + .map(|v| ResultValue::new_from_scylla(Some(v))) + .collect(), + ), + CqlValue::Map(mut map) => Self::Map( + map.drain(..) + .map(|(k, v)| { + ( + ResultValue::new_from_scylla(Some(k)), + ResultValue::new_from_scylla(Some(v)), + ) + }) + .collect(), + ), + CqlValue::Tuple(mut tuple) => { + Self::Tuple(tuple.drain(..).map(ResultValue::new_from_scylla).collect()) + } + CqlValue::UserDefinedType { .. } => todo!(), + }, + None => Self::Null, } } }