From a46fefab0585c9e59dc5916beefe35fb94ca3999 Mon Sep 17 00:00:00 2001 From: Conor Date: Thu, 15 Sep 2022 15:31:40 +1000 Subject: [PATCH] cdrs-tokio test suite integration (#741) --- Cargo.lock | 26 + shotover-proxy/Cargo.toml | 3 + shotover-proxy/benches/benches/cassandra.rs | 76 +-- shotover-proxy/src/message/mod.rs | 2 +- .../cassandra_int_tests/batch_statements.rs | 31 +- .../tests/cassandra_int_tests/collections.rs | 74 ++- .../tests/cassandra_int_tests/functions.rs | 6 +- .../tests/cassandra_int_tests/mod.rs | 224 +++++--- .../tests/cassandra_int_tests/native_types.rs | 2 +- .../prepared_statements.rs | 52 +- .../tests/cassandra_int_tests/protect.rs | 42 +- shotover-proxy/tests/examples/mod.rs | 7 +- shotover-proxy/tests/helpers/cassandra.rs | 538 ++++++++++++++---- 13 files changed, 757 insertions(+), 326 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 65f786282..9b61d150f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2372,6 +2372,31 @@ dependencies = [ "winapi", ] +[[package]] +name = "rstest" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9c9dc66cc29792b663ffb5269be669f1613664e69ad56441fdb895c2347b930" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5015e68a0685a95ade3eee617ff7101ab6a3fc689203101ca16ebc16f2b89c66" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "rustc_version", + "syn", +] + [[package]] name = "rusoto_core" version = "0.48.0" @@ -2791,6 +2816,7 @@ dependencies = [ "redis", "redis-protocol", "reqwest", + "rstest", "rusoto_kms", "rusoto_signature", "scylla", diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 99b9f4604..478f166b6 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -11,6 +11,7 @@ license = "Apache-2.0" [features] # Include WIP alpha transforms in the public API alpha-transforms = [] +cassandra-cpp-driver-tests = [] [dependencies] pretty-hex = "0.3.0" @@ -92,7 +93,9 @@ reqwest = "0.11.6" metrics-util = "0.14.0" cdrs-tokio = { git = "https://github.com/krojew/cdrs-tokio" } scylla = { version = "0.5.0", features = ["ssl"] } +rstest = "0.15.0" [[bench]] name = "benches" harness = false +required-features = ["cassandra-cpp-driver-tests"] diff --git a/shotover-proxy/benches/benches/cassandra.rs b/shotover-proxy/benches/benches/cassandra.rs index 80d229d18..4ad987029 100644 --- a/shotover-proxy/benches/benches/cassandra.rs +++ b/shotover-proxy/benches/benches/cassandra.rs @@ -1,7 +1,7 @@ -use crate::helpers::cassandra::CassandraConnection; +use crate::helpers::cassandra::{CassandraConnection, CassandraDriver}; use crate::helpers::ShotoverManager; use cassandra_cpp::{stmt, Session, Statement}; -use criterion::{criterion_group, Criterion}; +use criterion::{criterion_group, criterion_main, Criterion}; use test_helpers::cert::generate_cassandra_test_certs; use test_helpers::docker_compose::DockerCompose; use test_helpers::lazy::new_lazy_shared; @@ -11,6 +11,8 @@ struct Query { statement: Statement, } +const DRIVER: CassandraDriver = CassandraDriver::Datastax; + fn cassandra(c: &mut Criterion) { let mut group = c.benchmark_group("cassandra"); group.throughput(criterion::Throughput::Elements(1)); @@ -45,7 +47,7 @@ fn cassandra(c: &mut Criterion) { |b, resources| { b.iter(|| { let mut resources = resources.borrow_mut(); - let connection = &mut resources.as_mut().unwrap().connection; + let connection = &mut resources.as_mut().unwrap().get_connection(); connection.execute(&query.statement).wait().unwrap(); }) }, @@ -68,7 +70,7 @@ fn cassandra(c: &mut Criterion) { |b, resources| { b.iter(|| { let mut resources = resources.borrow_mut(); - let connection = &mut resources.as_mut().unwrap().connection; + let connection = &mut resources.as_mut().unwrap().get_connection(); connection.execute(&query.statement).wait().unwrap(); }) }, @@ -90,7 +92,7 @@ fn cassandra(c: &mut Criterion) { |b, resources| { b.iter(|| { let mut resources = resources.borrow_mut(); - let connection = &mut resources.as_mut().unwrap().connection; + let connection = &mut resources.as_mut().unwrap().get_connection(); connection.execute(&query.statement).wait().unwrap(); }) }, @@ -113,7 +115,7 @@ fn cassandra(c: &mut Criterion) { |b, resources| { b.iter(|| { let mut resources = resources.borrow_mut(); - let connection = &mut resources.as_mut().unwrap().connection; + let connection = &mut resources.as_mut().unwrap().get_connection(); connection.execute(&query.statement).wait().unwrap(); }) }, @@ -136,7 +138,7 @@ fn cassandra(c: &mut Criterion) { |b, resources| { b.iter(|| { let mut resources = resources.borrow_mut(); - let connection = &mut resources.as_mut().unwrap().connection; + let connection = &mut resources.as_mut().unwrap().get_connection(); connection.execute(&query.statement).wait().unwrap(); }) }, @@ -155,7 +157,7 @@ fn cassandra(c: &mut Criterion) { group.bench_with_input(format!("tls_{}", query.name), &resources, |b, resources| { b.iter(|| { let mut resources = resources.borrow_mut(); - let connection = &mut resources.as_mut().unwrap().connection; + let connection = &mut resources.as_mut().unwrap().get_connection(); connection.execute(&query.statement).wait().unwrap(); }) }); @@ -181,21 +183,21 @@ fn cassandra(c: &mut Criterion) { ); resources - .connection + .get_connection() .execute(&stmt!( "CREATE KEYSPACE test_protect_keyspace WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };" )) .wait() .unwrap(); resources - .connection + .get_connection() .execute(&stmt!( "CREATE TABLE test_protect_keyspace.test_table (pk varchar PRIMARY KEY, cluster varchar, col1 blob, col2 int, col3 boolean);" )) .wait() .unwrap(); resources - .connection + .get_connection() .execute(&stmt!( "INSERT INTO test_protect_keyspace.test_table (pk, cluster, col1, col2, col3) VALUES ('pk1', 'cluster', 'Initial value', 42, true);" )) @@ -212,7 +214,7 @@ fn cassandra(c: &mut Criterion) { |b, resources| { b.iter(|| { let mut resources = resources.borrow_mut(); - let connection = &mut resources.as_mut().unwrap().connection; + let connection = &mut resources.as_mut().unwrap().get_connection(); connection.execute(&query.statement).wait().unwrap(); }) }, @@ -234,7 +236,7 @@ fn cassandra(c: &mut Criterion) { |b, resources| { b.iter(|| { let mut resources = resources.borrow_mut(); - let connection = &mut resources.as_mut().unwrap().connection; + let connection = &mut resources.as_mut().unwrap().get_connection(); connection.execute(&query.statement).wait().unwrap(); }) }, @@ -244,11 +246,12 @@ fn cassandra(c: &mut Criterion) { } criterion_group!(benches, cassandra); +criterion_main!(benches); pub struct BenchResources { _compose: DockerCompose, _shotover_manager: ShotoverManager, - connection: Session, + connection: CassandraConnection, } impl BenchResources { @@ -256,10 +259,8 @@ impl BenchResources { let compose = DockerCompose::new(compose_file); let shotover_manager = ShotoverManager::from_topology_file(shotover_topology); - let CassandraConnection::Datastax { - session: connection, - .. - } = futures::executor::block_on(CassandraConnection::new("127.0.0.1", 9042)); + let connection = + futures::executor::block_on(CassandraConnection::new("127.0.0.1", 9042, DRIVER)); let bench_resources = Self { _compose: compose, @@ -270,6 +271,10 @@ impl BenchResources { bench_resources } + pub fn get_connection(&self) -> &Session { + self.connection.as_datastax() + } + fn new_tls(shotover_topology: &str, compose_file: &str) -> Self { generate_cassandra_test_certs(); let compose = DockerCompose::new(compose_file); @@ -277,10 +282,7 @@ impl BenchResources { let ca_cert = "example-configs/cassandra-tls/certs/localhost_CA.crt"; - let CassandraConnection::Datastax { - session: connection, - .. - } = futures::executor::block_on(CassandraConnection::new_tls("127.0.0.1", 9042, ca_cert)); + let connection = CassandraConnection::new_tls("127.0.0.1", 9042, ca_cert, DRIVER); let bench_resources = Self { _compose: compose, @@ -292,23 +294,33 @@ impl BenchResources { } fn setup(&self) { + let create_keyspace = stmt!( + "CREATE KEYSPACE benchmark_keyspace WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };" + ); + + let create_table = stmt!( + "CREATE TABLE benchmark_keyspace.table_1 (id int PRIMARY KEY, x int, name varchar);" + ); + + let insert = stmt!( + "INSERT INTO benchmark_keyspace.table_1 (id, x, name) VALUES (0, 10, 'initial value');" + ); + self.connection - .execute(&stmt!( - "CREATE KEYSPACE benchmark_keyspace WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };" - )) - .wait().unwrap(); + .as_datastax() + .execute(&create_keyspace) + .wait() + .unwrap(); self.connection - .execute(&stmt!( - "CREATE TABLE benchmark_keyspace.table_1 (id int PRIMARY KEY, x int, name varchar);" - )) + .as_datastax() + .execute(&create_table) .wait() .unwrap(); self.connection - .execute(&stmt!( - "INSERT INTO benchmark_keyspace.table_1 (id, x, name) VALUES (0, 10, 'initial value');" - )) + .as_datastax() + .execute(&insert) .wait() .unwrap(); } diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index 4232a0d55..907ebee53 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -298,7 +298,7 @@ impl Message { Metadata::Cassandra(metadata) => { let body = CassandraOperation::Error(ErrorBody { error_code: 0x1001, - message: "".into(), + message: "Server overloaded".into(), additional_info: AdditionalErrorInfo::Overloaded, }); diff --git a/shotover-proxy/tests/cassandra_int_tests/batch_statements.rs b/shotover-proxy/tests/cassandra_int_tests/batch_statements.rs index 111d1740e..cb42143e3 100644 --- a/shotover-proxy/tests/cassandra_int_tests/batch_statements.rs +++ b/shotover-proxy/tests/cassandra_int_tests/batch_statements.rs @@ -1,5 +1,4 @@ use crate::helpers::cassandra::{assert_query_result, run_query, CassandraConnection, ResultValue}; -use cassandra_cpp::{stmt, Batch, BatchType}; pub async fn test(connection: &CassandraConnection) { // setup keyspace and table for the batch statement tests @@ -9,12 +8,11 @@ pub async fn test(connection: &CassandraConnection) { } { - let mut batch = Batch::new(BatchType::LOGGED); + let mut batch = vec![]; for i in 0..2 { - let statement = format!("INSERT INTO batch_keyspace.batch_table (id, lastname, firstname) VALUES ({}, 'text1', 'text2')", i); - batch.add_statement(&stmt!(statement.as_str())).unwrap(); + batch.push(format!("INSERT INTO batch_keyspace.batch_table (id, lastname, firstname) VALUES ({}, 'text1', 'text2')", i)); } - connection.execute_batch(&batch); + connection.execute_batch(batch); assert_query_result( connection, @@ -36,15 +34,14 @@ pub async fn test(connection: &CassandraConnection) { } { - let mut batch = Batch::new(BatchType::LOGGED); + let mut batch = vec![]; for i in 0..2 { - let statement = format!( + batch.push(format!( "UPDATE batch_keyspace.batch_table SET lastname = 'text3' WHERE id = {};", i - ); - batch.add_statement(&stmt!(statement.as_str())).unwrap(); + )); } - connection.execute_batch(&batch); + connection.execute_batch(batch); assert_query_result( connection, @@ -66,18 +63,20 @@ pub async fn test(connection: &CassandraConnection) { } { - let mut batch = Batch::new(BatchType::LOGGED); + let mut batch = vec![]; for i in 0..2 { - let statement = format!("DELETE FROM batch_keyspace.batch_table WHERE id = {};", i); - batch.add_statement(&stmt!(statement.as_str())).unwrap(); + batch.push(format!( + "DELETE FROM batch_keyspace.batch_table WHERE id = {};", + i + )); } - connection.execute_batch(&batch); + connection.execute_batch(batch); assert_query_result(connection, "SELECT * FROM batch_keyspace.batch_table;", &[]).await; } { - let batch = Batch::new(BatchType::LOGGED); - connection.execute_batch(&batch); + let batch = vec![]; + connection.execute_batch(batch); } // test batch statements over QUERY PROTOCOL diff --git a/shotover-proxy/tests/cassandra_int_tests/collections.rs b/shotover-proxy/tests/cassandra_int_tests/collections.rs index f894c49de..d10b97d26 100644 --- a/shotover-proxy/tests/cassandra_int_tests/collections.rs +++ b/shotover-proxy/tests/cassandra_int_tests/collections.rs @@ -1,4 +1,6 @@ -use crate::helpers::cassandra::{assert_query_result, run_query, CassandraConnection, ResultValue}; +use crate::helpers::cassandra::{ + assert_query_result, run_query, CassandraConnection, CassandraDriver, ResultValue, +}; use cassandra_protocol::frame::message_result::ColType; fn get_map_example(value: &str) -> String { @@ -217,7 +219,7 @@ mod list { } } - async fn select(session: &CassandraConnection) { + async fn select(session: &CassandraConnection, driver: CassandraDriver) { // select lists of native types for (i, col_type) in NATIVE_COL_TYPES.iter().enumerate() { let query = format!( @@ -235,6 +237,12 @@ mod list { .await; } + let new_set = match driver { + CassandraDriver::CdrsTokio => ResultValue::List, + #[cfg(feature = "cassandra-cpp-driver-tests")] + CassandraDriver::Datastax => ResultValue::Set, + }; + // test selecting list of sets for (i, native_col_type) in NATIVE_COL_TYPES.iter().enumerate() { assert_query_result( @@ -244,7 +252,7 @@ mod list { i ) .as_str(), - &[&[ResultValue::List(vec![ResultValue::Set(vec![ + &[&[ResultValue::List(vec![new_set(vec![ get_type_example_result_value(*native_col_type), ])])]], ) @@ -285,10 +293,10 @@ mod list { } } - pub async fn test(session: &CassandraConnection) { + pub async fn test(session: &CassandraConnection, driver: CassandraDriver) { create(session).await; insert(session).await; - select(session).await; + select(session, driver).await; } } @@ -387,7 +395,13 @@ mod set { } } - async fn select(session: &CassandraConnection) { + async fn select(session: &CassandraConnection, driver: CassandraDriver) { + let new_set = match driver { + CassandraDriver::CdrsTokio => ResultValue::List, + #[cfg(feature = "cassandra-cpp-driver-tests")] + CassandraDriver::Datastax => ResultValue::Set, + }; + // select sets of native types for (i, col_type) in NATIVE_COL_TYPES.iter().enumerate() { let query = format!( @@ -398,15 +412,19 @@ mod set { assert_query_result( session, query.as_str(), - &[&[ResultValue::Set(vec![get_type_example_result_value( - *col_type, - )])]], + &[&[new_set(vec![get_type_example_result_value(*col_type)])]], ) .await; } // test selecting set of sets for (i, native_col_type) in NATIVE_COL_TYPES.iter().enumerate() { + let new_set = match driver { + CassandraDriver::CdrsTokio => ResultValue::List, + #[cfg(feature = "cassandra-cpp-driver-tests")] + CassandraDriver::Datastax => ResultValue::Set, + }; + assert_query_result( session, format!( @@ -414,9 +432,9 @@ mod set { i ) .as_str(), - &[&[ResultValue::Set(vec![ResultValue::Set(vec![ - get_type_example_result_value(*native_col_type), - ])])]], + &[&[new_set(vec![new_set(vec![get_type_example_result_value( + *native_col_type, + )])])]], ) .await; } @@ -430,7 +448,7 @@ mod set { i ) .as_str(), - &[&[ResultValue::Set(vec![ResultValue::List(vec![ + &[&[new_set(vec![ResultValue::List(vec![ get_type_example_result_value(*native_col_type), ])])]], ) @@ -446,7 +464,7 @@ mod set { i, ) .as_str(), - &[&[ResultValue::Set(vec![ResultValue::Map(vec![( + &[&[new_set(vec![ResultValue::Map(vec![( ResultValue::Int(0), get_type_example_result_value(*native_col_type), )])])]], @@ -455,10 +473,10 @@ mod set { } } - pub async fn test(session: &CassandraConnection) { + pub async fn test(session: &CassandraConnection, driver: CassandraDriver) { create(session).await; insert(session).await; - select(session).await; + select(session, driver).await; } } @@ -565,8 +583,8 @@ mod map { } } - async fn select(session: &CassandraConnection) { - // select sets of native types + async fn select(session: &CassandraConnection, driver: CassandraDriver) { + // select map of native types for (i, col_type) in NATIVE_COL_TYPES.iter().enumerate() { let query = format!( "SELECT my_map FROM test_collections_keyspace.test_map_table_{};", @@ -584,6 +602,12 @@ mod map { .await; } + let new_set = match driver { + CassandraDriver::CdrsTokio => ResultValue::List, + #[cfg(feature = "cassandra-cpp-driver-tests")] + CassandraDriver::Datastax => ResultValue::Set, + }; + // test selecting map of sets for (i, native_col_type) in NATIVE_COL_TYPES.iter().enumerate() { assert_query_result( @@ -595,7 +619,7 @@ mod map { .as_str(), &[&[ResultValue::Map(vec![( ResultValue::Int(0), - ResultValue::Set(vec![get_type_example_result_value(*native_col_type)]), + new_set(vec![get_type_example_result_value(*native_col_type)]), )])]], ) .await; @@ -639,17 +663,17 @@ mod map { } } - pub async fn test(session: &CassandraConnection) { + pub async fn test(session: &CassandraConnection, driver: CassandraDriver) { create(session).await; insert(session).await; - select(session).await; + select(session, driver).await; } } -pub async fn test(session: &CassandraConnection) { +pub async fn test(session: &CassandraConnection, driver: CassandraDriver) { run_query(session, "CREATE KEYSPACE test_collections_keyspace WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };").await; - list::test(session).await; - set::test(session).await; - map::test(session).await; + list::test(session, driver).await; + set::test(session, driver).await; + map::test(session, driver).await; } diff --git a/shotover-proxy/tests/cassandra_int_tests/functions.rs b/shotover-proxy/tests/cassandra_int_tests/functions.rs index 29f915668..d913c6552 100644 --- a/shotover-proxy/tests/cassandra_int_tests/functions.rs +++ b/shotover-proxy/tests/cassandra_int_tests/functions.rs @@ -6,14 +6,16 @@ async fn drop_function(session: &CassandraConnection) { "SELECT test_function_keyspace.my_function(x, y) FROM test_function_keyspace.test_function_table WHERE id=1;", &[&[ResultValue::Int(4)]] ).await; + run_query(session, "DROP FUNCTION test_function_keyspace.my_function").await; } async fn create_function(session: &CassandraConnection) { run_query( session, - "CREATE FUNCTION test_function_keyspace.my_function (a int, b int) RETURNS NULL ON NULL INPUT RETURNS int LANGUAGE javascript AS 'a * b';", + "CREATE FUNCTION test_function_keyspace.my_function (a int, b int) RETURNS NULL ON NULL INPUT RETURNS int LANGUAGE javascript AS 'a * b';" ).await; + assert_query_result( session, "SELECT test_function_keyspace.my_function(x, y) FROM test_function_keyspace.test_function_table;", @@ -26,10 +28,12 @@ pub async fn test(session: &CassandraConnection) { session, "CREATE KEYSPACE test_function_keyspace WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };" ).await; + run_query( session, "CREATE TABLE test_function_keyspace.test_function_table (id int PRIMARY KEY, x int, y int);", ).await; + run_query( session, r#"BEGIN BATCH diff --git a/shotover-proxy/tests/cassandra_int_tests/mod.rs b/shotover-proxy/tests/cassandra_int_tests/mod.rs index e37bde6e0..30ec4eadc 100644 --- a/shotover-proxy/tests/cassandra_int_tests/mod.rs +++ b/shotover-proxy/tests/cassandra_int_tests/mod.rs @@ -1,59 +1,74 @@ -use crate::helpers::cassandra::{assert_query_result, run_query, CassandraConnection, ResultValue}; +#[cfg(feature = "cassandra-cpp-driver-tests")] +use crate::helpers::cassandra::{ + assert_query_result, run_query, CassandraDriver::Datastax, CassandraError, CassandraErrorCode, + ResultValue, +}; +use crate::helpers::cassandra::{CassandraConnection, CassandraDriver, CassandraDriver::CdrsTokio}; use crate::helpers::ShotoverManager; -use cassandra_cpp::{stmt, Batch, BatchType, Error, ErrorKind}; -use cdrs_tokio::authenticators::StaticPasswordAuthenticatorProvider; -use cdrs_tokio::cluster::session::{SessionBuilder, TcpSessionBuilder}; -use cdrs_tokio::cluster::NodeTcpConfigBuilder; +#[cfg(feature = "cassandra-cpp-driver-tests")] +use cassandra_cpp::{Error, ErrorKind}; use cdrs_tokio::frame::events::{ SchemaChange, SchemaChangeOptions, SchemaChangeTarget, SchemaChangeType, ServerEvent, }; -use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; +#[cfg(feature = "cassandra-cpp-driver-tests")] use futures::future::{join_all, try_join_all}; use metrics_util::debugging::DebuggingRecorder; +use rstest::rstest; use serial_test::serial; -use std::sync::Arc; use test_helpers::docker_compose::DockerCompose; use tokio::time::{sleep, timeout, Duration}; mod batch_statements; mod cache; +#[cfg(feature = "cassandra-cpp-driver-tests")] mod cluster; +#[cfg(feature = "cassandra-cpp-driver-tests")] mod cluster_multi_rack; +#[cfg(feature = "cassandra-cpp-driver-tests")] mod cluster_single_rack_v3; +#[cfg(feature = "cassandra-cpp-driver-tests")] mod cluster_single_rack_v4; mod collections; mod functions; mod keyspace; mod native_types; mod prepared_statements; +#[cfg(feature = "cassandra-cpp-driver-tests")] #[cfg(feature = "alpha-transforms")] mod protect; mod table; mod udt; +#[rstest] +#[case(CdrsTokio)] +#[cfg_attr(feature = "cassandra-cpp-driver-tests", case(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] -async fn test_passthrough() { +async fn test_passthrough(#[case] driver: CassandraDriver) { let _compose = DockerCompose::new("example-configs/cassandra-passthrough/docker-compose.yml"); let _shotover_manager = ShotoverManager::from_topology_file("example-configs/cassandra-passthrough/topology.yaml"); - let connection = CassandraConnection::new("127.0.0.1", 9042).await; + let connection = CassandraConnection::new("127.0.0.1", 9042, driver).await; keyspace::test(&connection).await; table::test(&connection).await; udt::test(&connection).await; native_types::test(&connection).await; - collections::test(&connection).await; + collections::test(&connection, driver).await; functions::test(&connection).await; prepared_statements::test(&connection).await; batch_statements::test(&connection).await; } +#[cfg(feature = "cassandra-cpp-driver-tests")] +#[rstest] +//#[case(CdrsTokio)] // TODO +#[cfg_attr(feature = "cassandra-cpp-driver-tests", case(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] -async fn test_source_tls_and_single_tls() { +async fn test_source_tls_and_single_tls(#[case] driver: CassandraDriver) { test_helpers::cert::generate_cassandra_test_certs(); let _compose = DockerCompose::new("example-configs/cassandra-tls/docker-compose.yml"); @@ -64,7 +79,7 @@ async fn test_source_tls_and_single_tls() { { // Run a quick test straight to Cassandra to check our assumptions that Shotover and Cassandra TLS are behaving exactly the same - let direct_connection = CassandraConnection::new_tls("127.0.0.1", 9042, ca_cert).await; + let direct_connection = CassandraConnection::new_tls("127.0.0.1", 9042, ca_cert, driver); assert_query_result( &direct_connection, "SELECT bootstrapped FROM system.local", @@ -73,21 +88,25 @@ async fn test_source_tls_and_single_tls() { .await; } - let connection = CassandraConnection::new_tls("127.0.0.1", 9043, ca_cert).await; + let connection = CassandraConnection::new_tls("127.0.0.1", 9043, ca_cert, driver); keyspace::test(&connection).await; table::test(&connection).await; udt::test(&connection).await; native_types::test(&connection).await; - collections::test(&connection).await; + collections::test(&connection, driver).await; functions::test(&connection).await; prepared_statements::test(&connection).await; batch_statements::test(&connection).await; } +#[cfg(feature = "cassandra-cpp-driver-tests")] +#[rstest] +//#[case(CdrsTokio)] // TODO +#[cfg_attr(feature = "cassandra-cpp-driver-tests", case(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] -async fn test_cluster_single_rack_v3() { +async fn test_cluster_single_rack_v3(#[case] driver: CassandraDriver) { let _compose = DockerCompose::new("example-configs/cassandra-cluster/docker-compose-cassandra-v3.yml"); @@ -96,7 +115,7 @@ async fn test_cluster_single_rack_v3() { "example-configs/cassandra-cluster/topology-dummy-peers-v3.yaml", ); - let mut connection1 = CassandraConnection::new("127.0.0.1", 9042).await; + let mut connection1 = CassandraConnection::new("127.0.0.1", 9042, driver).await; connection1 .enable_schema_awaiter("172.16.1.2:9042", None) .await; @@ -104,14 +123,14 @@ async fn test_cluster_single_rack_v3() { table::test(&connection1).await; udt::test(&connection1).await; native_types::test(&connection1).await; - collections::test(&connection1).await; + collections::test(&connection1, driver).await; functions::test(&connection1).await; prepared_statements::test(&connection1).await; batch_statements::test(&connection1).await; cluster_single_rack_v3::test_dummy_peers(&connection1).await; //Check for bugs in cross connection state - let mut connection2 = CassandraConnection::new("127.0.0.1", 9042).await; + let mut connection2 = CassandraConnection::new("127.0.0.1", 9042, driver).await; connection2 .enable_schema_awaiter("172.16.1.2:9042", None) .await; @@ -121,9 +140,13 @@ async fn test_cluster_single_rack_v3() { cluster_single_rack_v3::test_topology_task(None).await; } +#[cfg(feature = "cassandra-cpp-driver-tests")] +#[rstest] +//#[case(CdrsTokio)] // TODO +#[cfg_attr(feature = "cassandra-cpp-driver-tests", case(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] -async fn test_cluster_single_rack_v4() { +async fn test_cluster_single_rack_v4(#[case] driver: CassandraDriver) { let _compose = DockerCompose::new("example-configs/cassandra-cluster/docker-compose-cassandra-v4.yml"); @@ -132,22 +155,23 @@ async fn test_cluster_single_rack_v4() { "example-configs/cassandra-cluster/topology-v4.yaml", ); - let mut connection1 = CassandraConnection::new("127.0.0.1", 9042).await; + let mut connection1 = CassandraConnection::new("127.0.0.1", 9042, driver).await; connection1 .enable_schema_awaiter("172.16.1.2:9044", None) .await; + keyspace::test(&connection1).await; table::test(&connection1).await; udt::test(&connection1).await; native_types::test(&connection1).await; - collections::test(&connection1).await; + collections::test(&connection1, driver).await; functions::test(&connection1).await; prepared_statements::test(&connection1).await; batch_statements::test(&connection1).await; cluster_single_rack_v4::test(&connection1).await; //Check for bugs in cross connection state - let mut connection2 = CassandraConnection::new("127.0.0.1", 9042).await; + let mut connection2 = CassandraConnection::new("127.0.0.1", 9042, driver).await; connection2 .enable_schema_awaiter("172.16.1.2:9044", None) .await; @@ -159,7 +183,7 @@ async fn test_cluster_single_rack_v4() { "example-configs/cassandra-cluster/topology-dummy-peers-v4.yaml", ); - let mut connection = CassandraConnection::new("127.0.0.1", 9042).await; + let mut connection = CassandraConnection::new("127.0.0.1", 9042, driver).await; connection .enable_schema_awaiter("172.16.1.2:9044", None) .await; @@ -169,9 +193,13 @@ async fn test_cluster_single_rack_v4() { cluster_single_rack_v4::test_topology_task(None, Some(9044)).await; } +#[cfg(feature = "cassandra-cpp-driver-tests")] +#[rstest] +//#[case(CdrsTokio)] // TODO +#[cfg_attr(feature = "cassandra-cpp-driver-tests", case(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] -async fn test_cluster_multi_rack() { +async fn test_cluster_multi_rack(#[case] driver: CassandraDriver) { let _compose = DockerCompose::new("example-configs/cassandra-cluster-multi-rack/docker-compose.yml"); @@ -186,7 +214,7 @@ async fn test_cluster_multi_rack() { "example-configs/cassandra-cluster-multi-rack/topology_rack3.yaml", ); - let mut connection1 = CassandraConnection::new("127.0.0.1", 9042).await; + let mut connection1 = CassandraConnection::new("127.0.0.1", 9042, driver).await; connection1 .enable_schema_awaiter("172.16.1.2:9042", None) .await; @@ -194,14 +222,14 @@ async fn test_cluster_multi_rack() { table::test(&connection1).await; udt::test(&connection1).await; native_types::test(&connection1).await; - collections::test(&connection1).await; + collections::test(&connection1, driver).await; functions::test(&connection1).await; prepared_statements::test(&connection1).await; batch_statements::test(&connection1).await; cluster_multi_rack::test(&connection1).await; //Check for bugs in cross connection state - let mut connection2 = CassandraConnection::new("127.0.0.1", 9042).await; + let mut connection2 = CassandraConnection::new("127.0.0.1", 9042, driver).await; connection2 .enable_schema_awaiter("172.16.1.2:9042", None) .await; @@ -211,11 +239,16 @@ async fn test_cluster_multi_rack() { cluster_multi_rack::test_topology_task(None).await; } +#[cfg(feature = "cassandra-cpp-driver-tests")] +#[rstest] +//#[case(CdrsTokio)] // TODO +#[cfg_attr(feature = "cassandra-cpp-driver-tests", case(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] -async fn test_source_tls_and_cluster_tls() { +async fn test_source_tls_and_cluster_tls(#[case] driver: CassandraDriver) { test_helpers::cert::generate_cassandra_test_certs(); let ca_cert = "example-configs/cassandra-tls/certs/localhost_CA.crt"; + let _compose = DockerCompose::new("example-configs/cassandra-cluster-tls/docker-compose.yml"); { let _shotover_manager = ShotoverManager::from_topology_file( @@ -224,7 +257,8 @@ async fn test_source_tls_and_cluster_tls() { { // Run a quick test straight to Cassandra to check our assumptions that Shotover and Cassandra TLS are behaving exactly the same - let direct_connection = CassandraConnection::new_tls("172.16.1.2", 9042, ca_cert).await; + let direct_connection = + CassandraConnection::new_tls("172.16.1.2", 9042, ca_cert, driver); assert_query_result( &direct_connection, "SELECT bootstrapped FROM system.local", @@ -233,7 +267,12 @@ async fn test_source_tls_and_cluster_tls() { .await; } - let mut connection = CassandraConnection::new_tls("127.0.0.1", 9042, ca_cert).await; + let mut connection = CassandraConnection::new_tls("127.0.0.1", 9042, ca_cert, driver); + connection + .enable_schema_awaiter("172.16.1.2:9042", Some(ca_cert)) + .await; + + let mut connection = CassandraConnection::new_tls("127.0.0.1", 9042, ca_cert, driver); connection .enable_schema_awaiter("172.16.1.2:9042", Some(ca_cert)) .await; @@ -243,7 +282,7 @@ async fn test_source_tls_and_cluster_tls() { udt::test(&connection).await; native_types::test(&connection).await; functions::test(&connection).await; - collections::test(&connection).await; + collections::test(&connection, driver).await; prepared_statements::test(&connection).await; batch_statements::test(&connection).await; cluster_single_rack_v4::test(&connection).await; @@ -252,9 +291,12 @@ async fn test_source_tls_and_cluster_tls() { cluster_single_rack_v4::test_topology_task(Some(ca_cert), None).await; } +#[rstest] +#[case(CdrsTokio)] +#[cfg_attr(feature = "cassandra-cpp-driver-tests", case(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] -async fn test_cassandra_redis_cache() { +async fn test_cassandra_redis_cache(#[case] driver: CassandraDriver) { let recorder = DebuggingRecorder::new(); let snapshotter = recorder.snapshotter(); recorder.install().unwrap(); @@ -265,7 +307,7 @@ async fn test_cassandra_redis_cache() { ); let mut redis_connection = shotover_manager.redis_connection(6379); - let connection = CassandraConnection::new("127.0.0.1", 9042).await; + let connection = CassandraConnection::new("127.0.0.1", 9042, driver).await; keyspace::test(&connection).await; table::test(&connection).await; @@ -276,55 +318,67 @@ async fn test_cassandra_redis_cache() { cache::test(&connection, &mut redis_connection, &snapshotter).await; } +#[cfg(feature = "cassandra-cpp-driver-tests")] +#[cfg(feature = "alpha-transforms")] +#[rstest] +// #[case(CdrsTokio)] // TODO +#[cfg_attr(feature = "cassandra-cpp-driver-tests", case(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] -#[cfg(feature = "alpha-transforms")] -async fn test_cassandra_protect_transform_local() { +async fn test_cassandra_protect_transform_local(#[case] driver: CassandraDriver) { let _compose = DockerCompose::new("example-configs/cassandra-protect-local/docker-compose.yml"); let _shotover_manager = ShotoverManager::from_topology_file( "example-configs/cassandra-protect-local/topology.yaml", ); - let shotover_connection = CassandraConnection::new("127.0.0.1", 9042).await; - let direct_connection = CassandraConnection::new("127.0.0.1", 9043).await; + let shotover_connection = CassandraConnection::new("127.0.0.1", 9042, driver).await; + let direct_connection = CassandraConnection::new("127.0.0.1", 9043, driver).await; keyspace::test(&shotover_connection).await; table::test(&shotover_connection).await; udt::test(&shotover_connection).await; native_types::test(&shotover_connection).await; - collections::test(&shotover_connection).await; + collections::test(&shotover_connection, driver).await; functions::test(&shotover_connection).await; batch_statements::test(&shotover_connection).await; protect::test(&shotover_connection, &direct_connection).await; } +#[cfg(feature = "cassandra-cpp-driver-tests")] +#[cfg(feature = "alpha-transforms")] +#[rstest] +//#[case(CdrsTokio)] // TODO +#[cfg_attr(feature = "cassandra-cpp-driver-tests", case(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] -#[cfg(feature = "alpha-transforms")] -async fn test_cassandra_protect_transform_aws() { +async fn test_cassandra_protect_transform_aws(#[case] driver: CassandraDriver) { let _compose = DockerCompose::new("example-configs/cassandra-protect-aws/docker-compose.yml"); let _compose_aws = DockerCompose::new_moto(); let _shotover_manager = ShotoverManager::from_topology_file("example-configs/cassandra-protect-aws/topology.yaml"); - let shotover_connection = CassandraConnection::new("127.0.0.1", 9042).await; - let direct_connection = CassandraConnection::new("127.0.0.1", 9043).await; + let shotover_connection = CassandraConnection::new("127.0.0.1", 9042, driver).await; + let direct_connection = CassandraConnection::new("127.0.0.1", 9043, driver).await; keyspace::test(&shotover_connection).await; table::test(&shotover_connection).await; udt::test(&shotover_connection).await; native_types::test(&shotover_connection).await; - collections::test(&shotover_connection).await; + collections::test(&shotover_connection, driver).await; functions::test(&shotover_connection).await; batch_statements::test(&shotover_connection).await; protect::test(&shotover_connection, &direct_connection).await; } +#[cfg(feature = "cassandra-cpp-driver-tests")] +#[rstest] +//#[case(CdrsTokio)] // TODO +#[cfg_attr(feature = "cassandra-cpp-driver-tests", case(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] -async fn test_cassandra_peers_rewrite_cassandra4() { +async fn test_cassandra_peers_rewrite_cassandra4(#[case] driver: CassandraDriver) { let _docker_compose = DockerCompose::new( "tests/test-configs/cassandra-peers-rewrite/docker-compose-4.0-cassandra.yaml", ); @@ -333,8 +387,8 @@ async fn test_cassandra_peers_rewrite_cassandra4() { "tests/test-configs/cassandra-peers-rewrite/topology.yaml", ); - let normal_connection = CassandraConnection::new("127.0.0.1", 9043).await; - let rewrite_port_connection = CassandraConnection::new("127.0.0.1", 9044).await; + let normal_connection = CassandraConnection::new("127.0.0.1", 9043, driver).await; + let rewrite_port_connection = CassandraConnection::new("127.0.0.1", 9044, driver).await; // run some basic tests to confirm it works as normal table::test(&normal_connection).await; @@ -412,9 +466,13 @@ async fn test_cassandra_peers_rewrite_cassandra4() { } } +#[cfg(feature = "cassandra-cpp-driver-tests")] +#[rstest] +//#[case(CdrsTokio)] // TODO +#[cfg_attr(feature = "cassandra-cpp-driver-tests", case(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] -async fn test_cassandra_peers_rewrite_cassandra3() { +async fn test_cassandra_peers_rewrite_cassandra3(#[case] driver: CassandraDriver) { let _docker_compose = DockerCompose::new( "tests/test-configs/cassandra-peers-rewrite/docker-compose-3.11-cassandra.yaml", ); @@ -423,7 +481,7 @@ async fn test_cassandra_peers_rewrite_cassandra3() { "tests/test-configs/cassandra-peers-rewrite/topology.yaml", ); - let connection = CassandraConnection::new("127.0.0.1", 9044).await; + let connection = CassandraConnection::new("127.0.0.1", 9044, driver).await; // run some basic tests to confirm it works as normal table::test(&connection).await; @@ -431,27 +489,31 @@ async fn test_cassandra_peers_rewrite_cassandra3() { // is passed through shotover unchanged. let statement = "SELECT data_center, native_port, rack FROM system.peers_v2;"; let result = connection.execute_expect_err(statement); - assert!(matches!( + assert_eq!( result, - Error( - ErrorKind::CassErrorResult(cassandra_cpp::CassErrorCode::SERVER_INVALID_QUERY, ..), - _ - ) - )); + CassandraError { + code: CassandraErrorCode::InvalidQuery, + message: "unconfigured table peers_v2".into() + } + ); } +#[cfg(feature = "cassandra-cpp-driver-tests")] +#[rstest] +//#[case(CdrsTokio)] // TODO +#[cfg_attr(feature = "cassandra-cpp-driver-tests", case(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] -async fn test_cassandra_request_throttling() { +async fn test_cassandra_request_throttling(#[case] driver: CassandraDriver) { let _docker_compose = DockerCompose::new("example-configs/cassandra-passthrough/docker-compose.yml"); let _shotover_manager = ShotoverManager::from_topology_file("tests/test-configs/cassandra-request-throttling.yaml"); - let connection = CassandraConnection::new("127.0.0.1", 9042).await; + let connection = CassandraConnection::new("127.0.0.1", 9042, driver).await; std::thread::sleep(std::time::Duration::from_secs(1)); // sleep to reset the window and not trigger the rate limiter with client's startup reqeusts - let connection_2 = CassandraConnection::new("127.0.0.1", 9042).await; + let connection_2 = CassandraConnection::new("127.0.0.1", 9042, driver).await; std::thread::sleep(std::time::Duration::from_secs(1)); // sleep to reset the window again let statement = "SELECT * FROM system.peers"; @@ -503,31 +565,29 @@ async fn test_cassandra_request_throttling() { // this batch set should be allowed through { - let mut batch = Batch::new(BatchType::LOGGED); + let mut queries: Vec = vec![]; for i in 0..25 { - let statement = format!("INSERT INTO test_keyspace.my_table (id, lastname, firstname) VALUES ({}, 'text', 'text')", i); - batch.add_statement(&stmt!(statement.as_str())).unwrap(); + queries.push(format!("INSERT INTO test_keyspace.my_table (id, lastname, firstname) VALUES ({}, 'text', 'text')", i)); } - connection.execute_batch(&batch); + connection.execute_batch(queries); } std::thread::sleep(std::time::Duration::from_secs(1)); // sleep to reset the window // this batch set should not be allowed through { - let mut batch = Batch::new(BatchType::LOGGED); + let mut queries: Vec = vec![]; for i in 0..60 { - let statement = format!("INSERT INTO test_keyspace.my_table (id, lastname, firstname) VALUES ({}, 'text', 'text')", i); - batch.add_statement(&stmt!(statement.as_str())).unwrap(); + queries.push(format!("INSERT INTO test_keyspace.my_table (id, lastname, firstname) VALUES ({}, 'text', 'text')", i)); } - let result = connection.execute_batch_expect_err(&batch); - assert!(matches!( + let result = connection.execute_batch_expect_err(queries); + assert_eq!( result, - Error( - ErrorKind::CassErrorResult(cassandra_cpp::CassErrorCode::SERVER_OVERLOADED, ..), - .. - ) - )); + CassandraError { + code: CassandraErrorCode::ServerOverloaded, + message: "Server overloaded".into() + } + ); } std::thread::sleep(std::time::Duration::from_secs(1)); // sleep to reset the window @@ -535,35 +595,25 @@ async fn test_cassandra_request_throttling() { batch_statements::test(&connection).await; } +#[rstest] +#[case(CdrsTokio)] #[tokio::test(flavor = "multi_thread")] #[serial] -async fn test_events_keyspace() { +async fn test_events_keyspace(#[case] driver: CassandraDriver) { let _docker_compose = DockerCompose::new("example-configs/cassandra-passthrough/docker-compose.yml"); let _shotover_manager = ShotoverManager::from_topology_file("example-configs/cassandra-passthrough/topology.yaml"); - let user = "cassandra"; - let password = "cassandra"; - let auth = StaticPasswordAuthenticatorProvider::new(&user, &password); - let config = NodeTcpConfigBuilder::new() - .with_contact_point("127.0.0.1:9042".into()) - .with_authenticator_provider(Arc::new(auth)) - .build() - .await - .unwrap(); - - let session = TcpSessionBuilder::new(RoundRobinLoadBalancingStrategy::new(), config) - .build() - .unwrap(); + let connection = CassandraConnection::new("127.0.0.1", 9042, driver).await; - let mut event_recv = session.create_event_receiver(); + let mut event_recv = connection.as_cdrs().create_event_receiver(); sleep(Duration::from_secs(10)).await; // let the driver finish connecting to the cluster and registering for the events let create_ks = "CREATE KEYSPACE IF NOT EXISTS test_events_ks WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"; - session.query(create_ks).await.unwrap(); + connection.execute(create_ks).await; let event = timeout(Duration::from_secs(10), event_recv.recv()) .await diff --git a/shotover-proxy/tests/cassandra_int_tests/native_types.rs b/shotover-proxy/tests/cassandra_int_tests/native_types.rs index b16e28234..aecbffc7f 100644 --- a/shotover-proxy/tests/cassandra_int_tests/native_types.rs +++ b/shotover-proxy/tests/cassandra_int_tests/native_types.rs @@ -30,7 +30,7 @@ async fn select(session: &CassandraConnection) { ResultValue::VarInt(vec![3, 5, 233]), ]], ) - .await + .await; } async fn insert(session: &CassandraConnection) { diff --git a/shotover-proxy/tests/cassandra_int_tests/prepared_statements.rs b/shotover-proxy/tests/cassandra_int_tests/prepared_statements.rs index d1ab4c920..44e300138 100644 --- a/shotover-proxy/tests/cassandra_int_tests/prepared_statements.rs +++ b/shotover-proxy/tests/cassandra_int_tests/prepared_statements.rs @@ -5,10 +5,8 @@ use crate::helpers::cassandra::{ async fn delete(session: &CassandraConnection) { let prepared = session.prepare("DELETE FROM test_prepare_statements.table_1 WHERE id = ?;"); - let mut statement = prepared.bind(); - statement.bind_int32(0, 1).unwrap(); assert_eq!( - session.execute_prepared(&statement), + session.execute_prepared(&prepared, 1), Vec::>::new() ); @@ -20,66 +18,42 @@ async fn delete(session: &CassandraConnection) { .await; } -async fn insert(session: &CassandraConnection) { - let prepared = session - .prepare("INSERT INTO test_prepare_statements.table_1 (id, x, name) VALUES (?, ?, ?);"); +fn insert(session: &CassandraConnection) { + let prepared = session.prepare("INSERT INTO test_prepare_statements.table_1 (id) VALUES (?);"); - let mut statement = prepared.bind(); - statement.bind_int32(0, 1).unwrap(); - statement.bind_int32(1, 11).unwrap(); - statement.bind_string(2, "foo").unwrap(); assert_eq!( - session.execute_prepared(&statement), + session.execute_prepared(&prepared, 1), Vec::>::new() ); - statement = prepared.bind(); - statement.bind_int32(0, 2).unwrap(); - statement.bind_int32(1, 12).unwrap(); - statement.bind_string(2, "bar").unwrap(); assert_eq!( - session.execute_prepared(&statement), + session.execute_prepared(&prepared, 2), Vec::>::new() ); - statement = prepared.bind(); - statement.bind_int32(0, 2).unwrap(); - statement.bind_int32(1, 13).unwrap(); - statement.bind_string(2, "baz").unwrap(); assert_eq!( - session.execute_prepared(&statement), + session.execute_prepared(&prepared, 2), Vec::>::new() ); } -async fn select(session: &CassandraConnection) { - let prepared = - session.prepare("SELECT id, x, name FROM test_prepare_statements.table_1 WHERE id = ?"); +fn select(session: &CassandraConnection) { + let prepared = session.prepare("SELECT id FROM test_prepare_statements.table_1 WHERE id = ?"); - let mut statement = prepared.bind(); - statement.bind_int32(0, 1).unwrap(); + let result_rows = session.execute_prepared(&prepared, 1); - let result_rows = session.execute_prepared(&statement); - - assert_rows( - result_rows, - &[&[ - ResultValue::Int(1), - ResultValue::Int(11), - ResultValue::Varchar("foo".into()), - ]], - ); + assert_rows(result_rows, &[&[ResultValue::Int(1)]]); } pub async fn test(session: &CassandraConnection) { run_query(session, "CREATE KEYSPACE test_prepare_statements WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };").await; run_query( session, - "CREATE TABLE test_prepare_statements.table_1 (id int PRIMARY KEY, x int, name varchar);", + "CREATE TABLE test_prepare_statements.table_1 (id int PRIMARY KEY);", ) .await; - insert(session).await; - select(session).await; + insert(session); + select(session); delete(session).await; } diff --git a/shotover-proxy/tests/cassandra_int_tests/protect.rs b/shotover-proxy/tests/cassandra_int_tests/protect.rs index d6dcf4a8d..549d739bc 100644 --- a/shotover-proxy/tests/cassandra_int_tests/protect.rs +++ b/shotover-proxy/tests/cassandra_int_tests/protect.rs @@ -1,5 +1,4 @@ use crate::helpers::cassandra::{assert_query_result, run_query, CassandraConnection, ResultValue}; -use cassandra_cpp::{stmt, Batch, BatchType}; use chacha20poly1305::Nonce; use serde::Deserialize; @@ -12,30 +11,30 @@ pub struct Protected { } pub async fn test(shotover_session: &CassandraConnection, direct_session: &CassandraConnection) { - run_query(shotover_session, "CREATE KEYSPACE test_protect_keyspace WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };").await; run_query( shotover_session, - "CREATE TABLE test_protect_keyspace.test_table (pk varchar PRIMARY KEY, cluster varchar, col1 blob, col2 int, col3 boolean);", + "CREATE KEYSPACE test_protect_keyspace WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };" ).await; run_query( shotover_session, - "INSERT INTO test_protect_keyspace.test_table (pk, cluster, col1, col2, col3) VALUES ('pk1', 'cluster', 'I am gonna get encrypted!!', 42, true);" + "CREATE TABLE test_protect_keyspace.test_table (pk varchar PRIMARY KEY, cluster varchar, col1 blob, col2 int, col3 boolean);" ).await; - let mut batch = Batch::new(BatchType::LOGGED); - batch.add_statement(&stmt!( - "INSERT INTO test_protect_keyspace.test_table (pk, cluster, col1, col2, col3) VALUES ('pk2', 'cluster', 'encrypted2', 422, true)" - )).unwrap(); - batch.add_statement(&stmt!( - "INSERT INTO test_protect_keyspace.test_table (pk, cluster, col1, col2, col3) VALUES ('pk3', 'cluster', 'encrypted3', 423, false)" - )).unwrap(); - shotover_session.execute_batch(&batch); + run_query( + shotover_session, + "INSERT INTO test_protect_keyspace.test_table (pk, cluster, col1, col2, col3) VALUES ('pk1', 'cluster', 'I am gonna get encrypted!!', 0, true);" + ).await; + + shotover_session.execute_batch(vec![ + "INSERT INTO test_protect_keyspace.test_table (pk, cluster, col1, col2, col3) VALUES ('pk2', 'cluster', 'encrypted2', 1, true)".into(), + "INSERT INTO test_protect_keyspace.test_table (pk, cluster, col1, col2, col3) VALUES ('pk3', 'cluster', 'encrypted3', 2, false)".into() + ]); let insert_statement = "BEGIN BATCH -INSERT INTO test_protect_keyspace.test_table (pk, cluster, col1, col2, col3) VALUES ('pk4', 'cluster', 'encrypted4', 424, true); -INSERT INTO test_protect_keyspace.test_table (pk, cluster, col1, col2, col3) VALUES ('pk5', 'cluster', 'encrypted5', 425, false); -APPLY BATCH;"; + INSERT INTO test_protect_keyspace.test_table (pk, cluster, col1, col2, col3) VALUES ('pk4', 'cluster', 'encrypted4', 3, true); + INSERT INTO test_protect_keyspace.test_table (pk, cluster, col1, col2, col3) VALUES ('pk5', 'cluster', 'encrypted5', 4, false); + APPLY BATCH;"; run_query(shotover_session, insert_statement).await; // assert that data is decrypted by shotover @@ -47,35 +46,35 @@ APPLY BATCH;"; ResultValue::Varchar("pk1".into()), ResultValue::Varchar("cluster".into()), ResultValue::Blob("I am gonna get encrypted!!".into()), - ResultValue::Int(42), + ResultValue::Int(0), ResultValue::Boolean(true), ], &[ ResultValue::Varchar("pk2".into()), ResultValue::Varchar("cluster".into()), ResultValue::Blob("encrypted2".into()), - ResultValue::Int(422), + ResultValue::Int(1), ResultValue::Boolean(true), ], &[ ResultValue::Varchar("pk3".into()), ResultValue::Varchar("cluster".into()), ResultValue::Blob("encrypted3".into()), - ResultValue::Int(423), + ResultValue::Int(2), ResultValue::Boolean(false), ], &[ ResultValue::Varchar("pk4".into()), ResultValue::Varchar("cluster".into()), ResultValue::Blob("encrypted4".into()), - ResultValue::Int(424), + ResultValue::Int(3), ResultValue::Boolean(true), ], &[ ResultValue::Varchar("pk5".into()), ResultValue::Varchar("cluster".into()), ResultValue::Blob("encrypted5".into()), - ResultValue::Int(425), + ResultValue::Int(4), ResultValue::Boolean(false), ], ], @@ -87,6 +86,7 @@ APPLY BATCH;"; .execute("SELECT pk, cluster, col1, col2, col3 FROM test_protect_keyspace.test_table") .await; assert_eq!(result.len(), 5); + for row in result { assert_eq!(row.len(), 5); @@ -97,7 +97,7 @@ APPLY BATCH;"; if let ResultValue::Blob(value) = &row[2] { let _: Protected = bincode::deserialize(value).unwrap(); } else { - panic!("expected 3rd column to be ResultValue::Varchar in {row:?}"); + panic!("expected 3rd column to be ResultValue::Blob in {row:?}"); } } } diff --git a/shotover-proxy/tests/examples/mod.rs b/shotover-proxy/tests/examples/mod.rs index eca2dc0a6..c3e735312 100644 --- a/shotover-proxy/tests/examples/mod.rs +++ b/shotover-proxy/tests/examples/mod.rs @@ -1,4 +1,7 @@ -use crate::helpers::cassandra::{assert_query_result, CassandraConnection, ResultValue}; +#![cfg(feature = "cassandra-cpp-driver-tests")] +use crate::helpers::cassandra::{ + assert_query_result, CassandraConnection, CassandraDriver, ResultValue, +}; use serial_test::serial; use test_helpers::docker_compose::DockerCompose; @@ -8,7 +11,7 @@ async fn test_cassandra_rewrite_peers_example() { let _docker_compose = DockerCompose::new("example-configs-docker/cassandra-peers-rewrite/docker-compose.yml"); - let connection = CassandraConnection::new("172.16.1.2", 9043).await; + let connection = CassandraConnection::new("172.16.1.2", 9043, CassandraDriver::Datastax).await; assert_query_result( &connection, diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index 04b3333f5..bc6f127ff 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -1,63 +1,215 @@ -use cassandra_cpp::Error as CassandraError; +#[cfg(feature = "cassandra-cpp-driver-tests")] use cassandra_cpp::{ - stmt, Batch, CassFuture, CassResult, Cluster, Error, PreparedStatement, Session, Ssl, - Statement, Value, ValueType, + stmt, Batch, BatchType, CassErrorCode, CassFuture, CassResult, Cluster, Error, ErrorKind, + PreparedStatement, Session as DatastaxSession, Ssl, Value, ValueType, +}; +use cassandra_protocol::types::cassandra_type::{wrapper_fn, CassandraType}; +use cdrs_tokio::{ + authenticators::StaticPasswordAuthenticatorProvider, + cluster::session::{Session as CdrsTokioSession, SessionBuilder, TcpSessionBuilder}, + cluster::{NodeAddress, NodeTcpConfigBuilder, TcpConnectionManager}, + frame::{ + message_response::ResponseBody, message_result::ResResultBody, Envelope, Serialize, Version, + }, + load_balancing::RoundRobinLoadBalancingStrategy, + query::{BatchQueryBuilder, PreparedQuery as CdrsTokioPreparedQuery}, + query_values, + transport::TransportTcp, + types::prelude::Error as CdrsError, }; use openssl::ssl::{SslContext, SslMethod}; use ordered_float::OrderedFloat; use scylla::{Session as SessionScylla, SessionBuilder as SessionBuilderScylla}; +#[cfg(feature = "cassandra-cpp-driver-tests")] +use std::fs::read_to_string; +use std::sync::Arc; + +#[derive(Debug)] +pub enum PreparedQuery { + #[cfg(feature = "cassandra-cpp-driver-tests")] + Datastax(PreparedStatement), + CdrsTokio(CdrsTokioPreparedQuery), +} + +impl PreparedQuery { + #[cfg(feature = "cassandra-cpp-driver-tests")] + fn as_datastax(&self) -> &PreparedStatement { + match self { + PreparedQuery::Datastax(p) => p, + _ => panic!("Not PreparedQuery::Datastax"), + } + } + + fn as_cdrs(&self) -> &CdrsTokioPreparedQuery { + match self { + PreparedQuery::CdrsTokio(p) => p, + #[cfg(feature = "cassandra-cpp-driver-tests")] + _ => panic!("Not PreparedQuery::CdrsTokio"), + } + } +} + +#[derive(Debug, PartialEq, Eq)] +pub struct CassandraError { + pub code: CassandraErrorCode, + pub message: String, +} + +#[derive(Debug, PartialEq, Eq)] +pub enum CassandraErrorCode { + ServerOverloaded = 0x1001, + InvalidQuery = 0x2200, +} + +impl From for CassandraErrorCode { + fn from(i: i32) -> Self { + match i { + 0x1001 => CassandraErrorCode::ServerOverloaded, + 0x2200 => CassandraErrorCode::InvalidQuery, + _ => unimplemented!("{i} is not implemented"), + } + } +} + +impl CassandraErrorCode { + #[cfg(feature = "cassandra-cpp-driver-tests")] + fn new_from_cpp(code: CassErrorCode) -> Self { + match code { + CassErrorCode::SERVER_INVALID_QUERY => CassandraErrorCode::InvalidQuery, + CassErrorCode::SERVER_OVERLOADED => CassandraErrorCode::ServerOverloaded, + _ => unimplemented!("{code:?} is not implemented"), + } + } +} + +#[allow(dead_code)] +#[derive(Copy, Clone)] +pub enum CassandraDriver { + #[cfg(feature = "cassandra-cpp-driver-tests")] + Datastax, + CdrsTokio, +} + +type CdrsTokioSessionInstance = CdrsTokioSession< + TransportTcp, + TcpConnectionManager, + RoundRobinLoadBalancingStrategy, +>; pub enum CassandraConnection { + #[cfg(feature = "cassandra-cpp-driver-tests")] Datastax { - session: Session, + session: DatastaxSession, + schema_awaiter: Option, + }, + CdrsTokio { + session: CdrsTokioSessionInstance, schema_awaiter: Option, }, } impl CassandraConnection { #[allow(dead_code)] - pub async fn new(contact_points: &str, port: u16) -> CassandraConnection { + pub async fn new(contact_points: &str, port: u16, driver: CassandraDriver) -> Self { for contact_point in contact_points.split(',') { test_helpers::wait_for_socket_to_open(contact_point, port); } - let mut cluster = Cluster::default(); - cluster.set_contact_points(contact_points).unwrap(); - cluster.set_credentials("cassandra", "cassandra").unwrap(); - cluster.set_port(port).unwrap(); - cluster.set_load_balance_round_robin(); - - CassandraConnection::Datastax { - // By default unwrap uses the Debug formatter `{:?}` which is extremely noisy for the error type returned by `connect()`. - // So we instead force the Display formatter `{}` on the error. - session: cluster.connect().map_err(|err| format!("{err}")).unwrap(), - schema_awaiter: None, + + match driver { + #[cfg(feature = "cassandra-cpp-driver-tests")] + CassandraDriver::Datastax => { + let mut cluster = Cluster::default(); + cluster.set_contact_points(contact_points).unwrap(); + cluster.set_credentials("cassandra", "cassandra").unwrap(); + cluster.set_port(port).unwrap(); + cluster.set_load_balance_round_robin(); + + CassandraConnection::Datastax { + // By default unwrap uses the Debug formatter `{:?}` which is extremely noisy for the error type returned by `connect()`. + // So we instead force the Display formatter `{}` on the error. + session: cluster.connect().map_err(|err| format!("{err}")).unwrap(), + schema_awaiter: None, + } + } + CassandraDriver::CdrsTokio => { + let user = "cassandra"; + let password = "cassandra"; + let auth = StaticPasswordAuthenticatorProvider::new(&user, &password); + + let node_addresses = contact_points + .split(',') + .map(|contact_point| NodeAddress::from(format!("{contact_point}:{port}"))) + .collect::>(); + + let config = NodeTcpConfigBuilder::new() + .with_contact_points(node_addresses) + .with_authenticator_provider(Arc::new(auth)) + .build() + .await + .unwrap(); + + let session = + TcpSessionBuilder::new(RoundRobinLoadBalancingStrategy::new(), config) + .build() + .unwrap(); + CassandraConnection::CdrsTokio { + session, + schema_awaiter: None, + } + } + } + } + + #[allow(dead_code)] + pub fn as_cdrs(&self) -> &CdrsTokioSessionInstance { + match self { + Self::CdrsTokio { session, .. } => session, + #[cfg(feature = "cassandra-cpp-driver-tests")] + _ => panic!("Not CdrsTokio"), } } + #[cfg(feature = "cassandra-cpp-driver-tests")] #[allow(dead_code)] - pub async fn new_tls( + pub fn as_datastax(&self) -> &DatastaxSession { + match self { + Self::Datastax { session, .. } => session, + _ => panic!("Not Datastax"), + } + } + + #[allow(dead_code, unused_variables)] + pub fn new_tls( contact_points: &str, port: u16, ca_cert_path: &str, - ) -> CassandraConnection { - let ca_cert = std::fs::read_to_string(ca_cert_path).unwrap(); - let mut ssl = Ssl::default(); - Ssl::add_trusted_cert(&mut ssl, &ca_cert).unwrap(); + driver: CassandraDriver, + ) -> Self { + match driver { + #[cfg(feature = "cassandra-cpp-driver-tests")] + CassandraDriver::Datastax => { + let ca_cert = read_to_string(ca_cert_path).unwrap(); + let mut ssl = Ssl::default(); + Ssl::add_trusted_cert(&mut ssl, &ca_cert).unwrap(); - for contact_point in contact_points.split(',') { - test_helpers::wait_for_socket_to_open(contact_point, port); - } + for contact_point in contact_points.split(',') { + test_helpers::wait_for_socket_to_open(contact_point, port); + } - let mut cluster = Cluster::default(); - cluster.set_credentials("cassandra", "cassandra").unwrap(); - cluster.set_contact_points(contact_points).unwrap(); - cluster.set_port(port).ok(); - cluster.set_load_balance_round_robin(); - cluster.set_ssl(&mut ssl); + let mut cluster = Cluster::default(); + cluster.set_credentials("cassandra", "cassandra").unwrap(); + cluster.set_contact_points(contact_points).unwrap(); + cluster.set_port(port).ok(); + cluster.set_load_balance_round_robin(); + cluster.set_ssl(&mut ssl); - CassandraConnection::Datastax { - session: cluster.connect().unwrap(), - schema_awaiter: None, + CassandraConnection::Datastax { + session: cluster.connect().unwrap(), + schema_awaiter: None, + } + } + // TODO actually implement TLS for cdrs-tokio + CassandraDriver::CdrsTokio => todo!(), } } @@ -68,111 +220,176 @@ impl CassandraConnection { context.set_ca_file(ca_cert).unwrap(); context.build() }); - match self { - CassandraConnection::Datastax { schema_awaiter, .. } => { - *schema_awaiter = Some( - SessionBuilderScylla::new() - .known_node(direct_node) - .user("cassandra", "cassandra") - .ssl_context(context) - .build() - .await - .unwrap(), - ); - } + + let schema_awaiter = match self { + #[cfg(feature = "cassandra-cpp-driver-tests")] + Self::Datastax { schema_awaiter, .. } => schema_awaiter, + Self::CdrsTokio { schema_awaiter, .. } => schema_awaiter, + }; + + *schema_awaiter = Some( + SessionBuilderScylla::new() + .known_node(direct_node) + .user("cassandra", "cassandra") + .ssl_context(context) + .build() + .await + .unwrap(), + ); + } + + async fn await_schema_agreement(&self) { + let schema_awaiter = match self { + #[cfg(feature = "cassandra-cpp-driver-tests")] + Self::Datastax { schema_awaiter, .. } => schema_awaiter, + Self::CdrsTokio { schema_awaiter, .. } => schema_awaiter, + }; + if let Some(schema_awaiter) = schema_awaiter { + schema_awaiter.await_schema_agreement().await.unwrap(); } } #[allow(dead_code)] pub async fn execute(&self, query: &str) -> Vec> { let result = match self { - CassandraConnection::Datastax { session, .. } => { + #[cfg(feature = "cassandra-cpp-driver-tests")] + Self::Datastax { session, .. } => { let statement = stmt!(query); match session.execute(&statement).wait() { Ok(result) => result .into_iter() - .map(|x| x.into_iter().map(ResultValue::new).collect()) + .map(|x| x.into_iter().map(ResultValue::new_from_cpp).collect()) .collect(), Err(Error(err, _)) => panic!("The CQL query: {query}\nFailed with: {err}"), } } + Self::CdrsTokio { session, .. } => { + let response = session.query(query).await.unwrap(); + Self::process_cdrs_response(response) + } }; let query = query.to_uppercase(); let query = query.trim(); - if query.starts_with("CREATE") || query.starts_with("ALTER") { - match self { - CassandraConnection::Datastax { schema_awaiter, .. } => { - if let Some(schema_awaiter) = schema_awaiter { - schema_awaiter.await_schema_agreement().await.unwrap(); - } - } - } + if query.starts_with("CREATE") || query.starts_with("ALTER") || query.starts_with("DROP") { + self.await_schema_agreement().await; } result } #[allow(dead_code)] + #[cfg(feature = "cassandra-cpp-driver-tests")] pub fn execute_async(&self, query: &str) -> CassFuture { match self { - CassandraConnection::Datastax { session, .. } => { + #[cfg(feature = "cassandra-cpp-driver-tests")] + Self::Datastax { session, .. } => { let statement = stmt!(query); session.execute(&statement) } + Self::CdrsTokio { .. } => todo!(), } } #[allow(dead_code)] pub fn execute_expect_err(&self, query: &str) -> CassandraError { match self { - CassandraConnection::Datastax { session, .. } => { + #[cfg(feature = "cassandra-cpp-driver-tests")] + Self::Datastax { session, .. } => { let statement = stmt!(query); - session.execute(&statement).wait().unwrap_err() + let error = session.execute(&statement).wait().unwrap_err(); + + if let ErrorKind::CassErrorResult(code, msg, ..) = error.0 { + return CassandraError { + code: CassandraErrorCode::new_from_cpp(code), + message: msg, + }; + } + + panic!("Did not get an error result for {query}"); + } + Self::CdrsTokio { session, .. } => { + let error = futures::executor::block_on(session.query(query)).unwrap_err(); + + match error { + CdrsError::Server { body, .. } => CassandraError { + code: body.error_code.into(), + message: body.message, + }, + _ => todo!(), + } } } } #[allow(dead_code)] pub fn execute_expect_err_contains(&self, query: &str, contains: &str) { - let result = self.execute_expect_err(query).to_string(); + let error_msg = self.execute_expect_err(query).message; assert!( - result.contains(contains), - "Expected the error to contain '{contains}' but it did not and was instead '{result}'" + error_msg.contains(contains), + "Expected the error to contain '{contains}' but it did not and was instead '{error_msg}'" ); } #[allow(dead_code)] - pub fn prepare(&self, query: &str) -> PreparedStatement { + pub fn prepare(&self, query: &str) -> PreparedQuery { match self { - CassandraConnection::Datastax { session, .. } => { - session.prepare(query).unwrap().wait().unwrap() + #[cfg(feature = "cassandra-cpp-driver-tests")] + Self::Datastax { session, .. } => { + PreparedQuery::Datastax(session.prepare(query).unwrap().wait().unwrap()) + } + Self::CdrsTokio { session, .. } => { + let query = futures::executor::block_on(session.prepare(query)).unwrap(); + PreparedQuery::CdrsTokio(query) } } } #[allow(dead_code)] - pub fn execute_prepared(&self, statement: &Statement) -> Vec> { + pub fn execute_prepared( + &self, + prepared_query: &PreparedQuery, + value: i32, + ) -> Vec> { match self { - CassandraConnection::Datastax { session, .. } => { - match session.execute(statement).wait() { + #[cfg(feature = "cassandra-cpp-driver-tests")] + Self::Datastax { session, .. } => { + let mut statement = prepared_query.as_datastax().bind(); + statement.bind_int32(0, value).unwrap(); + match session.execute(&statement).wait() { Ok(result) => result .into_iter() - .map(|x| x.into_iter().map(ResultValue::new).collect()) + .map(|x| x.into_iter().map(ResultValue::new_from_cpp).collect()) .collect(), Err(Error(err, _)) => { panic!("The statement: {statement:?}\nFailed with: {err}") } } } + Self::CdrsTokio { session, .. } => { + let statement = prepared_query.as_cdrs(); + let response = futures::executor::block_on( + session.exec_with_values(statement, query_values!(value)), + ) + .unwrap(); + + Self::process_cdrs_response(response) + } } } #[allow(dead_code)] - pub fn execute_batch(&self, batch: &Batch) { + pub fn execute_batch(&self, queries: Vec) { match self { - CassandraConnection::Datastax { session, .. } => { - match session.execute_batch(batch).wait() { + #[cfg(feature = "cassandra-cpp-driver-tests")] + Self::Datastax { session, .. } => { + let mut batch = Batch::new(BatchType::LOGGED); + + for query in queries { + batch.add_statement(&stmt!(query.as_str())).unwrap(); + } + + match session.execute_batch(&batch).wait() { Ok(result) => assert_eq!( result.into_iter().count(), 0, @@ -181,22 +398,83 @@ impl CassandraConnection { Err(Error(err, _)) => panic!("The batch: {batch:?}\nFailed with: {err}"), } } + Self::CdrsTokio { session, .. } => { + let mut builder = BatchQueryBuilder::new(); + + for query in queries { + builder = builder.add_query(query, query_values!()); + } + + let batch = builder.build().unwrap(); + + futures::executor::block_on(session.batch(batch)).unwrap(); + } } } - #[allow(dead_code)] - pub fn execute_batch_expect_err(&self, batch: &Batch) -> CassandraError { + #[allow(dead_code, unused_variables)] + pub fn execute_batch_expect_err(&self, queries: Vec) -> CassandraError { match self { - CassandraConnection::Datastax { session, .. } => { - session.execute_batch(batch).wait().unwrap_err() + #[cfg(feature = "cassandra-cpp-driver-tests")] + Self::Datastax { session, .. } => { + let mut batch = Batch::new(BatchType::LOGGED); + for query in queries { + batch.add_statement(&stmt!(query.as_str())).unwrap(); + } + let error = session.execute_batch(&batch).wait().unwrap_err(); + if let ErrorKind::CassErrorResult(code, message, ..) = error.0 { + return CassandraError { + code: CassandraErrorCode::new_from_cpp(code), + message, + }; + } + + panic!("Did not get an error result for {batch:?}"); + } + Self::CdrsTokio { .. } => todo!(), + } + } + + fn process_cdrs_response(response: Envelope) -> Vec> { + let version = response.version; + let response_body = response.response_body().unwrap(); + + match response_body { + ResponseBody::Error(err) => { + panic!("CQL query Failed with: {err:?}") } + ResponseBody::Result(res_result_body) => match res_result_body { + ResResultBody::Rows(rows) => { + let mut result_values = vec![]; + + for row in &rows.rows_content { + let mut row_result_values = vec![]; + for (i, col_spec) in rows.metadata.col_specs.iter().enumerate() { + let wrapper = wrapper_fn(&col_spec.col_type.id); + let value = ResultValue::new_from_cdrs( + wrapper(&row[i], &col_spec.col_type, version).unwrap(), + version, + ); + + row_result_values.push(value); + } + result_values.push(row_result_values); + } + + result_values + } + ResResultBody::Prepared(_) => todo!(), + ResResultBody::SchemaChange(_) => vec![], + ResResultBody::SetKeyspace(_) => vec![], + ResResultBody::Void => vec![], + }, + _ => todo!(), } } } #[derive(Debug, Clone, PartialOrd, Eq, Ord)] pub enum ResultValue { - Text(String), Varchar(String), Int(i32), Boolean(bool), @@ -210,7 +488,7 @@ pub enum ResultValue { Float(OrderedFloat), Inet(String), SmallInt(i16), - Time(Vec), // TODO should be String + Time(Vec), // TODO shoulbe be String Timestamp(i64), TimeUuid(uuid::Uuid), Counter(i64), @@ -219,8 +497,10 @@ pub enum ResultValue { Date(Vec), // TODO should be string Set(Vec), List(Vec), + #[allow(dead_code)] Tuple(Vec), Map(Vec<(ResultValue, ResultValue)>), + #[allow(dead_code)] Null, /// Never output by the DB /// Can be used by the user in assertions to allow any value. @@ -231,7 +511,6 @@ pub enum ResultValue { impl PartialEq for ResultValue { fn eq(&self, other: &Self) -> bool { match (self, other) { - (Self::Text(l0), Self::Text(r0)) => l0 == r0, (Self::Varchar(l0), Self::Varchar(r0)) => l0 == r0, (Self::Int(l0), Self::Int(r0)) => l0 == r0, (Self::Boolean(l0), Self::Boolean(r0)) => l0 == r0, @@ -266,12 +545,12 @@ impl PartialEq for ResultValue { impl ResultValue { #[allow(dead_code)] - pub fn new(value: Value) -> ResultValue { + #[cfg(feature = "cassandra-cpp-driver-tests")] + pub fn new_from_cpp(value: Value) -> Self { if value.is_null() { ResultValue::Null } else { match value.get_type() { - ValueType::TEXT => ResultValue::Text(value.get_string().unwrap()), ValueType::VARCHAR => ResultValue::Varchar(value.get_string().unwrap()), ValueType::INT => ResultValue::Int(value.get_i32().unwrap()), ValueType::BOOLEAN => ResultValue::Boolean(value.get_bool().unwrap()), @@ -298,41 +577,98 @@ impl ResultValue { ValueType::COUNTER => ResultValue::Counter(value.get_i64().unwrap()), ValueType::VARINT => ResultValue::VarInt(value.get_bytes().unwrap().to_vec()), ValueType::TINY_INT => ResultValue::TinyInt(value.get_i8().unwrap()), - ValueType::SET => { - ResultValue::Set(value.get_set().unwrap().map(ResultValue::new).collect()) - } + ValueType::SET => ResultValue::Set( + value + .get_set() + .unwrap() + .map(ResultValue::new_from_cpp) + .collect(), + ), // despite the name get_set is used by SET, LIST and TUPLE - ValueType::LIST => { - ResultValue::List(value.get_set().unwrap().map(ResultValue::new).collect()) - } - ValueType::TUPLE => { - ResultValue::Tuple(value.get_set().unwrap().map(ResultValue::new).collect()) - } + ValueType::LIST => ResultValue::List( + value + .get_set() + .unwrap() + .map(ResultValue::new_from_cpp) + .collect(), + ), + ValueType::TUPLE => ResultValue::Tuple( + value + .get_set() + .unwrap() + .map(ResultValue::new_from_cpp) + .collect(), + ), ValueType::MAP => ResultValue::Map( value .get_map() .unwrap() - .map(|(k, v)| (ResultValue::new(k), ResultValue::new(v))) + .map(|(k, v)| (ResultValue::new_from_cpp(k), ResultValue::new_from_cpp(v))) .collect(), ), ValueType::UNKNOWN => todo!(), ValueType::CUSTOM => todo!(), ValueType::UDT => todo!(), + ValueType::TEXT => unimplemented!("text is represented by the same id as varchar at the protocol level and therefore will never be instantiated by the datastax cpp driver. https://github.com/apache/cassandra/blob/703ccdee29f7e8c39aeb976e72e516415d609cf4/doc/native_protocol_v5.spec#L1184"), } } } -} -/// Execute a `query` against the `session` and return result rows -#[allow(dead_code)] -pub fn execute_query(session: &Session, query: &str) -> Vec> { - let statement = stmt!(query); - match session.execute(&statement).wait() { - Ok(result) => result - .into_iter() - .map(|x| x.into_iter().map(ResultValue::new).collect()) - .collect(), - Err(Error(err, _)) => panic!("The CSQL query: {query}\nFailed with: {err}"), + pub fn new_from_cdrs(value: CassandraType, version: Version) -> Self { + match value { + CassandraType::Ascii(ascii) => ResultValue::Ascii(ascii), + CassandraType::Bigint(big_int) => ResultValue::BigInt(big_int), + CassandraType::Blob(blob) => ResultValue::Blob(blob.into_vec()), + CassandraType::Boolean(b) => ResultValue::Boolean(b), + CassandraType::Counter(counter) => ResultValue::Counter(counter), + CassandraType::Decimal(decimal) => { + ResultValue::Decimal(decimal.serialize_to_vec(version)) + } + CassandraType::Double(double) => ResultValue::Double(double.into()), + CassandraType::Float(float) => ResultValue::Float(float.into()), + CassandraType::Int(int) => ResultValue::Int(int), + CassandraType::Timestamp(timestamp) => ResultValue::Timestamp(timestamp), + CassandraType::Uuid(uuid) => ResultValue::Uuid(uuid), + CassandraType::Varchar(varchar) => ResultValue::Varchar(varchar), + CassandraType::Varint(var_int) => ResultValue::VarInt(var_int.to_signed_bytes_be()), + CassandraType::Timeuuid(uuid) => ResultValue::TimeUuid(uuid), + CassandraType::Inet(ip_addr) => ResultValue::Inet(ip_addr.to_string()), + CassandraType::Date(date) => ResultValue::Date(date.serialize_to_vec(version)), + CassandraType::Time(time) => ResultValue::Time(time.serialize_to_vec(version)), + CassandraType::Smallint(small_int) => ResultValue::SmallInt(small_int), + CassandraType::Tinyint(tiny_int) => ResultValue::TinyInt(tiny_int), + CassandraType::Duration(duration) => { + ResultValue::Duration(duration.serialize_to_vec(version)) + } + CassandraType::List(mut list) => ResultValue::List( + list.drain(..) + .map(|element| ResultValue::new_from_cdrs(element, version)) + .collect(), + ), + CassandraType::Map(mut map) => ResultValue::Map( + map.drain(..) + .map(|(k, v)| { + ( + ResultValue::new_from_cdrs(k, version), + ResultValue::new_from_cdrs(v, version), + ) + }) + .collect(), + ), + CassandraType::Set(mut set) => ResultValue::Set( + set.drain(..) + .map(|element| ResultValue::new_from_cdrs(element, version)) + .collect(), + ), + CassandraType::Udt(_) => todo!(), + CassandraType::Tuple(mut tuple) => ResultValue::Tuple( + tuple + .drain(..) + .map(|element| ResultValue::new_from_cdrs(element, version)) + .collect(), + ), + CassandraType::Null => ResultValue::Null, + } } }