diff --git a/shotover-proxy/tests/cassandra_int_tests/mod.rs b/shotover-proxy/tests/cassandra_int_tests/mod.rs index 4681804d1..95644759e 100644 --- a/shotover-proxy/tests/cassandra_int_tests/mod.rs +++ b/shotover-proxy/tests/cassandra_int_tests/mod.rs @@ -6,8 +6,6 @@ use crate::helpers::cassandra::{ }; use crate::helpers::ShotoverManager; #[cfg(feature = "cassandra-cpp-driver-tests")] -use cassandra_cpp::{Error, ErrorKind}; -#[cfg(feature = "cassandra-cpp-driver-tests")] use cassandra_protocol::frame::message_error::{ErrorBody, ErrorType}; use cdrs_tokio::frame::events::{ SchemaChange, SchemaChangeOptions, SchemaChangeTarget, SchemaChangeType, ServerEvent, @@ -472,7 +470,8 @@ async fn peers_rewrite_v4(#[case] driver: CassandraDriver) { #[cfg(feature = "cassandra-cpp-driver-tests")] #[rstest] -//#[case::cdrs(CdrsTokio)] // TODO +#[case::cdrs(CdrsTokio)] +#[case::scylla(Scylla)] #[cfg_attr(feature = "cassandra-cpp-driver-tests", case::datastax(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] @@ -491,20 +490,21 @@ async fn peers_rewrite_v3(#[case] driver: CassandraDriver) { // Assert that the error cassandra gives because system.peers_v2 does not exist on cassandra v3 // is passed through shotover unchanged. - let statement = "SELECT data_center, native_port, rack FROM system.peers_v2;"; - let result = connection.execute_expect_err(statement).await; assert_eq!( - result, - ErrorBody { + connection + .execute_fallible("SELECT data_center, native_port, rack FROM system.peers_v2;") + .await, + Err(ErrorBody { ty: ErrorType::Invalid, message: "unconfigured table peers_v2".into() - } + }) ); } #[cfg(feature = "cassandra-cpp-driver-tests")] #[rstest] -//#[case::cdrs(CdrsTokio)] // TODO +#[case::cdrs(CdrsTokio)] +#[case::scylla(Scylla)] #[cfg_attr(feature = "cassandra-cpp-driver-tests", case::datastax(Datastax))] #[tokio::test(flavor = "multi_thread")] #[serial] @@ -545,10 +545,10 @@ async fn request_throttling(#[case] driver: CassandraDriver) { let mut results = join_all(future_list).await; results.retain(|result| match result { Ok(_) => true, - Err(Error( - ErrorKind::CassErrorResult(cassandra_cpp::CassErrorCode::SERVER_OVERLOADED, ..), - _, - )) => false, + Err(ErrorBody { + ty: ErrorType::Overloaded, + .. + }) => false, Err(e) => panic!( "wrong error returned, got {:?}, expected SERVER_OVERLOADED", e @@ -584,7 +584,10 @@ async fn request_throttling(#[case] driver: CassandraDriver) { for i in 0..60 { queries.push(format!("INSERT INTO test_keyspace.my_table (id, lastname, firstname) VALUES ({}, 'text', 'text')", i)); } - let result = connection.execute_batch_expect_err(queries).await; + let result = connection + .execute_batch_fallible(queries) + .await + .unwrap_err(); assert_eq!( result, ErrorBody { diff --git a/shotover-proxy/tests/cassandra_int_tests/prepared_statements_all.rs b/shotover-proxy/tests/cassandra_int_tests/prepared_statements_all.rs index 145414551..ad212eca1 100644 --- a/shotover-proxy/tests/cassandra_int_tests/prepared_statements_all.rs +++ b/shotover-proxy/tests/cassandra_int_tests/prepared_statements_all.rs @@ -55,7 +55,7 @@ async fn insert(connection: &CassandraConnection) { ] ) .await, - Vec::>::new() + Ok(Vec::>::new()) ); } else { let prepared = connection @@ -63,7 +63,7 @@ async fn insert(connection: &CassandraConnection) { .await; assert_eq!( connection.execute_prepared(&prepared, &values()).await, - Vec::>::new() + Ok(Vec::>::new()) ); } } diff --git a/shotover-proxy/tests/cassandra_int_tests/prepared_statements_simple.rs b/shotover-proxy/tests/cassandra_int_tests/prepared_statements_simple.rs index de3528a8e..fa0471fe9 100644 --- a/shotover-proxy/tests/cassandra_int_tests/prepared_statements_simple.rs +++ b/shotover-proxy/tests/cassandra_int_tests/prepared_statements_simple.rs @@ -13,7 +13,7 @@ async fn delete(session: &CassandraConnection) { session .execute_prepared(&prepared, &[ResultValue::Int(1)]) .await, - Vec::>::new() + Ok(Vec::>::new()) ); assert_query_result( @@ -33,21 +33,21 @@ async fn insert(session: &CassandraConnection) { session .execute_prepared(&prepared, &[ResultValue::Int(1)]) .await, - Vec::>::new() + Ok(Vec::>::new()) ); assert_eq!( session .execute_prepared(&prepared, &[ResultValue::Int(2)]) .await, - Vec::>::new() + Ok(Vec::>::new()) ); assert_eq!( session .execute_prepared(&prepared, &[ResultValue::Int(3)]) .await, - Vec::>::new() + Ok(Vec::>::new()) ); } @@ -58,7 +58,8 @@ async fn select(session: &CassandraConnection) { let result_rows = session .execute_prepared(&prepared, &[ResultValue::Int(1)]) - .await; + .await + .unwrap(); assert_rows(result_rows, &[&[ResultValue::Int(1)]]); } @@ -80,13 +81,15 @@ async fn select_cross_connection( assert_rows( connection_before .execute_prepared(&prepared, &[ResultValue::Int(1)]) - .await, + .await + .unwrap(), &[&[ResultValue::Int(1), ResultValue::Int(1)]], ); assert_rows( connection_after .execute_prepared(&prepared, &[ResultValue::Int(1)]) - .await, + .await + .unwrap(), &[&[ResultValue::Int(1), ResultValue::Int(1)]], ); } @@ -106,7 +109,7 @@ async fn use_statement(session: &CassandraConnection) { session .execute_prepared(&prepared, &[ResultValue::Int(358)]) .await, - Vec::>::new() + Ok(Vec::>::new()) ); // observe that the query succeeded despite the keyspace being incorrect at the time. @@ -139,6 +142,6 @@ where if session.is(&[CassandraDriver::Scylla, CassandraDriver::CdrsTokio]) { let cql = "SELECT * FROM system.local WHERE key = 'local'"; let prepared = session.prepare(cql).await; - session.execute_prepared(&prepared, &[]).await; + session.execute_prepared(&prepared, &[]).await.unwrap(); } } diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index 7038ce30f..84bf01980 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -5,7 +5,6 @@ use cassandra_cpp::{ PreparedStatement as PreparedStatementCpp, Session as DatastaxSession, Ssl, Statement as StatementCpp, Value, ValueType, }; -#[cfg(feature = "cassandra-cpp-driver-tests")] use cassandra_protocol::frame::message_error::ErrorType; use cassandra_protocol::query::QueryValues; use cassandra_protocol::types::IntoRustByIndex; @@ -26,7 +25,6 @@ use cdrs_tokio::{ query::{BatchQueryBuilder, PreparedQuery as CdrsTokioPreparedQuery}, query_values, transport::TransportTcp, - types::prelude::Error as CdrsError, }; use openssl::ssl::{SslContext, SslMethod}; use ordered_float::OrderedFloat; @@ -35,7 +33,8 @@ use scylla::frame::response::result::CqlValue; use scylla::frame::types::Consistency; use scylla::frame::value::Value as ScyllaValue; use scylla::prepared_statement::PreparedStatement as PreparedStatementScylla; -use scylla::{Session as SessionScylla, SessionBuilder as SessionBuilderScylla}; +use scylla::transport::errors::{DbError, QueryError}; +use scylla::{QueryResult, Session as SessionScylla, SessionBuilder as SessionBuilderScylla}; #[cfg(feature = "cassandra-cpp-driver-tests")] use std::fs::read_to_string; use std::net::IpAddr; @@ -73,18 +72,6 @@ impl PreparedQuery { } } -#[cfg(feature = "cassandra-cpp-driver-tests")] -fn cpp_error_to_cdrs(code: CassErrorCode, message: String) -> ErrorBody { - ErrorBody { - ty: match code { - CassErrorCode::SERVER_INVALID_QUERY => ErrorType::Invalid, - CassErrorCode::SERVER_OVERLOADED => ErrorType::Overloaded, - _ => unimplemented!("{code:?} is not implemented"), - }, - message, - } -} - #[allow(dead_code)] #[derive(Copy, Clone, Eq, PartialEq)] pub enum CassandraDriver { @@ -298,36 +285,25 @@ impl CassandraConnection { #[allow(dead_code)] pub async fn execute(&self, query: &str) -> Vec> { + match self.execute_fallible(query).await { + Ok(result) => result, + Err(err) => panic!("The CQL query: {query}\nFailed with: {err:?}"), + } + } + + #[allow(dead_code)] + pub async fn execute_fallible(&self, query: &str) -> Result>, ErrorBody> { let result = match self { #[cfg(feature = "cassandra-cpp-driver-tests")] Self::Datastax { session, .. } => { let statement = stmt!(query); - match session.execute(&statement).await { - Ok(result) => result - .into_iter() - .map(|x| x.into_iter().map(ResultValue::new_from_cpp).collect()) - .collect(), - Err(Error(err, _)) => panic!("The CQL query: {query}\nFailed with: {err}"), - } + Self::process_datastax_response(session.execute(&statement).await) } Self::CdrsTokio { session, .. } => { - let response = session.query(query).await.unwrap(); - Self::process_cdrs_response(response) + Self::process_cdrs_response(session.query(query).await) } Self::Scylla { session, .. } => { - let rows = session.query(query, ()).await.unwrap().rows; - match rows { - Some(rows) => rows - .into_iter() - .map(|x| { - x.columns - .into_iter() - .map(ResultValue::new_from_scylla) - .collect() - }) - .collect(), - None => vec![], - } + Self::process_scylla_response(session.query(query, ()).await) } }; @@ -340,49 +316,9 @@ impl CassandraConnection { result } - #[allow(dead_code)] - #[cfg(feature = "cassandra-cpp-driver-tests")] - pub async fn execute_fallible(&self, query: &str) -> Result { - match self { - #[cfg(feature = "cassandra-cpp-driver-tests")] - Self::Datastax { session, .. } => { - let statement = stmt!(query); - session.execute(&statement).await - } - Self::CdrsTokio { .. } => todo!(), - Self::Scylla { .. } => todo!(), - } - } - - #[allow(dead_code)] - pub async fn execute_expect_err(&self, query: &str) -> ErrorBody { - match self { - #[cfg(feature = "cassandra-cpp-driver-tests")] - Self::Datastax { session, .. } => { - let statement = stmt!(query); - let error = session.execute(&statement).await.unwrap_err(); - - if let ErrorKind::CassErrorResult(code, msg, ..) = error.0 { - cpp_error_to_cdrs(code, msg) - } else { - panic!("Did not get an error result for {query}"); - } - } - Self::CdrsTokio { session, .. } => { - let error = session.query(query).await.unwrap_err(); - - match error { - CdrsError::Server { body, .. } => body, - _ => todo!(), - } - } - Self::Scylla { .. } => todo!(), - } - } - #[allow(dead_code)] pub async fn execute_expect_err_contains(&self, query: &str, contains: &str) { - let error_msg = self.execute_expect_err(query).await.message; + let error_msg = self.execute_fallible(query).await.unwrap_err().message; assert!( error_msg.contains(contains), "Expected the error to contain '{contains}' but it did not and was instead '{error_msg}'" @@ -478,7 +414,7 @@ impl CassandraConnection { &self, prepared_query: &PreparedQuery, values: &[ResultValue], - ) -> Vec> { + ) -> Result>, ErrorBody> { match self { #[cfg(feature = "cassandra-cpp-driver-tests")] Self::Datastax { session, .. } => { @@ -488,15 +424,7 @@ impl CassandraConnection { } statement.set_tracing(true).unwrap(); - match session.execute(&statement).await { - Ok(result) => result - .into_iter() - .map(|x| x.into_iter().map(ResultValue::new_from_cpp).collect()) - .collect(), - Err(Error(err, _)) => { - panic!("The statement: {statement:?}\nFailed with: {err}") - } - } + Self::process_datastax_response(session.execute(&statement).await) } Self::CdrsTokio { session, .. } => { let statement = prepared_query.as_cdrs(); @@ -515,28 +443,13 @@ impl CassandraConnection { beta_protocol: false, }; - let response = session.exec_with_params(statement, ¶ms).await.unwrap(); - - Self::process_cdrs_response(response) + Self::process_cdrs_response(session.exec_with_params(statement, ¶ms).await) } Self::Scylla { session, .. } => { let statement = prepared_query.as_scylla(); let values = Self::build_values_scylla(values); - let response = session.execute(statement, values).await.unwrap(); - - match response.rows { - Some(rows) => rows - .into_iter() - .map(|row| { - row.columns - .into_iter() - .map(ResultValue::new_from_scylla) - .collect() - }) - .collect(), - None => vec![], - } + Self::process_scylla_response(session.execute(statement, values).await) } } } @@ -620,35 +533,28 @@ impl CassandraConnection { } #[allow(dead_code)] - pub async fn execute_batch(&self, queries: Vec) { + pub async fn execute_batch_fallible( + &self, + queries: Vec, + ) -> Result>, ErrorBody> { match self { #[cfg(feature = "cassandra-cpp-driver-tests")] Self::Datastax { session, .. } => { let mut batch = Batch::new(BatchType::LOGGED); - for query in queries { batch.add_statement(&stmt!(query.as_str())).unwrap(); } - match session.execute_batch(&batch).await { - Ok(result) => assert_eq!( - result.into_iter().count(), - 0, - "Batches should never return results", - ), - Err(Error(err, _)) => panic!("The batch: {batch:?}\nFailed with: {err}"), - } + Self::process_datastax_response(session.execute_batch(&batch).await) } Self::CdrsTokio { session, .. } => { let mut builder = BatchQueryBuilder::new(); - for query in queries { builder = builder.add_query(query, query_values!()); } - let batch = builder.build().unwrap(); - session.batch(batch).await.unwrap(); + Self::process_cdrs_response(session.batch(batch).await) } Self::Scylla { session, .. } => { let mut values = vec![]; @@ -658,66 +564,116 @@ impl CassandraConnection { values.push(()); } - session.batch(&batch, values).await.unwrap(); + Self::process_scylla_response(session.batch(&batch, values).await) } } } - #[allow(dead_code, unused_variables)] - pub async fn execute_batch_expect_err(&self, queries: Vec) -> ErrorBody { - match self { - #[cfg(feature = "cassandra-cpp-driver-tests")] - Self::Datastax { session, .. } => { - let mut batch = Batch::new(BatchType::LOGGED); - for query in queries { - batch.add_statement(&stmt!(query.as_str())).unwrap(); - } - let error = session.execute_batch(&batch).await.unwrap_err(); - if let ErrorKind::CassErrorResult(code, msg, ..) = error.0 { - cpp_error_to_cdrs(code, msg) - } else { - panic!("Did not get an error result for {batch:?}"); - } - } - Self::CdrsTokio { .. } => todo!(), - Self::Scylla { .. } => todo!(), + #[allow(dead_code)] + pub async fn execute_batch(&self, queries: Vec) { + let result = self.execute_batch_fallible(queries).await.unwrap(); + assert_eq!(result.len(), 0, "Batches should never return results"); + } + + // allow reason: micro performance doesnt matter in a test and boxing ErrorBody makes matches unergonomic + #[allow(clippy::result_large_err)] + #[cfg(feature = "cassandra-cpp-driver-tests")] + fn process_datastax_response( + response: Result, + ) -> Result>, ErrorBody> { + match response { + Ok(result) => Ok(result + .into_iter() + .map(|x| x.into_iter().map(ResultValue::new_from_cpp).collect()) + .collect()), + Err(Error(ErrorKind::CassErrorResult(code, message, ..), _)) => Err(ErrorBody { + ty: match code { + CassErrorCode::SERVER_OVERLOADED => ErrorType::Overloaded, + CassErrorCode::SERVER_SERVER_ERROR => ErrorType::Server, + CassErrorCode::SERVER_INVALID_QUERY => ErrorType::Invalid, + code => todo!("Implement handling for cassandra_cpp err: {code:?}"), + }, + message, + }), + Err(err) => panic!("Unexpected cassandra_cpp error: {err}"), } } - fn process_cdrs_response(response: Envelope) -> Vec> { - let version = response.version; - let response_body = response.response_body().unwrap(); + // allow reason: micro performance doesnt matter in a test and boxing ErrorBody makes matches unergonomic + #[allow(clippy::result_large_err)] + fn process_scylla_response( + response: Result, + ) -> Result>, ErrorBody> { + match response { + Ok(value) => Ok(match value.rows { + Some(rows) => rows + .into_iter() + .map(|x| { + x.columns + .into_iter() + .map(ResultValue::new_from_scylla) + .collect() + }) + .collect(), + None => vec![], + }), + Err(QueryError::DbError(code, message)) => Err(ErrorBody { + ty: match code { + DbError::Overloaded => ErrorType::Overloaded, + DbError::ServerError => ErrorType::Server, + DbError::Invalid => ErrorType::Invalid, + code => todo!("Implement handling for cassandra_cpp err: {code:?}"), + }, + message, + }), + Err(err) => panic!("Unexpected scylla error: {err:?}"), + } + } - match response_body { - ResponseBody::Error(err) => { - panic!("CQL query Failed with: {err:?}") - } - ResponseBody::Result(res_result_body) => match res_result_body { - ResResultBody::Rows(rows) => { - let mut result_values = vec![]; - - for row in &rows.rows_content { - let mut row_result_values = vec![]; - for (i, col_spec) in rows.metadata.col_specs.iter().enumerate() { - let wrapper = wrapper_fn(&col_spec.col_type.id); - let value = ResultValue::new_from_cdrs( - wrapper(&row[i], &col_spec.col_type, version).unwrap(), - version, - ); - - row_result_values.push(value); - } - result_values.push(row_result_values); + // allow reason: micro performance doesnt matter in a test and boxing ErrorBody makes matches unergonomic + #[allow(clippy::result_large_err)] + fn process_cdrs_response( + response: Result, + ) -> Result>, ErrorBody> { + match response { + Ok(response) => { + let version = response.version; + let response_body = response.response_body().unwrap(); + + Ok(match response_body { + ResponseBody::Error(err) => { + panic!("CQL query Failed with: {err:?}") } - - result_values - } - ResResultBody::Prepared(_) => todo!(), - ResResultBody::SchemaChange(_) => vec![], - ResResultBody::SetKeyspace(_) => vec![], - ResResultBody::Void => vec![], - }, - _ => todo!(), + ResponseBody::Result(res_result_body) => match res_result_body { + ResResultBody::Rows(rows) => { + let mut result_values = vec![]; + + for row in &rows.rows_content { + let mut row_result_values = vec![]; + for (i, col_spec) in rows.metadata.col_specs.iter().enumerate() { + let wrapper = wrapper_fn(&col_spec.col_type.id); + let value = ResultValue::new_from_cdrs( + wrapper(&row[i], &col_spec.col_type, version).unwrap(), + version, + ); + + row_result_values.push(value); + } + result_values.push(row_result_values); + } + + result_values + } + ResResultBody::Prepared(_) => todo!(), + ResResultBody::SchemaChange(_) => vec![], + ResResultBody::SetKeyspace(_) => vec![], + ResResultBody::Void => vec![], + }, + _ => todo!(), + }) + } + Err(cassandra_protocol::Error::Server { body, .. }) => Err(body), + Err(err) => panic!("Unexpected cdrs-tokio error: {err:?}"), } } }